1use rill_core::math::Transcendental;
2use rill_core::traits::{SignalNode, NodeId, NodeMetadata, NodeParams, NodeVariant};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
11pub enum RegistryError {
12 UnknownType(String),
14}
15
16impl std::fmt::Display for RegistryError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
20 }
21 }
22}
23
24impl std::error::Error for RegistryError {}
25
26pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
37 fn type_name(&self) -> &'static str;
39
40 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE>;
49}
50
51pub struct NodeRegistry<T: Transcendental, const BUF_SIZE: usize> {
65 entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
66}
67
68impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeRegistry<T, BUF_SIZE> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl<T: Transcendental, const BUF_SIZE: usize> NodeRegistry<T, BUF_SIZE> {
75 pub fn new() -> Self {
77 Self {
78 entries: HashMap::new(),
79 }
80 }
81
82 pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
88 let name = ctor.type_name();
89 self.entries.insert(name, Box::new(ctor));
90 }
91
92 pub fn register_fn(
97 &mut self,
98 type_name: &'static str,
99 f: impl Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
100 ) {
101 self.entries.insert(
102 type_name,
103 Box::new(ClosureCtor {
104 type_name,
105 f: Box::new(f),
106 }),
107 );
108 }
109
110 pub fn construct(
115 &self,
116 type_name: &str,
117 id: NodeId,
118 params: &NodeParams,
119 ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
120 self.entries
121 .get(type_name)
122 .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
123 .map(|ctor| ctor.construct(id, params))
124 }
125
126 pub fn contains(&self, type_name: &str) -> bool {
128 self.entries.contains_key(type_name)
129 }
130
131 pub fn list_types(&self) -> Vec<&'static str> {
133 self.entries.keys().copied().collect()
134 }
135
136 pub fn len(&self) -> usize {
138 self.entries.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.entries.is_empty()
144 }
145
146 pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
152 self.entries.get(type_name).map(|ctor| {
153 let dummy = NodeParams::new(44100.0);
154 let variant = ctor.construct(NodeId(u32::MAX), &dummy);
155 variant.metadata()
156 })
157 }
158}
159
160struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
165 type_name: &'static str,
166 f: Box<dyn Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
167}
168
169impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
170 for ClosureCtor<T, BUF_SIZE>
171{
172 fn type_name(&self) -> &'static str {
173 self.type_name
174 }
175
176 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE> {
177 (self.f)(id, params)
178 }
179}
180
181#[macro_export]
205macro_rules! node_ctor {
206 ($registry:expr, $type_name:expr, $ctor:expr) => {
207 $registry.register_fn($type_name, $ctor);
208 };
209}
210
211#[cfg(test)]
216mod tests {
217 use super::*;
218 use rill_core::traits::Source;
219 use rill_core::traits::Processor;
220 use rill_core::time::ClockTick;
221 use rill_core::traits::{ParamValue, ProcessResult};
222 use rill_core::traits::node::NodeState;
223 use rill_core::traits::port::Port;
224 use rill_core::traits::NodeCategory;
225
226 struct TestSource<T: Transcendental, const B: usize> {
229 id: NodeId,
230 state: NodeState<T, B>,
231 output: Port<T, B>,
232 meta_name: &'static str,
233 meta_cat: NodeCategory,
234 }
235
236 impl<T: Transcendental, const B: usize> TestSource<T, B> {
237 fn new() -> Self {
238 Self {
239 id: NodeId(0),
240 state: NodeState::new(44100.0),
241 output: Port::output(NodeId(0), 0, "out"),
242 meta_name: "TestSource",
243 meta_cat: NodeCategory::Source,
244 }
245 }
246
247 fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
248 self.id = id;
249 self.state.sample_rate = sample_rate;
250 }
251 }
252
253 impl<T: Transcendental, const B: usize> SignalNode<T, B> for TestSource<T, B> {
254 fn metadata(&self) -> rill_core::traits::NodeMetadata {
255 rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
256 }
257 fn init(&mut self, sample_rate: f32) { self.state.sample_rate = sample_rate; }
258 fn reset(&mut self) {}
259 fn get_parameter(&self, _: &rill_core::traits::ParameterId) -> Option<rill_core::traits::ParamValue> { None }
260 fn set_parameter(&mut self, _: &rill_core::traits::ParameterId, _: rill_core::traits::ParamValue) -> ProcessResult<()> { Ok(()) }
261 fn id(&self) -> NodeId { self.id }
262 fn set_id(&mut self, id: NodeId) { self.id = id; }
263 fn input_port(&self, _: usize) -> Option<&Port<T, B>> { None }
264 fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
265 fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
266 if index == 0 { Some(&self.output) } else { None }
267 }
268 fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
269 if index == 0 { Some(&mut self.output) } else { None }
270 }
271 fn control_port(&self, _: usize) -> Option<&Port<T, B>> { None }
272 fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
273 fn state(&self) -> &NodeState<T, B> { &self.state }
274 fn state_mut(&mut self) -> &mut NodeState<T, B> { &mut self.state }
275 }
276
277 impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
278 fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> { Ok(()) }
279 }
280
281 impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
282 fn process(&mut self, _: &ClockTick, _: &[&[T; B]], _: &[T], _: &[ClockTick], _: &[&[T; B]]) -> ProcessResult<()> { Ok(()) }
283 fn latency(&self) -> usize { 0 }
284 }
285
286 struct TestSourceCtor;
287 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
288 fn type_name(&self) -> &'static str { "test/source" }
289 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
290 let mut node = TestSource::<T, B>::new();
291 node.set_id_and_init(id, params.sample_rate);
292 NodeVariant::Source(Box::new(node))
293 }
294 }
295
296 struct TestProcessorCtor;
297 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
298 fn type_name(&self) -> &'static str { "test/processor" }
299 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
300 let mut node = TestSource::<T, B>::new();
301 node.meta_name = "Noop";
302 node.meta_cat = NodeCategory::Processor;
303 node.set_id_and_init(id, params.sample_rate);
304 NodeVariant::Processor(Box::new(node))
305 }
306 }
307
308 #[test]
311 fn test_registry_empty() {
312 let registry = NodeRegistry::<f32, 64>::new();
313 assert!(registry.is_empty());
314 assert_eq!(registry.len(), 0);
315 }
316
317 #[test]
318 fn test_registry_register_and_construct() {
319 let mut registry = NodeRegistry::<f32, 64>::new();
320 registry.register(TestSourceCtor);
321
322 assert!(registry.contains("test/source"));
323 assert_eq!(registry.len(), 1);
324
325 let params = NodeParams::new(48000.0);
326 let variant = registry.construct("test/source", NodeId(42), ¶ms)
327 .expect("should construct");
328
329 match &variant {
330 NodeVariant::Source(_) => {}
331 _ => panic!("expected Source variant"),
332 }
333
334 assert_eq!(variant.metadata().name, "TestSource");
336 }
337
338 #[test]
339 fn test_registry_unknown_type() {
340 let registry = NodeRegistry::<f32, 64>::new();
341 let params = NodeParams::new(44100.0);
342 let result = registry.construct("nonexistent", NodeId(0), ¶ms);
343 assert!(result.is_err());
344 match result {
345 Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
346 _ => panic!("expected UnknownType error"),
347 }
348 }
349
350 #[test]
351 fn test_registry_register_fn() {
352 let mut registry = NodeRegistry::<f32, 64>::new();
353 registry.register_fn("test/fn_ctor", |id, params| {
354 let mut node = TestSource::<f32, 64>::new();
355 node.set_id(id);
356 node.init(params.sample_rate);
357 NodeVariant::Source(Box::new(node))
358 });
359
360 assert!(registry.contains("test/fn_ctor"));
361 let params = NodeParams::new(44100.0);
362 let variant = registry.construct("test/fn_ctor", NodeId(1), ¶ms)
363 .expect("should construct from fn");
364 match variant {
365 NodeVariant::Source(_) => {}
366 _ => panic!("expected Source variant"),
367 }
368 }
369
370 #[test]
371 fn test_registry_list_types() {
372 let mut registry = NodeRegistry::<f32, 64>::new();
373 registry.register(TestSourceCtor);
374 registry.register(TestProcessorCtor);
375
376 let mut types = registry.list_types();
377 types.sort();
378 assert_eq!(types, vec!["test/processor", "test/source"]);
379 }
380
381 #[test]
382 fn test_registry_replace() {
383 let mut registry = NodeRegistry::<f32, 64>::new();
384 registry.register(TestSourceCtor);
385 assert_eq!(registry.len(), 1);
386
387 registry.register(TestSourceCtor);
389 assert_eq!(registry.len(), 1);
390 }
391
392 #[test]
393 fn test_registry_metadata() {
394 let mut registry = NodeRegistry::<f32, 64>::new();
395 registry.register(TestSourceCtor);
396
397 let meta = registry.metadata("test/source");
398 assert!(meta.is_some());
399 assert_eq!(meta.unwrap().name, "TestSource");
400 }
401
402 #[test]
403 fn test_construct_with_params() {
404 let mut registry = NodeRegistry::<f32, 64>::new();
405 registry.register_fn("test/with_params", |id, params| {
406 let freq = params.get_f32("frequency", 440.0);
407 assert_eq!(freq, 220.0);
408 let amp = params.get_f32("amplitude", 0.5);
409 assert_eq!(amp, 0.8);
410
411 let mut node = TestSource::<f32, 64>::new();
412 node.set_id(id);
413 node.init(params.sample_rate);
414 NodeVariant::Source(Box::new(node))
415 });
416
417 let params = NodeParams::new(44100.0)
418 .with("frequency", ParamValue::Float(220.0))
419 .with("amplitude", ParamValue::Float(0.8));
420 let result = registry.construct("test/with_params", NodeId(0), ¶ms);
421 assert!(result.is_ok());
422 }
423}