Skip to main content

rill_graph/
registry.rs

1use rill_core::math::Transcendental;
2use rill_core::traits::{NodeId, NodeMetadata, NodeParams, NodeVariant, SignalNode};
3use std::collections::HashMap;
4
5// ============================================================================
6// Registry Error
7// ============================================================================
8
9/// Errors that can occur during node construction via the registry.
10#[derive(Debug, Clone)]
11pub enum RegistryError {
12    /// No constructor registered for the given type name.
13    UnknownType(String),
14}
15
16impl std::fmt::Display for RegistryError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
20        }
21    }
22}
23
24impl std::error::Error for RegistryError {}
25
26// ============================================================================
27// NodeConstructor Trait
28// ============================================================================
29
30/// Factory trait for creating graph nodes by type name.
31///
32/// Each node type that wants to be constructable via the registry
33/// implements this trait. The [`construct`](Self::construct) method
34/// receives a [`NodeId`] and [`NodeParams`] and must return the
35/// appropriate [`NodeVariant`].
36pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
37    /// Canonical name for this node type (e.g. `"rill/sine_osc"`).
38    fn type_name(&self) -> &'static str;
39
40    /// Build a fully initialised node variant.
41    ///
42    /// Implementations should:
43    /// 1. Extract parameters from `params`.
44    /// 2. Create the concrete node.
45    /// 3. Call [`SignalNode::set_id`] with the given `id`.
46    /// 4. Call [`SignalNode::init`] with `params.sample_rate`.
47    /// 5. Wrap in the correct [`NodeVariant`] variant.
48    fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE>;
49}
50
51// ============================================================================
52// NodeRegistry
53// ============================================================================
54
55/// A registry of named node constructors.
56///
57/// Register constructors with [`register`](Self::register), then create
58/// nodes by type name with [`construct`](Self::construct).
59///
60/// # Type parameters
61///
62/// - `T` — sample type (typically `f32`)
63/// - `BUF_SIZE` — block size (must match the target graph)
64pub struct NodeRegistry<T: Transcendental, const BUF_SIZE: usize> {
65    entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
66}
67
68impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeRegistry<T, BUF_SIZE> {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl<T: Transcendental, const BUF_SIZE: usize> NodeRegistry<T, BUF_SIZE> {
75    /// Create an empty registry.
76    pub fn new() -> Self {
77        Self {
78            entries: HashMap::new(),
79        }
80    }
81
82    /// Register a node constructor.
83    ///
84    /// The constructor's [`type_name`](NodeConstructor::type_name) is used
85    /// as the lookup key. If a constructor with the same name already exists,
86    /// it is replaced.
87    pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
88        let name = ctor.type_name();
89        self.entries.insert(name, Box::new(ctor));
90    }
91
92    /// Register a node type via a closure.
93    ///
94    /// This is a convenience wrapper around [`register`](Self::register) for
95    /// cases where a full struct + trait impl is not needed.
96    pub fn register_fn(
97        &mut self,
98        type_name: &'static str,
99        f: impl Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
100    ) {
101        self.entries.insert(
102            type_name,
103            Box::new(ClosureCtor {
104                type_name,
105                f: Box::new(f),
106            }),
107        );
108    }
109
110    /// Construct a node by type name.
111    ///
112    /// Returns [`RegistryError::UnknownType`] if the name has not been
113    /// registered.
114    pub fn construct(
115        &self,
116        type_name: &str,
117        id: NodeId,
118        params: &NodeParams,
119    ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
120        self.entries
121            .get(type_name)
122            .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
123            .map(|ctor| ctor.construct(id, params))
124    }
125
126    /// Check whether a type name is registered.
127    pub fn contains(&self, type_name: &str) -> bool {
128        self.entries.contains_key(type_name)
129    }
130
131    /// List all registered type names.
132    pub fn list_types(&self) -> Vec<&'static str> {
133        self.entries.keys().copied().collect()
134    }
135
136    /// Number of registered constructors.
137    pub fn len(&self) -> usize {
138        self.entries.len()
139    }
140
141    /// True when no constructors are registered.
142    pub fn is_empty(&self) -> bool {
143        self.entries.is_empty()
144    }
145
146    /// Get metadata for a registered type without constructing a node.
147    ///
148    /// This requires constructing a temporary node and immediately
149    /// discarding it. If performance is a concern, cache the metadata
150    /// alongside the constructor in the registry.
151    pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
152        self.entries.get(type_name).map(|ctor| {
153            let dummy = NodeParams::new(44100.0);
154            let variant = ctor.construct(NodeId(u32::MAX), &dummy);
155            variant.metadata()
156        })
157    }
158}
159
160// ============================================================================
161// Internal: closure-based constructor wrapper
162// ============================================================================
163
164#[allow(clippy::type_complexity)]
165struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
166    type_name: &'static str,
167    f: Box<dyn Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
168}
169
170impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
171    for ClosureCtor<T, BUF_SIZE>
172{
173    fn type_name(&self) -> &'static str {
174        self.type_name
175    }
176
177    fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE> {
178        (self.f)(id, params)
179    }
180}
181
182// ============================================================================
183// Node Ctor Macro
184// ============================================================================
185
186/// Register a node constructor by type name.
187///
188/// Shorthand for [`NodeRegistry::register_fn`]. Emits a call to
189/// `registry.register_fn(type_name, closure)`.
190///
191/// # Example
192///
193/// ```rust
194/// use rill_graph::{node_ctor, NodeRegistry};
195/// use rill_core::traits::{NodeId, NodeParams, NodeVariant, Source, SignalNode};
196///
197/// // Inside a function that has access to a &mut NodeRegistry<f32, 64>:
198/// fn register(registry: &mut NodeRegistry<f32, 64>) {
199///     node_ctor!(registry, "test/my_source", |id, params| {
200///         // construct and return NodeVariant
201///         todo!()
202///     });
203/// }
204/// ```
205#[macro_export]
206macro_rules! node_ctor {
207    ($registry:expr, $type_name:expr, $ctor:expr) => {
208        $registry.register_fn($type_name, $ctor);
209    };
210}
211
212// ============================================================================
213// Tests
214// ============================================================================
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use rill_core::time::ClockTick;
220    use rill_core::traits::node::NodeState;
221    use rill_core::traits::port::Port;
222    use rill_core::traits::NodeCategory;
223    use rill_core::traits::Processor;
224    use rill_core::traits::Source;
225    use rill_core::traits::{ParamValue, ProcessResult};
226
227    // ── Test helpers ────────────────────────────────────────────────
228
229    struct TestSource<T: Transcendental, const B: usize> {
230        id: NodeId,
231        state: NodeState<T, B>,
232        output: Port<T, B>,
233        meta_name: &'static str,
234        meta_cat: NodeCategory,
235    }
236
237    impl<T: Transcendental, const B: usize> TestSource<T, B> {
238        fn new() -> Self {
239            Self {
240                id: NodeId(0),
241                state: NodeState::new(44100.0),
242                output: Port::output(NodeId(0), 0, "out"),
243                meta_name: "TestSource",
244                meta_cat: NodeCategory::Source,
245            }
246        }
247
248        fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
249            self.id = id;
250            self.state.sample_rate = sample_rate;
251        }
252    }
253
254    impl<T: Transcendental, const B: usize> SignalNode<T, B> for TestSource<T, B> {
255        fn metadata(&self) -> rill_core::traits::NodeMetadata {
256            rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
257        }
258        fn init(&mut self, sample_rate: f32) {
259            self.state.sample_rate = sample_rate;
260        }
261        fn reset(&mut self) {}
262        fn get_parameter(
263            &self,
264            _: &rill_core::traits::ParameterId,
265        ) -> Option<rill_core::traits::ParamValue> {
266            None
267        }
268        fn set_parameter(
269            &mut self,
270            _: &rill_core::traits::ParameterId,
271            _: rill_core::traits::ParamValue,
272        ) -> ProcessResult<()> {
273            Ok(())
274        }
275        fn id(&self) -> NodeId {
276            self.id
277        }
278        fn set_id(&mut self, id: NodeId) {
279            self.id = id;
280        }
281        fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
282            None
283        }
284        fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
285            None
286        }
287        fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
288            if index == 0 {
289                Some(&self.output)
290            } else {
291                None
292            }
293        }
294        fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
295            if index == 0 {
296                Some(&mut self.output)
297            } else {
298                None
299            }
300        }
301        fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
302            None
303        }
304        fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
305            None
306        }
307        fn state(&self) -> &NodeState<T, B> {
308            &self.state
309        }
310        fn state_mut(&mut self) -> &mut NodeState<T, B> {
311            &mut self.state
312        }
313    }
314
315    impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
316        fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> {
317            Ok(())
318        }
319    }
320
321    impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
322        fn process(
323            &mut self,
324            _: &ClockTick,
325            _: &[&[T; B]],
326            _: &[T],
327            _: &[ClockTick],
328            _: &[&[T; B]],
329        ) -> ProcessResult<()> {
330            Ok(())
331        }
332        fn latency(&self) -> usize {
333            0
334        }
335    }
336
337    struct TestSourceCtor;
338    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
339        fn type_name(&self) -> &'static str {
340            "test/source"
341        }
342        fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
343            let mut node = TestSource::<T, B>::new();
344            node.set_id_and_init(id, params.sample_rate);
345            NodeVariant::Source(Box::new(node))
346        }
347    }
348
349    struct TestProcessorCtor;
350    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
351        fn type_name(&self) -> &'static str {
352            "test/processor"
353        }
354        fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
355            let mut node = TestSource::<T, B>::new();
356            node.meta_name = "Noop";
357            node.meta_cat = NodeCategory::Processor;
358            node.set_id_and_init(id, params.sample_rate);
359            NodeVariant::Processor(Box::new(node))
360        }
361    }
362
363    // ── Tests ───────────────────────────────────────────────────────
364
365    #[test]
366    fn test_registry_empty() {
367        let registry = NodeRegistry::<f32, 64>::new();
368        assert!(registry.is_empty());
369        assert_eq!(registry.len(), 0);
370    }
371
372    #[test]
373    fn test_registry_register_and_construct() {
374        let mut registry = NodeRegistry::<f32, 64>::new();
375        registry.register(TestSourceCtor);
376
377        assert!(registry.contains("test/source"));
378        assert_eq!(registry.len(), 1);
379
380        let params = NodeParams::new(48000.0);
381        let variant = registry
382            .construct("test/source", NodeId(42), &params)
383            .expect("should construct");
384
385        match &variant {
386            NodeVariant::Source(_) => {}
387            _ => panic!("expected Source variant"),
388        }
389
390        // Verify init was called (sample_rate stored in state)
391        assert_eq!(variant.metadata().name, "TestSource");
392    }
393
394    #[test]
395    fn test_registry_unknown_type() {
396        let registry = NodeRegistry::<f32, 64>::new();
397        let params = NodeParams::new(44100.0);
398        let result = registry.construct("nonexistent", NodeId(0), &params);
399        assert!(result.is_err());
400        match result {
401            Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
402            _ => panic!("expected UnknownType error"),
403        }
404    }
405
406    #[test]
407    fn test_registry_register_fn() {
408        let mut registry = NodeRegistry::<f32, 64>::new();
409        registry.register_fn("test/fn_ctor", |id, params| {
410            let mut node = TestSource::<f32, 64>::new();
411            node.set_id(id);
412            node.init(params.sample_rate);
413            NodeVariant::Source(Box::new(node))
414        });
415
416        assert!(registry.contains("test/fn_ctor"));
417        let params = NodeParams::new(44100.0);
418        let variant = registry
419            .construct("test/fn_ctor", NodeId(1), &params)
420            .expect("should construct from fn");
421        match variant {
422            NodeVariant::Source(_) => {}
423            _ => panic!("expected Source variant"),
424        }
425    }
426
427    #[test]
428    fn test_registry_list_types() {
429        let mut registry = NodeRegistry::<f32, 64>::new();
430        registry.register(TestSourceCtor);
431        registry.register(TestProcessorCtor);
432
433        let mut types = registry.list_types();
434        types.sort();
435        assert_eq!(types, vec!["test/processor", "test/source"]);
436    }
437
438    #[test]
439    fn test_registry_replace() {
440        let mut registry = NodeRegistry::<f32, 64>::new();
441        registry.register(TestSourceCtor);
442        assert_eq!(registry.len(), 1);
443
444        // Registering again under the same name replaces.
445        registry.register(TestSourceCtor);
446        assert_eq!(registry.len(), 1);
447    }
448
449    #[test]
450    fn test_registry_metadata() {
451        let mut registry = NodeRegistry::<f32, 64>::new();
452        registry.register(TestSourceCtor);
453
454        let meta = registry.metadata("test/source");
455        assert!(meta.is_some());
456        assert_eq!(meta.unwrap().name, "TestSource");
457    }
458
459    #[test]
460    fn test_construct_with_params() {
461        let mut registry = NodeRegistry::<f32, 64>::new();
462        registry.register_fn("test/with_params", |id, params| {
463            let freq = params.get_f32("frequency", 440.0);
464            assert_eq!(freq, 220.0);
465            let amp = params.get_f32("amplitude", 0.5);
466            assert_eq!(amp, 0.8);
467
468            let mut node = TestSource::<f32, 64>::new();
469            node.set_id(id);
470            node.init(params.sample_rate);
471            NodeVariant::Source(Box::new(node))
472        });
473
474        let params = NodeParams::new(44100.0)
475            .with("frequency", ParamValue::Float(220.0))
476            .with("amplitude", ParamValue::Float(0.8));
477        let result = registry.construct("test/with_params", NodeId(0), &params);
478        assert!(result.is_ok());
479    }
480}