1use std::collections::HashMap;
2use std::sync::Arc;
3
4use rill_core::math::Transcendental;
5use rill_core::traits::{Node, NodeId, NodeMetadata, NodeVariant, Params};
6
7#[derive(Debug, Clone)]
13pub enum RegistryError {
14 UnknownType(String),
16}
17
18impl std::fmt::Display for RegistryError {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
22 }
23 }
24}
25
26impl std::error::Error for RegistryError {}
27
28pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
39 fn type_name(&self) -> &'static str;
41
42 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE>;
51
52 fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>>;
54}
55
56pub struct NodeFactory<T: Transcendental, const BUF_SIZE: usize> {
70 entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
71}
72
73impl<T: Transcendental, const BUF_SIZE: usize> Clone for NodeFactory<T, BUF_SIZE> {
74 fn clone(&self) -> Self {
75 Self {
76 entries: self
77 .entries
78 .iter()
79 .map(|(k, v)| (*k, v.clone_box()))
80 .collect(),
81 }
82 }
83}
84
85impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeFactory<T, BUF_SIZE> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<T: Transcendental, const BUF_SIZE: usize> NodeFactory<T, BUF_SIZE> {
92 pub fn new() -> Self {
94 Self {
95 entries: HashMap::new(),
96 }
97 }
98
99 pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
105 let name = ctor.type_name();
106 self.entries.insert(name, Box::new(ctor));
107 }
108
109 pub fn register_fn(
114 &mut self,
115 type_name: &'static str,
116 f: impl Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
117 ) {
118 self.entries.insert(
119 type_name,
120 Box::new(ClosureCtor {
121 type_name,
122 f: Arc::new(f),
123 }),
124 );
125 }
126
127 pub fn construct(
132 &self,
133 type_name: &str,
134 id: NodeId,
135 params: &Params,
136 ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
137 self.entries
138 .get(type_name)
139 .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
140 .map(|ctor| ctor.construct(id, params))
141 }
142
143 pub fn contains(&self, type_name: &str) -> bool {
145 self.entries.contains_key(type_name)
146 }
147
148 pub fn list_types(&self) -> Vec<&'static str> {
150 self.entries.keys().copied().collect()
151 }
152
153 pub fn len(&self) -> usize {
155 self.entries.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
160 self.entries.is_empty()
161 }
162
163 pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
169 self.entries.get(type_name).map(|ctor| {
170 let dummy = Params::new(44100.0);
171 let variant = ctor.construct(NodeId(u32::MAX), &dummy);
172 variant.metadata()
173 })
174 }
175}
176
177#[allow(clippy::type_complexity)]
182struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
183 type_name: &'static str,
184 f: Arc<dyn Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
185}
186
187impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
188 for ClosureCtor<T, BUF_SIZE>
189{
190 fn type_name(&self) -> &'static str {
191 self.type_name
192 }
193
194 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE> {
195 (self.f)(id, params)
196 }
197
198 fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>> {
199 Box::new(ClosureCtor {
200 type_name: self.type_name,
201 f: self.f.clone(),
202 })
203 }
204}
205
206#[macro_export]
230macro_rules! node_ctor {
231 ($registry:expr, $type_name:expr, $ctor:expr) => {
232 $registry.register_fn($type_name, $ctor);
233 };
234}
235
236#[cfg(test)]
241mod tests {
242 use super::*;
243
244 use rill_core::time::RenderContext;
245 use rill_core::traits::node::NodeState;
246 use rill_core::traits::port::Port;
247 use rill_core::traits::NodeCategory;
248 use rill_core::traits::Processor;
249 use rill_core::traits::Source;
250 use rill_core::traits::{ParamValue, ProcessResult};
251
252 struct TestSource<T: Transcendental, const B: usize> {
255 id: NodeId,
256 state: NodeState<T, B>,
257 output: Port<T, B>,
258 meta_name: &'static str,
259 meta_cat: NodeCategory,
260 }
261
262 impl<T: Transcendental, const B: usize> TestSource<T, B> {
263 fn new() -> Self {
264 Self {
265 id: NodeId(0),
266 state: NodeState::new(44100.0),
267 output: Port::output(NodeId(0), 0, "out"),
268 meta_name: "TestSource",
269 meta_cat: NodeCategory::Source,
270 }
271 }
272
273 fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
274 self.id = id;
275 self.state.sample_rate = sample_rate;
276 }
277 }
278
279 impl<T: Transcendental, const B: usize> Node<T, B> for TestSource<T, B> {
280 fn metadata(&self) -> rill_core::traits::NodeMetadata {
281 rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
282 }
283 fn init(&mut self, sample_rate: f32) {
284 self.state.sample_rate = sample_rate;
285 }
286 fn reset(&mut self) {}
287 fn get_parameter(
288 &self,
289 _: &rill_core::traits::ParameterId,
290 ) -> Option<rill_core::traits::ParamValue> {
291 None
292 }
293 fn set_parameter(
294 &mut self,
295 _: &rill_core::traits::ParameterId,
296 _: rill_core::traits::ParamValue,
297 ) -> ProcessResult<()> {
298 Ok(())
299 }
300 fn id(&self) -> NodeId {
301 self.id
302 }
303 fn set_id(&mut self, id: NodeId) {
304 self.id = id;
305 }
306 fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
307 None
308 }
309 fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
310 None
311 }
312 fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
313 if index == 0 {
314 Some(&self.output)
315 } else {
316 None
317 }
318 }
319 fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
320 if index == 0 {
321 Some(&mut self.output)
322 } else {
323 None
324 }
325 }
326 fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
327 None
328 }
329 fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
330 None
331 }
332 fn state(&self) -> &NodeState<T, B> {
333 &self.state
334 }
335 fn state_mut(&mut self) -> &mut NodeState<T, B> {
336 &mut self.state
337 }
338 }
339
340 impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
341 fn generate(
342 &mut self,
343 _: &RenderContext,
344 _: &[T],
345 _: &[RenderContext],
346 _: &rill_core::time::ClockTick,
347 ) -> ProcessResult<()> {
348 Ok(())
349 }
350 }
351
352 impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
353 fn process(
354 &mut self,
355 _: &RenderContext,
356 _: &[&[T; B]],
357 _: &[T],
358 _: &[RenderContext],
359 _: &[&[T; B]],
360 ) -> ProcessResult<()> {
361 Ok(())
362 }
363 fn latency(&self) -> usize {
364 0
365 }
366 }
367
368 struct TestSourceCtor;
369 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
370 fn type_name(&self) -> &'static str {
371 "test/source"
372 }
373 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
374 let mut node = TestSource::<T, B>::new();
375 node.set_id_and_init(id, params.sample_rate);
376 NodeVariant::Source(Box::new(node))
377 }
378 fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
379 Box::new(Self)
380 }
381 }
382
383 struct TestProcessorCtor;
384 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
385 fn type_name(&self) -> &'static str {
386 "test/processor"
387 }
388 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
389 let mut node = TestSource::<T, B>::new();
390 node.meta_name = "Noop";
391 node.meta_cat = NodeCategory::Processor;
392 node.set_id_and_init(id, params.sample_rate);
393 NodeVariant::Processor(Box::new(node))
394 }
395 fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
396 Box::new(Self)
397 }
398 }
399
400 #[test]
403 fn test_registry_empty() {
404 let registry = NodeFactory::<f32, 64>::new();
405 assert!(registry.is_empty());
406 assert_eq!(registry.len(), 0);
407 }
408
409 #[test]
410 fn test_registry_register_and_construct() {
411 let mut registry = NodeFactory::<f32, 64>::new();
412 registry.register(TestSourceCtor);
413
414 assert!(registry.contains("test/source"));
415 assert_eq!(registry.len(), 1);
416
417 let params = Params::new(48000.0);
418 let variant = registry
419 .construct("test/source", NodeId(42), ¶ms)
420 .expect("should construct");
421
422 match &variant {
423 NodeVariant::Source(_) => {}
424 _ => panic!("expected Source variant"),
425 }
426
427 assert_eq!(variant.metadata().name, "TestSource");
429 }
430
431 #[test]
432 fn test_registry_unknown_type() {
433 let registry = NodeFactory::<f32, 64>::new();
434 let params = Params::new(44100.0);
435 let result = registry.construct("nonexistent", NodeId(0), ¶ms);
436 assert!(result.is_err());
437 match result {
438 Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
439 _ => panic!("expected UnknownType error"),
440 }
441 }
442
443 #[test]
444 fn test_registry_register_fn() {
445 let mut registry = NodeFactory::<f32, 64>::new();
446 registry.register_fn("test/fn_ctor", |id, params| {
447 let mut node = TestSource::<f32, 64>::new();
448 node.set_id(id);
449 node.init(params.sample_rate);
450 NodeVariant::Source(Box::new(node))
451 });
452
453 assert!(registry.contains("test/fn_ctor"));
454 let params = Params::new(44100.0);
455 let variant = registry
456 .construct("test/fn_ctor", NodeId(1), ¶ms)
457 .expect("should construct from fn");
458 match variant {
459 NodeVariant::Source(_) => {}
460 _ => panic!("expected Source variant"),
461 }
462 }
463
464 #[test]
465 fn test_registry_list_types() {
466 let mut registry = NodeFactory::<f32, 64>::new();
467 registry.register(TestSourceCtor);
468 registry.register(TestProcessorCtor);
469
470 let mut types = registry.list_types();
471 types.sort();
472 assert_eq!(types, vec!["test/processor", "test/source"]);
473 }
474
475 #[test]
476 fn test_registry_replace() {
477 let mut registry = NodeFactory::<f32, 64>::new();
478 registry.register(TestSourceCtor);
479 assert_eq!(registry.len(), 1);
480
481 registry.register(TestSourceCtor);
483 assert_eq!(registry.len(), 1);
484 }
485
486 #[test]
487 fn test_registry_metadata() {
488 let mut registry = NodeFactory::<f32, 64>::new();
489 registry.register(TestSourceCtor);
490
491 let meta = registry.metadata("test/source");
492 assert!(meta.is_some());
493 assert_eq!(meta.unwrap().name, "TestSource");
494 }
495
496 #[test]
497 fn test_construct_with_params() {
498 let mut registry = NodeFactory::<f32, 64>::new();
499 registry.register_fn("test/with_params", |id, params| {
500 let freq = params.get_f32("frequency", 440.0);
501 assert_eq!(freq, 220.0);
502 let amp = params.get_f32("amplitude", 0.5);
503 assert_eq!(amp, 0.8);
504
505 let mut node = TestSource::<f32, 64>::new();
506 node.set_id(id);
507 node.init(params.sample_rate);
508 NodeVariant::Source(Box::new(node))
509 });
510
511 let params = Params::new(44100.0)
512 .with("frequency", ParamValue::Float(220.0))
513 .with("amplitude", ParamValue::Float(0.8));
514 let result = registry.construct("test/with_params", NodeId(0), ¶ms);
515 assert!(result.is_ok());
516 }
517}