Skip to main content

rill_graph/
factory.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use rill_core::math::Transcendental;
5use rill_core::traits::{Node, NodeId, NodeMetadata, NodeVariant, Params};
6
7// ============================================================================
8// Registry Error
9// ============================================================================
10
11/// Errors that can occur during node construction via the registry.
12#[derive(Debug, Clone)]
13pub enum RegistryError {
14    /// No constructor registered for the given type name.
15    UnknownType(String),
16}
17
18impl std::fmt::Display for RegistryError {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
22        }
23    }
24}
25
26impl std::error::Error for RegistryError {}
27
28// ============================================================================
29// NodeConstructor Trait
30// ============================================================================
31
32/// Factory trait for creating graph nodes by type name.
33///
34/// Each node type that wants to be constructable via the registry
35/// implements this trait. The [`construct`](Self::construct) method
36/// receives a [`NodeId`] and [`Params`] and must return the
37/// appropriate [`NodeVariant`].
38pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
39    /// Canonical name for this node type (e.g. `"rill/sine_osc"`).
40    fn type_name(&self) -> &'static str;
41
42    /// Build a fully initialised node variant.
43    ///
44    /// Implementations should:
45    /// 1. Extract parameters from `params`.
46    /// 2. Create the concrete node.
47    /// 3. Call [`Node::set_id`] with the given `id`.
48    /// 4. Call [`Node::init`] with `params.sample_rate`.
49    /// 5. Wrap in the correct [`NodeVariant`] variant.
50    fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE>;
51
52    /// Clone this constructor into a boxed trait object.
53    fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>>;
54}
55
56// ============================================================================
57// NodeFactory
58// ============================================================================
59
60/// A registry of named node constructors.
61///
62/// Register constructors with [`register`](Self::register), then create
63/// nodes by type name with [`construct`](Self::construct).
64///
65/// # Type parameters
66///
67/// - `T` — sample type (typically `f32`)
68/// - `BUF_SIZE` — block size (must match the target graph)
69pub struct NodeFactory<T: Transcendental, const BUF_SIZE: usize> {
70    entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
71}
72
73impl<T: Transcendental, const BUF_SIZE: usize> Clone for NodeFactory<T, BUF_SIZE> {
74    fn clone(&self) -> Self {
75        Self {
76            entries: self
77                .entries
78                .iter()
79                .map(|(k, v)| (*k, v.clone_box()))
80                .collect(),
81        }
82    }
83}
84
85impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeFactory<T, BUF_SIZE> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl<T: Transcendental, const BUF_SIZE: usize> NodeFactory<T, BUF_SIZE> {
92    /// Create an empty registry.
93    pub fn new() -> Self {
94        Self {
95            entries: HashMap::new(),
96        }
97    }
98
99    /// Register a node constructor.
100    ///
101    /// The constructor's [`type_name`](NodeConstructor::type_name) is used
102    /// as the lookup key. If a constructor with the same name already exists,
103    /// it is replaced.
104    pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
105        let name = ctor.type_name();
106        self.entries.insert(name, Box::new(ctor));
107    }
108
109    /// Register a node type via a closure.
110    ///
111    /// This is a convenience wrapper around [`register`](Self::register) for
112    /// cases where a full struct + trait impl is not needed.
113    pub fn register_fn(
114        &mut self,
115        type_name: &'static str,
116        f: impl Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
117    ) {
118        self.entries.insert(
119            type_name,
120            Box::new(ClosureCtor {
121                type_name,
122                f: Arc::new(f),
123            }),
124        );
125    }
126
127    /// Construct a node by type name.
128    ///
129    /// Returns [`RegistryError::UnknownType`] if the name has not been
130    /// registered.
131    pub fn construct(
132        &self,
133        type_name: &str,
134        id: NodeId,
135        params: &Params,
136    ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
137        self.entries
138            .get(type_name)
139            .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
140            .map(|ctor| ctor.construct(id, params))
141    }
142
143    /// Check whether a type name is registered.
144    pub fn contains(&self, type_name: &str) -> bool {
145        self.entries.contains_key(type_name)
146    }
147
148    /// List all registered type names.
149    pub fn list_types(&self) -> Vec<&'static str> {
150        self.entries.keys().copied().collect()
151    }
152
153    /// Number of registered constructors.
154    pub fn len(&self) -> usize {
155        self.entries.len()
156    }
157
158    /// True when no constructors are registered.
159    pub fn is_empty(&self) -> bool {
160        self.entries.is_empty()
161    }
162
163    /// Get metadata for a registered type without constructing a node.
164    ///
165    /// This requires constructing a temporary node and immediately
166    /// discarding it. If performance is a concern, cache the metadata
167    /// alongside the constructor in the registry.
168    pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
169        self.entries.get(type_name).map(|ctor| {
170            let dummy = Params::new(44100.0);
171            let variant = ctor.construct(NodeId(u32::MAX), &dummy);
172            variant.metadata()
173        })
174    }
175}
176
177// ============================================================================
178// Internal: closure-based constructor wrapper
179// ============================================================================
180
181#[allow(clippy::type_complexity)]
182struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
183    type_name: &'static str,
184    f: Arc<dyn Fn(NodeId, &Params) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
185}
186
187impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
188    for ClosureCtor<T, BUF_SIZE>
189{
190    fn type_name(&self) -> &'static str {
191        self.type_name
192    }
193
194    fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, BUF_SIZE> {
195        (self.f)(id, params)
196    }
197
198    fn clone_box(&self) -> Box<dyn NodeConstructor<T, BUF_SIZE>> {
199        Box::new(ClosureCtor {
200            type_name: self.type_name,
201            f: self.f.clone(),
202        })
203    }
204}
205
206// ============================================================================
207// Node Ctor Macro
208// ============================================================================
209
210/// Register a node constructor by type name.
211///
212/// Shorthand for [`NodeFactory::register_fn`]. Emits a call to
213/// `registry.register_fn(type_name, closure)`.
214///
215/// # Example
216///
217/// ```rust
218/// use rill_graph::{node_ctor, NodeFactory};
219/// use rill_core::traits::{NodeId, Params, NodeVariant, Source, Node};
220///
221/// // Inside a function that has access to a &mut NodeFactory<f32, 64>:
222/// fn register(registry: &mut NodeFactory<f32, 64>) {
223///     node_ctor!(registry, "test/my_source", |id, params| {
224///         // construct and return NodeVariant
225///         todo!()
226///     });
227/// }
228/// ```
229#[macro_export]
230macro_rules! node_ctor {
231    ($registry:expr, $type_name:expr, $ctor:expr) => {
232        $registry.register_fn($type_name, $ctor);
233    };
234}
235
236// ============================================================================
237// Tests
238// ============================================================================
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    use rill_core::time::ClockTick;
245    use rill_core::traits::node::NodeState;
246    use rill_core::traits::port::Port;
247    use rill_core::traits::NodeCategory;
248    use rill_core::traits::Processor;
249    use rill_core::traits::Source;
250    use rill_core::traits::{ParamValue, ProcessResult};
251
252    // ── Test helpers ────────────────────────────────────────────────
253
254    struct TestSource<T: Transcendental, const B: usize> {
255        id: NodeId,
256        state: NodeState<T, B>,
257        output: Port<T, B>,
258        meta_name: &'static str,
259        meta_cat: NodeCategory,
260    }
261
262    impl<T: Transcendental, const B: usize> TestSource<T, B> {
263        fn new() -> Self {
264            Self {
265                id: NodeId(0),
266                state: NodeState::new(44100.0),
267                output: Port::output(NodeId(0), 0, "out"),
268                meta_name: "TestSource",
269                meta_cat: NodeCategory::Source,
270            }
271        }
272
273        fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
274            self.id = id;
275            self.state.sample_rate = sample_rate;
276        }
277    }
278
279    impl<T: Transcendental, const B: usize> Node<T, B> for TestSource<T, B> {
280        fn metadata(&self) -> rill_core::traits::NodeMetadata {
281            rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
282        }
283        fn init(&mut self, sample_rate: f32) {
284            self.state.sample_rate = sample_rate;
285        }
286        fn reset(&mut self) {}
287        fn get_parameter(
288            &self,
289            _: &rill_core::traits::ParameterId,
290        ) -> Option<rill_core::traits::ParamValue> {
291            None
292        }
293        fn set_parameter(
294            &mut self,
295            _: &rill_core::traits::ParameterId,
296            _: rill_core::traits::ParamValue,
297        ) -> ProcessResult<()> {
298            Ok(())
299        }
300        fn id(&self) -> NodeId {
301            self.id
302        }
303        fn set_id(&mut self, id: NodeId) {
304            self.id = id;
305        }
306        fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
307            None
308        }
309        fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
310            None
311        }
312        fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
313            if index == 0 {
314                Some(&self.output)
315            } else {
316                None
317            }
318        }
319        fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
320            if index == 0 {
321                Some(&mut self.output)
322            } else {
323                None
324            }
325        }
326        fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
327            None
328        }
329        fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
330            None
331        }
332        fn state(&self) -> &NodeState<T, B> {
333            &self.state
334        }
335        fn state_mut(&mut self) -> &mut NodeState<T, B> {
336            &mut self.state
337        }
338    }
339
340    impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
341        fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> {
342            Ok(())
343        }
344    }
345
346    impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
347        fn process(
348            &mut self,
349            _: &ClockTick,
350            _: &[&[T; B]],
351            _: &[T],
352            _: &[ClockTick],
353            _: &[&[T; B]],
354        ) -> ProcessResult<()> {
355            Ok(())
356        }
357        fn latency(&self) -> usize {
358            0
359        }
360    }
361
362    struct TestSourceCtor;
363    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
364        fn type_name(&self) -> &'static str {
365            "test/source"
366        }
367        fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
368            let mut node = TestSource::<T, B>::new();
369            node.set_id_and_init(id, params.sample_rate);
370            NodeVariant::Source(Box::new(node))
371        }
372        fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
373            Box::new(Self)
374        }
375    }
376
377    struct TestProcessorCtor;
378    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
379        fn type_name(&self) -> &'static str {
380            "test/processor"
381        }
382        fn construct(&self, id: NodeId, params: &Params) -> NodeVariant<T, B> {
383            let mut node = TestSource::<T, B>::new();
384            node.meta_name = "Noop";
385            node.meta_cat = NodeCategory::Processor;
386            node.set_id_and_init(id, params.sample_rate);
387            NodeVariant::Processor(Box::new(node))
388        }
389        fn clone_box(&self) -> Box<dyn NodeConstructor<T, B>> {
390            Box::new(Self)
391        }
392    }
393
394    // ── Tests ───────────────────────────────────────────────────────
395
396    #[test]
397    fn test_registry_empty() {
398        let registry = NodeFactory::<f32, 64>::new();
399        assert!(registry.is_empty());
400        assert_eq!(registry.len(), 0);
401    }
402
403    #[test]
404    fn test_registry_register_and_construct() {
405        let mut registry = NodeFactory::<f32, 64>::new();
406        registry.register(TestSourceCtor);
407
408        assert!(registry.contains("test/source"));
409        assert_eq!(registry.len(), 1);
410
411        let params = Params::new(48000.0);
412        let variant = registry
413            .construct("test/source", NodeId(42), &params)
414            .expect("should construct");
415
416        match &variant {
417            NodeVariant::Source(_) => {}
418            _ => panic!("expected Source variant"),
419        }
420
421        // Verify init was called (sample_rate stored in state)
422        assert_eq!(variant.metadata().name, "TestSource");
423    }
424
425    #[test]
426    fn test_registry_unknown_type() {
427        let registry = NodeFactory::<f32, 64>::new();
428        let params = Params::new(44100.0);
429        let result = registry.construct("nonexistent", NodeId(0), &params);
430        assert!(result.is_err());
431        match result {
432            Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
433            _ => panic!("expected UnknownType error"),
434        }
435    }
436
437    #[test]
438    fn test_registry_register_fn() {
439        let mut registry = NodeFactory::<f32, 64>::new();
440        registry.register_fn("test/fn_ctor", |id, params| {
441            let mut node = TestSource::<f32, 64>::new();
442            node.set_id(id);
443            node.init(params.sample_rate);
444            NodeVariant::Source(Box::new(node))
445        });
446
447        assert!(registry.contains("test/fn_ctor"));
448        let params = Params::new(44100.0);
449        let variant = registry
450            .construct("test/fn_ctor", NodeId(1), &params)
451            .expect("should construct from fn");
452        match variant {
453            NodeVariant::Source(_) => {}
454            _ => panic!("expected Source variant"),
455        }
456    }
457
458    #[test]
459    fn test_registry_list_types() {
460        let mut registry = NodeFactory::<f32, 64>::new();
461        registry.register(TestSourceCtor);
462        registry.register(TestProcessorCtor);
463
464        let mut types = registry.list_types();
465        types.sort();
466        assert_eq!(types, vec!["test/processor", "test/source"]);
467    }
468
469    #[test]
470    fn test_registry_replace() {
471        let mut registry = NodeFactory::<f32, 64>::new();
472        registry.register(TestSourceCtor);
473        assert_eq!(registry.len(), 1);
474
475        // Registering again under the same name replaces.
476        registry.register(TestSourceCtor);
477        assert_eq!(registry.len(), 1);
478    }
479
480    #[test]
481    fn test_registry_metadata() {
482        let mut registry = NodeFactory::<f32, 64>::new();
483        registry.register(TestSourceCtor);
484
485        let meta = registry.metadata("test/source");
486        assert!(meta.is_some());
487        assert_eq!(meta.unwrap().name, "TestSource");
488    }
489
490    #[test]
491    fn test_construct_with_params() {
492        let mut registry = NodeFactory::<f32, 64>::new();
493        registry.register_fn("test/with_params", |id, params| {
494            let freq = params.get_f32("frequency", 440.0);
495            assert_eq!(freq, 220.0);
496            let amp = params.get_f32("amplitude", 0.5);
497            assert_eq!(amp, 0.8);
498
499            let mut node = TestSource::<f32, 64>::new();
500            node.set_id(id);
501            node.init(params.sample_rate);
502            NodeVariant::Source(Box::new(node))
503        });
504
505        let params = Params::new(44100.0)
506            .with("frequency", ParamValue::Float(220.0))
507            .with("amplitude", ParamValue::Float(0.8));
508        let result = registry.construct("test/with_params", NodeId(0), &params);
509        assert!(result.is_ok());
510    }
511}