1use std::collections::HashMap;
2use std::sync::Arc;
3
4use rill_core::math::Transcendental;
5use rill_core::traits::{Node, NodeId, NodeMetadata, NodeVariant, Params};
6
7#[derive(Debug, Clone)]
13pub enum RegistryError {
14 UnknownType(String),
16}
17
18impl std::fmt::Display for RegistryError {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
22 }
23 }
24}
25
26impl std::error::Error for RegistryError {}
27
28pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
39 fn type_name(&self) -> &'static str;
41
42 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE>;
51
52 fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>>;
54}
55
56pub struct NodeFactory<T: Transcendental, const BUF_SIZE: usize> {
70 entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
71}
72
73impl<T: Transcendental, const BUF_SIZE: usize> Clone for NodeFactory<T, BUF_SIZE> {
74 fn clone(&self) -> Self {
75 Self {
76 entries: self
77 .entries
78 .iter()
79 .map(|(k, v)| (*k, v.clone_box()))
80 .collect(),
81 }
82 }
83}
84
85impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeFactory<T, BUF_SIZE> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<T: Transcendental, const BUF_SIZE: usize> NodeFactory<T, BUF_SIZE> {
92 pub fn new() -> Self {
94 Self {
95 entries: HashMap::new(),
96 }
97 }
98
99 pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
105 let name = ctor.type_name();
106 self.entries.insert(name, Box::new(ctor));
107 }
108
109 pub fn register_fn(
114 &mut self,
115 type_name: &'static str,
116 f: impl Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
117 ) {
118 self.entries.insert(
119 type_name,
120 Box::new(ClosureCtor {
121 type_name,
122 f: Arc::new(f),
123 }),
124 );
125 }
126
127 pub fn construct(
132 &self,
133 type_name: &str,
134 id: NodeId,
135 params: &Params,
136 ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
137 self.entries
138 .get(type_name)
139 .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
140 .map(|ctor| ctor.construct(id, params))
141 }
142
143 pub fn contains(&self, type_name: &str) -> bool {
145 self.entries.contains_key(type_name)
146 }
147
148 pub fn list_types(&self) -> Vec<&'static str> {
150 self.entries.keys().copied().collect()
151 }
152
153 pub fn len(&self) -> usize {
155 self.entries.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
160 self.entries.is_empty()
161 }
162
163 pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
169 self.entries.get(type_name).map(|ctor| {
170 let dummy = Params::new(44100.0);
171 let variant = ctor.construct(NodeId(u32::MAX), &dummy);
172 variant.metadata()
173 })
174 }
175}
176
177#[allow(clippy::type_complexity)]
182struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
183 type_name: &'static str,
184 f: Arc<dyn Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
185}
186
187impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
188 for ClosureCtor<T, BUF_SIZE>
189{
190 fn type_name(&self) -> &'static str {
191 self.type_name
192 }
193
194 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE> {
195 (self.f)(id, params)
196 }
197
198 fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>> {
199 Box::new(ClosureCtor {
200 type_name: self.type_name,
201 f: self.f.clone(),
202 })
203 }
204}
205
206#[macro_export]
230macro_rules! node_ctor {
231 ($registry:expr, $type_name:expr, $ctor:expr) => {
232 $registry.register_fn($type_name, $ctor);
233 };
234}
235
236#[cfg(test)]
241mod tests {
242 use super::*;
243
244 use rill_core::time::RenderContext;
245 use rill_core::traits::node::NodeState;
246 use rill_core::traits::port::Port;
247 use rill_core::traits::NodeCategory;
248 use rill_core::traits::Processor;
249 use rill_core::traits::Source;
250 use rill_core::traits::{ParamValue, ProcessResult};
251
252 struct TestSource<T: Transcendental, const B: usize> {
255 id: NodeId,
256 state: NodeState<T, B>,
257 output: Port<T, B>,
258 meta_name: &'static str,
259 meta_cat: NodeCategory,
260 }
261
262 impl<T: Transcendental, const B: usize> TestSource<T, B> {
263 fn new() -> Self {
264 Self {
265 id: NodeId(0),
266 state: NodeState::new(44100.0),
267 output: Port::output(NodeId(0), 0, "out"),
268 meta_name: "TestSource",
269 meta_cat: NodeCategory::Source,
270 }
271 }
272
273 fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
274 self.id = id;
275 self.state.sample_rate = sample_rate;
276 }
277 }
278
279 impl<T: Transcendental, const B: usize> Node<T, B> for TestSource<T, B> {
280 fn metadata(&self) -> rill_core::traits::NodeMetadata {
281 rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
282 }
283 fn init(&mut self, sample_rate: f32) {
284 self.state.sample_rate = sample_rate;
285 }
286 fn reset(&mut self) {}
287 fn get_parameter(
288 &self,
289 _: &rill_core::traits::ParameterId,
290 ) -> Option<rill_core::traits::ParamValue> {
291 None
292 }
293 fn set_parameter(
294 &mut self,
295 _: &rill_core::traits::ParameterId,
296 _: rill_core::traits::ParamValue,
297 ) -> ProcessResult<()> {
298 Ok(())
299 }
300 fn id(&self) -> NodeId {
301 self.id
302 }
303 fn set_id(&mut self, id: NodeId) {
304 self.id = id;
305 }
306 fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
307 None
308 }
309 fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
310 None
311 }
312 fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
313 if index == 0 {
314 Some(&self.output)
315 } else {
316 None
317 }
318 }
319 fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
320 if index == 0 {
321 Some(&mut self.output)
322 } else {
323 None
324 }
325 }
326 fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
327 None
328 }
329 fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
330 None
331 }
332 fn state(&self) -> &NodeState<T, B> {
333 &self.state
334 }
335 fn state_mut(&mut self) -> &mut NodeState<T, B> {
336 &mut self.state
337 }
338 }
339
340 impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
341 fn generate(
342 &mut self,
343 _: &RenderContext,
344 _: &[T],
345 _: &[RenderContext],
346 ) -> ProcessResult<()> {
347 Ok(())
348 }
349 }
350
351 impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
352 fn process(
353 &mut self,
354 _: &RenderContext,
355 _: &[&[T; B]],
356 _: &[T],
357 _: &[RenderContext],
358 _: &[&[T; B]],
359 ) -> ProcessResult<()> {
360 Ok(())
361 }
362 fn latency(&self) -> usize {
363 0
364 }
365 }
366
367 struct TestSourceCtor;
368 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
369 fn type_name(&self) -> &'static str {
370 "test/source"
371 }
372 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
373 let mut node = TestSource::<T, B>::new();
374 node.set_id_and_init(id, params.sample_rate);
375 NodeVariant::Source(Box::new(node))
376 }
377 fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
378 Box::new(Self)
379 }
380 }
381
382 struct TestProcessorCtor;
383 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
384 fn type_name(&self) -> &'static str {
385 "test/processor"
386 }
387 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
388 let mut node = TestSource::<T, B>::new();
389 node.meta_name = "Noop";
390 node.meta_cat = NodeCategory::Processor;
391 node.set_id_and_init(id, params.sample_rate);
392 NodeVariant::Processor(Box::new(node))
393 }
394 fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
395 Box::new(Self)
396 }
397 }
398
399 #[test]
402 fn test_registry_empty() {
403 let registry = NodeFactory::<f32, 64>::new();
404 assert!(registry.is_empty());
405 assert_eq!(registry.len(), 0);
406 }
407
408 #[test]
409 fn test_registry_register_and_construct() {
410 let mut registry = NodeFactory::<f32, 64>::new();
411 registry.register(TestSourceCtor);
412
413 assert!(registry.contains("test/source"));
414 assert_eq!(registry.len(), 1);
415
416 let params = Params::new(48000.0);
417 let variant = registry
418 .construct("test/source", NodeId(42), ¶ms)
419 .expect("should construct");
420
421 match &variant {
422 NodeVariant::Source(_) => {}
423 _ => panic!("expected Source variant"),
424 }
425
426 assert_eq!(variant.metadata().name, "TestSource");
428 }
429
430 #[test]
431 fn test_registry_unknown_type() {
432 let registry = NodeFactory::<f32, 64>::new();
433 let params = Params::new(44100.0);
434 let result = registry.construct("nonexistent", NodeId(0), ¶ms);
435 assert!(result.is_err());
436 match result {
437 Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
438 _ => panic!("expected UnknownType error"),
439 }
440 }
441
442 #[test]
443 fn test_registry_register_fn() {
444 let mut registry = NodeFactory::<f32, 64>::new();
445 registry.register_fn("test/fn_ctor", |id, params| {
446 let mut node = TestSource::<f32, 64>::new();
447 node.set_id(id);
448 node.init(params.sample_rate);
449 NodeVariant::Source(Box::new(node))
450 });
451
452 assert!(registry.contains("test/fn_ctor"));
453 let params = Params::new(44100.0);
454 let variant = registry
455 .construct("test/fn_ctor", NodeId(1), ¶ms)
456 .expect("should construct from fn");
457 match variant {
458 NodeVariant::Source(_) => {}
459 _ => panic!("expected Source variant"),
460 }
461 }
462
463 #[test]
464 fn test_registry_list_types() {
465 let mut registry = NodeFactory::<f32, 64>::new();
466 registry.register(TestSourceCtor);
467 registry.register(TestProcessorCtor);
468
469 let mut types = registry.list_types();
470 types.sort();
471 assert_eq!(types, vec!["test/processor", "test/source"]);
472 }
473
474 #[test]
475 fn test_registry_replace() {
476 let mut registry = NodeFactory::<f32, 64>::new();
477 registry.register(TestSourceCtor);
478 assert_eq!(registry.len(), 1);
479
480 registry.register(TestSourceCtor);
482 assert_eq!(registry.len(), 1);
483 }
484
485 #[test]
486 fn test_registry_metadata() {
487 let mut registry = NodeFactory::<f32, 64>::new();
488 registry.register(TestSourceCtor);
489
490 let meta = registry.metadata("test/source");
491 assert!(meta.is_some());
492 assert_eq!(meta.unwrap().name, "TestSource");
493 }
494
495 #[test]
496 fn test_construct_with_params() {
497 let mut registry = NodeFactory::<f32, 64>::new();
498 registry.register_fn("test/with_params", |id, params| {
499 let freq = params.get_f32("frequency", 440.0);
500 assert_eq!(freq, 220.0);
501 let amp = params.get_f32("amplitude", 0.5);
502 assert_eq!(amp, 0.8);
503
504 let mut node = TestSource::<f32, 64>::new();
505 node.set_id(id);
506 node.init(params.sample_rate);
507 NodeVariant::Source(Box::new(node))
508 });
509
510 let params = Params::new(44100.0)
511 .with("frequency", ParamValue::Float(220.0))
512 .with("amplitude", ParamValue::Float(0.8));
513 let result = registry.construct("test/with_params", NodeId(0), ¶ms);
514 assert!(result.is_ok());
515 }
516}