Skip to main content

rill_graph/
factory.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use rill_core::math::Transcendental;
5use rill_core::traits::{Node, NodeId, NodeMetadata, NodeVariant, Params};
6
7// ============================================================================
8// Registry Error
9// ============================================================================
10
11/// Errors that can occur during node construction via the registry.
12#[derive(Debug, Clone)]
13pub enum RegistryError {
14    /// No constructor registered for the given type name.
15    UnknownType(String),
16}
17
18impl std::fmt::Display for RegistryError {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
22        }
23    }
24}
25
26impl std::error::Error for RegistryError {}
27
28// ============================================================================
29// NodeConstructor Trait
30// ============================================================================
31
32/// Factory trait for creating graph nodes by type name.
33///
34/// Each node type that wants to be constructable via the registry
35/// implements this trait. The [`construct`](Self::construct) method
36/// receives a [`NodeId`] and [`Params`] and must return the
37/// appropriate [`NodeVariant`].
38pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
39    /// Canonical name for this node type (e.g. `"rill/sine_osc"`).
40    fn type_name(&self) -> &'static str;
41
42    /// Build a fully initialised node variant.
43    ///
44    /// Implementations should:
45    /// 1. Extract parameters from `params`.
46    /// 2. Create the concrete node.
47    /// 3. Call [`Node::set_id`] with the given `id`.
48    /// 4. Call [`Node::init`] with `params.sample_rate`.
49    /// 5. Wrap in the correct [`NodeVariant`] variant.
50    fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE>;
51
52    /// Clone this constructor into a boxed trait object.
53    fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>>;
54}
55
56// ============================================================================
57// NodeFactory
58// ============================================================================
59
60/// A registry of named node constructors.
61///
62/// Register constructors with [`register`](Self::register), then create
63/// nodes by type name with [`construct`](Self::construct).
64///
65/// # Type parameters
66///
67/// - `T` — sample type (typically `f32`)
68/// - `BUF_SIZE` — block size (must match the target graph)
69pub struct NodeFactory<T: Transcendental, const BUF_SIZE: usize> {
70    entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
71}
72
73impl<T: Transcendental, const BUF_SIZE: usize> Clone for NodeFactory<T, BUF_SIZE> {
74    fn clone(&self) -> Self {
75        Self {
76            entries: self
77                .entries
78                .iter()
79                .map(|(k, v)| (*k, v.clone_box()))
80                .collect(),
81        }
82    }
83}
84
85impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeFactory<T, BUF_SIZE> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl<T: Transcendental, const BUF_SIZE: usize> NodeFactory<T, BUF_SIZE> {
92    /// Create an empty registry.
93    pub fn new() -> Self {
94        Self {
95            entries: HashMap::new(),
96        }
97    }
98
99    /// Register a node constructor.
100    ///
101    /// The constructor's [`type_name`](NodeConstructor::type_name) is used
102    /// as the lookup key. If a constructor with the same name already exists,
103    /// it is replaced.
104    pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
105        let name = ctor.type_name();
106        self.entries.insert(name, Box::new(ctor));
107    }
108
109    /// Register a node type via a closure.
110    ///
111    /// This is a convenience wrapper around [`register`](Self::register) for
112    /// cases where a full struct + trait impl is not needed.
113    pub fn register_fn(
114        &mut self,
115        type_name: &'static str,
116        f: impl Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
117    ) {
118        self.entries.insert(
119            type_name,
120            Box::new(ClosureCtor {
121                type_name,
122                f: Arc::new(f),
123            }),
124        );
125    }
126
127    /// Construct a node by type name.
128    ///
129    /// Returns [`RegistryError::UnknownType`] if the name has not been
130    /// registered.
131    pub fn construct(
132        &self,
133        type_name: &str,
134        id: NodeId,
135        params: &Params,
136    ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
137        self.entries
138            .get(type_name)
139            .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
140            .map(|ctor| ctor.construct(id, params))
141    }
142
143    /// Check whether a type name is registered.
144    pub fn contains(&self, type_name: &str) -> bool {
145        self.entries.contains_key(type_name)
146    }
147
148    /// List all registered type names.
149    pub fn list_types(&self) -> Vec<&'static str> {
150        self.entries.keys().copied().collect()
151    }
152
153    /// Number of registered constructors.
154    pub fn len(&self) -> usize {
155        self.entries.len()
156    }
157
158    /// True when no constructors are registered.
159    pub fn is_empty(&self) -> bool {
160        self.entries.is_empty()
161    }
162
163    /// Get metadata for a registered type without constructing a node.
164    ///
165    /// This requires constructing a temporary node and immediately
166    /// discarding it. If performance is a concern, cache the metadata
167    /// alongside the constructor in the registry.
168    pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
169        self.entries.get(type_name).map(|ctor| {
170            let dummy = Params::new(44100.0);
171            let variant = ctor.construct(NodeId(u32::MAX), &dummy);
172            variant.metadata()
173        })
174    }
175}
176
177// ============================================================================
178// Internal: closure-based constructor wrapper
179// ============================================================================
180
181#[allow(clippy::type_complexity)]
182struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
183    type_name: &'static str,
184    f: Arc<dyn Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
185}
186
187impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
188    for ClosureCtor<T, BUF_SIZE>
189{
190    fn type_name(&self) -> &'static str {
191        self.type_name
192    }
193
194    fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE> {
195        (self.f)(id, params)
196    }
197
198    fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>> {
199        Box::new(ClosureCtor {
200            type_name: self.type_name,
201            f: self.f.clone(),
202        })
203    }
204}
205
206// ============================================================================
207// Node Ctor Macro
208// ============================================================================
209
210/// Register a node constructor by type name.
211///
212/// Shorthand for [`NodeFactory::register_fn`]. Emits a call to
213/// `registry.register_fn(type_name, closure)`.
214///
215/// # Example
216///
217/// ```rust
218/// use rill_graph::{node_ctor, NodeFactory};
219/// use rill_core::traits::{NodeId, Params, NodeVariant, Source, Node};
220///
221/// // Inside a function that has access to a &mut NodeFactory<f32, 64>:
222/// fn register(registry: &mut NodeFactory<f32, 64>) {
223///     node_ctor!(registry, "test/my_source", |id, params| {
224///         // construct and return NodeVariant
225///         todo!()
226///     });
227/// }
228/// ```
229#[macro_export]
230macro_rules! node_ctor {
231    ($registry:expr, $type_name:expr, $ctor:expr) => {
232        $registry.register_fn($type_name, $ctor);
233    };
234}
235
236// ============================================================================
237// Tests
238// ============================================================================
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    use rill_core::time::RenderContext;
245    use rill_core::traits::node::NodeState;
246    use rill_core::traits::port::Port;
247    use rill_core::traits::NodeCategory;
248    use rill_core::traits::Processor;
249    use rill_core::traits::Source;
250    use rill_core::traits::{ParamValue, ProcessResult};
251
252    // ── Test helpers ────────────────────────────────────────────────
253
254    struct TestSource<T: Transcendental, const B: usize> {
255        id: NodeId,
256        state: NodeState<T, B>,
257        output: Port<T, B>,
258        meta_name: &'static str,
259        meta_cat: NodeCategory,
260    }
261
262    impl<T: Transcendental, const B: usize> TestSource<T, B> {
263        fn new() -> Self {
264            Self {
265                id: NodeId(0),
266                state: NodeState::new(44100.0),
267                output: Port::output(NodeId(0), 0, "out"),
268                meta_name: "TestSource",
269                meta_cat: NodeCategory::Source,
270            }
271        }
272
273        fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
274            self.id = id;
275            self.state.sample_rate = sample_rate;
276        }
277    }
278
279    impl<T: Transcendental, const B: usize> Node<T, B> for TestSource<T, B> {
280        fn metadata(&self) -> rill_core::traits::NodeMetadata {
281            rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
282        }
283        fn init(&mut self, sample_rate: f32) {
284            self.state.sample_rate = sample_rate;
285        }
286        fn reset(&mut self) {}
287        fn get_parameter(
288            &self,
289            _: &rill_core::traits::ParameterId,
290        ) -> Option<rill_core::traits::ParamValue> {
291            None
292        }
293        fn set_parameter(
294            &mut self,
295            _: &rill_core::traits::ParameterId,
296            _: rill_core::traits::ParamValue,
297        ) -> ProcessResult<()> {
298            Ok(())
299        }
300        fn id(&self) -> NodeId {
301            self.id
302        }
303        fn set_id(&mut self, id: NodeId) {
304            self.id = id;
305        }
306        fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
307            None
308        }
309        fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
310            None
311        }
312        fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
313            if index == 0 {
314                Some(&self.output)
315            } else {
316                None
317            }
318        }
319        fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
320            if index == 0 {
321                Some(&mut self.output)
322            } else {
323                None
324            }
325        }
326        fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
327            None
328        }
329        fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
330            None
331        }
332        fn state(&self) -> &NodeState<T, B> {
333            &self.state
334        }
335        fn state_mut(&mut self) -> &mut NodeState<T, B> {
336            &mut self.state
337        }
338    }
339
340    impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
341        fn generate(
342            &mut self,
343            _: &RenderContext,
344            _: &[T],
345            _: &[RenderContext],
346            _: &rill_core::time::ClockTick,
347        ) -> ProcessResult<()> {
348            Ok(())
349        }
350    }
351
352    impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
353        fn process(
354            &mut self,
355            _: &RenderContext,
356            _: &[&[T; B]],
357            _: &[T],
358            _: &[RenderContext],
359            _: &[&[T; B]],
360        ) -> ProcessResult<()> {
361            Ok(())
362        }
363        fn latency(&self) -> usize {
364            0
365        }
366    }
367
368    struct TestSourceCtor;
369    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
370        fn type_name(&self) -> &'static str {
371            "test/source"
372        }
373        fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
374            let mut node = TestSource::<T, B>::new();
375            node.set_id_and_init(id, params.sample_rate);
376            NodeVariant::Source(Box::new(node))
377        }
378        fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
379            Box::new(Self)
380        }
381    }
382
383    struct TestProcessorCtor;
384    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
385        fn type_name(&self) -> &'static str {
386            "test/processor"
387        }
388        fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
389            let mut node = TestSource::<T, B>::new();
390            node.meta_name = "Noop";
391            node.meta_cat = NodeCategory::Processor;
392            node.set_id_and_init(id, params.sample_rate);
393            NodeVariant::Processor(Box::new(node))
394        }
395        fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
396            Box::new(Self)
397        }
398    }
399
400    // ── Tests ───────────────────────────────────────────────────────
401
402    #[test]
403    fn test_registry_empty() {
404        let registry = NodeFactory::<f32, 64>::new();
405        assert!(registry.is_empty());
406        assert_eq!(registry.len(), 0);
407    }
408
409    #[test]
410    fn test_registry_register_and_construct() {
411        let mut registry = NodeFactory::<f32, 64>::new();
412        registry.register(TestSourceCtor);
413
414        assert!(registry.contains("test/source"));
415        assert_eq!(registry.len(), 1);
416
417        let params = Params::new(48000.0);
418        let variant = registry
419            .construct("test/source", NodeId(42), &params)
420            .expect("should construct");
421
422        match &variant {
423            NodeVariant::Source(_) => {}
424            _ => panic!("expected Source variant"),
425        }
426
427        // Verify init was called (sample_rate stored in state)
428        assert_eq!(variant.metadata().name, "TestSource");
429    }
430
431    #[test]
432    fn test_registry_unknown_type() {
433        let registry = NodeFactory::<f32, 64>::new();
434        let params = Params::new(44100.0);
435        let result = registry.construct("nonexistent", NodeId(0), &params);
436        assert!(result.is_err());
437        match result {
438            Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
439            _ => panic!("expected UnknownType error"),
440        }
441    }
442
443    #[test]
444    fn test_registry_register_fn() {
445        let mut registry = NodeFactory::<f32, 64>::new();
446        registry.register_fn("test/fn_ctor", |id, params| {
447            let mut node = TestSource::<f32, 64>::new();
448            node.set_id(id);
449            node.init(params.sample_rate);
450            NodeVariant::Source(Box::new(node))
451        });
452
453        assert!(registry.contains("test/fn_ctor"));
454        let params = Params::new(44100.0);
455        let variant = registry
456            .construct("test/fn_ctor", NodeId(1), &params)
457            .expect("should construct from fn");
458        match variant {
459            NodeVariant::Source(_) => {}
460            _ => panic!("expected Source variant"),
461        }
462    }
463
464    #[test]
465    fn test_registry_list_types() {
466        let mut registry = NodeFactory::<f32, 64>::new();
467        registry.register(TestSourceCtor);
468        registry.register(TestProcessorCtor);
469
470        let mut types = registry.list_types();
471        types.sort();
472        assert_eq!(types, vec!["test/processor", "test/source"]);
473    }
474
475    #[test]
476    fn test_registry_replace() {
477        let mut registry = NodeFactory::<f32, 64>::new();
478        registry.register(TestSourceCtor);
479        assert_eq!(registry.len(), 1);
480
481        // Registering again under the same name replaces.
482        registry.register(TestSourceCtor);
483        assert_eq!(registry.len(), 1);
484    }
485
486    #[test]
487    fn test_registry_metadata() {
488        let mut registry = NodeFactory::<f32, 64>::new();
489        registry.register(TestSourceCtor);
490
491        let meta = registry.metadata("test/source");
492        assert!(meta.is_some());
493        assert_eq!(meta.unwrap().name, "TestSource");
494    }
495
496    #[test]
497    fn test_construct_with_params() {
498        let mut registry = NodeFactory::<f32, 64>::new();
499        registry.register_fn("test/with_params", |id, params| {
500            let freq = params.get_f32("frequency", 440.0);
501            assert_eq!(freq, 220.0);
502            let amp = params.get_f32("amplitude", 0.5);
503            assert_eq!(amp, 0.8);
504
505            let mut node = TestSource::<f32, 64>::new();
506            node.set_id(id);
507            node.init(params.sample_rate);
508            NodeVariant::Source(Box::new(node))
509        });
510
511        let params = Params::new(44100.0)
512            .with("frequency", ParamValue::Float(220.0))
513            .with("amplitude", ParamValue::Float(0.8));
514        let result = registry.construct("test/with_params", NodeId(0), &params);
515        assert!(result.is_ok());
516    }
517}