1use std::collections::HashMap;
2use std::sync::Arc;
3
4use rill_core::math::Transcendental;
5use rill_core::traits::{Node, NodeId, NodeMetadata, NodeVariant, Params};
6
7#[derive(Debug, Clone)]
13pub enum RegistryError {
14 UnknownType(String),
16}
17
18impl std::fmt::Display for RegistryError {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
22 }
23 }
24}
25
26impl std::error::Error for RegistryError {}
27
28pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
39 fn type_name(&self) -> &'static str;
41
42 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE>;
51
52 fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>>;
54}
55
56pub struct NodeFactory<T: Transcendental, const BUF_SIZE: usize> {
70 entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
71}
72
73impl<T: Transcendental, const BUF_SIZE: usize> Clone for NodeFactory<T, BUF_SIZE> {
74 fn clone(&self) -> Self {
75 Self {
76 entries: self
77 .entries
78 .iter()
79 .map(|(k, v)| (*k, v.clone_box()))
80 .collect(),
81 }
82 }
83}
84
85impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeFactory<T, BUF_SIZE> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl<T: Transcendental, const BUF_SIZE: usize> NodeFactory<T, BUF_SIZE> {
92 pub fn new() -> Self {
94 Self {
95 entries: HashMap::new(),
96 }
97 }
98
99 pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
105 let name = ctor.type_name();
106 self.entries.insert(name, Box::new(ctor));
107 }
108
109 pub fn register_fn(
114 &mut self,
115 type_name: &'static str,
116 f: impl Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
117 ) {
118 self.entries.insert(
119 type_name,
120 Box::new(ClosureCtor {
121 type_name,
122 f: Arc::new(f),
123 }),
124 );
125 }
126
127 pub fn construct(
132 &self,
133 type_name: &str,
134 id: NodeId,
135 params: &Params,
136 ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
137 self.entries
138 .get(type_name)
139 .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
140 .map(|ctor| ctor.construct(id, params))
141 }
142
143 pub fn contains(&self, type_name: &str) -> bool {
145 self.entries.contains_key(type_name)
146 }
147
148 pub fn list_types(&self) -> Vec<&'static str> {
150 self.entries.keys().copied().collect()
151 }
152
153 pub fn len(&self) -> usize {
155 self.entries.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
160 self.entries.is_empty()
161 }
162
163 pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
169 self.entries.get(type_name).map(|ctor| {
170 let dummy = Params::new(44100.0);
171 let variant = ctor.construct(NodeId(u32::MAX), &dummy);
172 variant.metadata()
173 })
174 }
175}
176
177#[allow(clippy::type_complexity)]
182struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
183 type_name: &'static str,
184 f: Arc<dyn Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
185}
186
187impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
188 for ClosureCtor<T, BUF_SIZE>
189{
190 fn type_name(&self) -> &'static str {
191 self.type_name
192 }
193
194 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE> {
195 (self.f)(id, params)
196 }
197
198 fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>> {
199 Box::new(ClosureCtor {
200 type_name: self.type_name,
201 f: self.f.clone(),
202 })
203 }
204}
205
206#[macro_export]
230macro_rules! node_ctor {
231 ($registry:expr, $type_name:expr, $ctor:expr) => {
232 $registry.register_fn($type_name, $ctor);
233 };
234}
235
236#[cfg(test)]
241mod tests {
242 use super::*;
243
244 use rill_core::time::ClockTick;
245 use rill_core::traits::node::NodeState;
246 use rill_core::traits::port::Port;
247 use rill_core::traits::NodeCategory;
248 use rill_core::traits::Processor;
249 use rill_core::traits::Source;
250 use rill_core::traits::{ParamValue, ProcessResult};
251
252 struct TestSource<T: Transcendental, const B: usize> {
255 id: NodeId,
256 state: NodeState<T, B>,
257 output: Port<T, B>,
258 meta_name: &'static str,
259 meta_cat: NodeCategory,
260 }
261
262 impl<T: Transcendental, const B: usize> TestSource<T, B> {
263 fn new() -> Self {
264 Self {
265 id: NodeId(0),
266 state: NodeState::new(44100.0),
267 output: Port::output(NodeId(0), 0, "out"),
268 meta_name: "TestSource",
269 meta_cat: NodeCategory::Source,
270 }
271 }
272
273 fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
274 self.id = id;
275 self.state.sample_rate = sample_rate;
276 }
277 }
278
279 impl<T: Transcendental, const B: usize> Node<T, B> for TestSource<T, B> {
280 fn metadata(&self) -> rill_core::traits::NodeMetadata {
281 rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
282 }
283 fn init(&mut self, sample_rate: f32) {
284 self.state.sample_rate = sample_rate;
285 }
286 fn reset(&mut self) {}
287 fn get_parameter(
288 &self,
289 _: &rill_core::traits::ParameterId,
290 ) -> Option<rill_core::traits::ParamValue> {
291 None
292 }
293 fn set_parameter(
294 &mut self,
295 _: &rill_core::traits::ParameterId,
296 _: rill_core::traits::ParamValue,
297 ) -> ProcessResult<()> {
298 Ok(())
299 }
300 fn id(&self) -> NodeId {
301 self.id
302 }
303 fn set_id(&mut self, id: NodeId) {
304 self.id = id;
305 }
306 fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
307 None
308 }
309 fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
310 None
311 }
312 fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
313 if index == 0 {
314 Some(&self.output)
315 } else {
316 None
317 }
318 }
319 fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
320 if index == 0 {
321 Some(&mut self.output)
322 } else {
323 None
324 }
325 }
326 fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
327 None
328 }
329 fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
330 None
331 }
332 fn state(&self) -> &NodeState<T, B> {
333 &self.state
334 }
335 fn state_mut(&mut self) -> &mut NodeState<T, B> {
336 &mut self.state
337 }
338 }
339
340 impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
341 fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> {
342 Ok(())
343 }
344 }
345
346 impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
347 fn process(
348 &mut self,
349 _: &ClockTick,
350 _: &[&[T; B]],
351 _: &[T],
352 _: &[ClockTick],
353 _: &[&[T; B]],
354 ) -> ProcessResult<()> {
355 Ok(())
356 }
357 fn latency(&self) -> usize {
358 0
359 }
360 }
361
362 struct TestSourceCtor;
363 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
364 fn type_name(&self) -> &'static str {
365 "test/source"
366 }
367 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
368 let mut node = TestSource::<T, B>::new();
369 node.set_id_and_init(id, params.sample_rate);
370 NodeVariant::Source(Box::new(node))
371 }
372 fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
373 Box::new(Self)
374 }
375 }
376
377 struct TestProcessorCtor;
378 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
379 fn type_name(&self) -> &'static str {
380 "test/processor"
381 }
382 fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
383 let mut node = TestSource::<T, B>::new();
384 node.meta_name = "Noop";
385 node.meta_cat = NodeCategory::Processor;
386 node.set_id_and_init(id, params.sample_rate);
387 NodeVariant::Processor(Box::new(node))
388 }
389 fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
390 Box::new(Self)
391 }
392 }
393
394 #[test]
397 fn test_registry_empty() {
398 let registry = NodeFactory::<f32, 64>::new();
399 assert!(registry.is_empty());
400 assert_eq!(registry.len(), 0);
401 }
402
403 #[test]
404 fn test_registry_register_and_construct() {
405 let mut registry = NodeFactory::<f32, 64>::new();
406 registry.register(TestSourceCtor);
407
408 assert!(registry.contains("test/source"));
409 assert_eq!(registry.len(), 1);
410
411 let params = Params::new(48000.0);
412 let variant = registry
413 .construct("test/source", NodeId(42), ¶ms)
414 .expect("should construct");
415
416 match &variant {
417 NodeVariant::Source(_) => {}
418 _ => panic!("expected Source variant"),
419 }
420
421 assert_eq!(variant.metadata().name, "TestSource");
423 }
424
425 #[test]
426 fn test_registry_unknown_type() {
427 let registry = NodeFactory::<f32, 64>::new();
428 let params = Params::new(44100.0);
429 let result = registry.construct("nonexistent", NodeId(0), ¶ms);
430 assert!(result.is_err());
431 match result {
432 Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
433 _ => panic!("expected UnknownType error"),
434 }
435 }
436
437 #[test]
438 fn test_registry_register_fn() {
439 let mut registry = NodeFactory::<f32, 64>::new();
440 registry.register_fn("test/fn_ctor", |id, params| {
441 let mut node = TestSource::<f32, 64>::new();
442 node.set_id(id);
443 node.init(params.sample_rate);
444 NodeVariant::Source(Box::new(node))
445 });
446
447 assert!(registry.contains("test/fn_ctor"));
448 let params = Params::new(44100.0);
449 let variant = registry
450 .construct("test/fn_ctor", NodeId(1), ¶ms)
451 .expect("should construct from fn");
452 match variant {
453 NodeVariant::Source(_) => {}
454 _ => panic!("expected Source variant"),
455 }
456 }
457
458 #[test]
459 fn test_registry_list_types() {
460 let mut registry = NodeFactory::<f32, 64>::new();
461 registry.register(TestSourceCtor);
462 registry.register(TestProcessorCtor);
463
464 let mut types = registry.list_types();
465 types.sort();
466 assert_eq!(types, vec!["test/processor", "test/source"]);
467 }
468
469 #[test]
470 fn test_registry_replace() {
471 let mut registry = NodeFactory::<f32, 64>::new();
472 registry.register(TestSourceCtor);
473 assert_eq!(registry.len(), 1);
474
475 registry.register(TestSourceCtor);
477 assert_eq!(registry.len(), 1);
478 }
479
480 #[test]
481 fn test_registry_metadata() {
482 let mut registry = NodeFactory::<f32, 64>::new();
483 registry.register(TestSourceCtor);
484
485 let meta = registry.metadata("test/source");
486 assert!(meta.is_some());
487 assert_eq!(meta.unwrap().name, "TestSource");
488 }
489
490 #[test]
491 fn test_construct_with_params() {
492 let mut registry = NodeFactory::<f32, 64>::new();
493 registry.register_fn("test/with_params", |id, params| {
494 let freq = params.get_f32("frequency", 440.0);
495 assert_eq!(freq, 220.0);
496 let amp = params.get_f32("amplitude", 0.5);
497 assert_eq!(amp, 0.8);
498
499 let mut node = TestSource::<f32, 64>::new();
500 node.set_id(id);
501 node.init(params.sample_rate);
502 NodeVariant::Source(Box::new(node))
503 });
504
505 let params = Params::new(44100.0)
506 .with("frequency", ParamValue::Float(220.0))
507 .with("amplitude", ParamValue::Float(0.8));
508 let result = registry.construct("test/with_params", NodeId(0), ¶ms);
509 assert!(result.is_ok());
510 }
511}