Skip to main content

rill_graph/
registry.rs

1use rill_core::math::Transcendental;
2use rill_core::traits::{SignalNode, NodeId, NodeMetadata, NodeParams, NodeVariant};
3use std::collections::HashMap;
4
5// ============================================================================
6// Registry Error
7// ============================================================================
8
9/// Errors that can occur during node construction via the registry.
10#[derive(Debug, Clone)]
11pub enum RegistryError {
12    /// No constructor registered for the given type name.
13    UnknownType(String),
14}
15
16impl std::fmt::Display for RegistryError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
20        }
21    }
22}
23
24impl std::error::Error for RegistryError {}
25
26// ============================================================================
27// NodeConstructor Trait
28// ============================================================================
29
30/// Factory trait for creating graph nodes by type name.
31///
32/// Each node type that wants to be constructable via the registry
33/// implements this trait. The [`construct`](Self::construct) method
34/// receives a [`NodeId`] and [`NodeParams`] and must return the
35/// appropriate [`NodeVariant`].
36pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
37    /// Canonical name for this node type (e.g. `"rill/sine_osc"`).
38    fn type_name(&self) -> &'static str;
39
40    /// Build a fully initialised node variant.
41    ///
42    /// Implementations should:
43    /// 1. Extract parameters from `params`.
44    /// 2. Create the concrete node.
45    /// 3. Call [`SignalNode::set_id`] with the given `id`.
46    /// 4. Call [`SignalNode::init`] with `params.sample_rate`.
47    /// 5. Wrap in the correct [`NodeVariant`] variant.
48    fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE>;
49}
50
51// ============================================================================
52// NodeRegistry
53// ============================================================================
54
55/// A registry of named node constructors.
56///
57/// Register constructors with [`register`](Self::register), then create
58/// nodes by type name with [`construct`](Self::construct).
59///
60/// # Type parameters
61///
62/// - `T` — sample type (typically `f32`)
63/// - `BUF_SIZE` — block size (must match the target graph)
64pub struct NodeRegistry<T: Transcendental, const BUF_SIZE: usize> {
65    entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
66}
67
68impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeRegistry<T, BUF_SIZE> {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl<T: Transcendental, const BUF_SIZE: usize> NodeRegistry<T, BUF_SIZE> {
75    /// Create an empty registry.
76    pub fn new() -> Self {
77        Self {
78            entries: HashMap::new(),
79        }
80    }
81
82    /// Register a node constructor.
83    ///
84    /// The constructor's [`type_name`](NodeConstructor::type_name) is used
85    /// as the lookup key. If a constructor with the same name already exists,
86    /// it is replaced.
87    pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
88        let name = ctor.type_name();
89        self.entries.insert(name, Box::new(ctor));
90    }
91
92    /// Register a node type via a closure.
93    ///
94    /// This is a convenience wrapper around [`register`](Self::register) for
95    /// cases where a full struct + trait impl is not needed.
96    pub fn register_fn(
97        &mut self,
98        type_name: &'static str,
99        f: impl Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
100    ) {
101        self.entries.insert(
102            type_name,
103            Box::new(ClosureCtor {
104                type_name,
105                f: Box::new(f),
106            }),
107        );
108    }
109
110    /// Construct a node by type name.
111    ///
112    /// Returns [`RegistryError::UnknownType`] if the name has not been
113    /// registered.
114    pub fn construct(
115        &self,
116        type_name: &str,
117        id: NodeId,
118        params: &NodeParams,
119    ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
120        self.entries
121            .get(type_name)
122            .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
123            .map(|ctor| ctor.construct(id, params))
124    }
125
126    /// Check whether a type name is registered.
127    pub fn contains(&self, type_name: &str) -> bool {
128        self.entries.contains_key(type_name)
129    }
130
131    /// List all registered type names.
132    pub fn list_types(&self) -> Vec<&'static str> {
133        self.entries.keys().copied().collect()
134    }
135
136    /// Number of registered constructors.
137    pub fn len(&self) -> usize {
138        self.entries.len()
139    }
140
141    /// True when no constructors are registered.
142    pub fn is_empty(&self) -> bool {
143        self.entries.is_empty()
144    }
145
146    /// Get metadata for a registered type without constructing a node.
147    ///
148    /// This requires constructing a temporary node and immediately
149    /// discarding it. If performance is a concern, cache the metadata
150    /// alongside the constructor in the registry.
151    pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
152        self.entries.get(type_name).map(|ctor| {
153            let dummy = NodeParams::new(44100.0);
154            let variant = ctor.construct(NodeId(u32::MAX), &dummy);
155            variant.metadata()
156        })
157    }
158}
159
160// ============================================================================
161// Internal: closure-based constructor wrapper
162// ============================================================================
163
164struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
165    type_name: &'static str,
166    f: Box<dyn Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
167}
168
169impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
170    for ClosureCtor<T, BUF_SIZE>
171{
172    fn type_name(&self) -> &'static str {
173        self.type_name
174    }
175
176    fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE> {
177        (self.f)(id, params)
178    }
179}
180
181// ============================================================================
182// Node Ctor Macro
183// ============================================================================
184
185/// Register a node constructor by type name.
186///
187/// Shorthand for [`NodeRegistry::register_fn`]. Emits a call to
188/// `registry.register_fn(type_name, closure)`.
189///
190/// # Example
191///
192/// ```rust
193/// use rill_graph::{node_ctor, NodeRegistry};
194/// use rill_core::traits::{NodeId, NodeParams, NodeVariant, Source, SignalNode};
195///
196/// // Inside a function that has access to a &mut NodeRegistry<f32, 64>:
197/// fn register(registry: &mut NodeRegistry<f32, 64>) {
198///     node_ctor!(registry, "test/my_source", |id, params| {
199///         // construct and return NodeVariant
200///         todo!()
201///     });
202/// }
203/// ```
204#[macro_export]
205macro_rules! node_ctor {
206    ($registry:expr, $type_name:expr, $ctor:expr) => {
207        $registry.register_fn($type_name, $ctor);
208    };
209}
210
211// ============================================================================
212// Tests
213// ============================================================================
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use rill_core::traits::Source;
219    use rill_core::traits::Processor;
220    use rill_core::time::ClockTick;
221    use rill_core::traits::{ParamValue, ProcessResult};
222    use rill_core::traits::node::NodeState;
223    use rill_core::traits::port::Port;
224    use rill_core::traits::NodeCategory;
225
226    // ── Test helpers ────────────────────────────────────────────────
227
228    struct TestSource<T: Transcendental, const B: usize> {
229        id: NodeId,
230        state: NodeState<T, B>,
231        output: Port<T, B>,
232        meta_name: &'static str,
233        meta_cat: NodeCategory,
234    }
235
236    impl<T: Transcendental, const B: usize> TestSource<T, B> {
237        fn new() -> Self {
238            Self {
239                id: NodeId(0),
240                state: NodeState::new(44100.0),
241                output: Port::output(NodeId(0), 0, "out"),
242                meta_name: "TestSource",
243                meta_cat: NodeCategory::Source,
244            }
245        }
246
247        fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
248            self.id = id;
249            self.state.sample_rate = sample_rate;
250        }
251    }
252
253    impl<T: Transcendental, const B: usize> SignalNode<T, B> for TestSource<T, B> {
254        fn metadata(&self) -> rill_core::traits::NodeMetadata {
255            rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
256        }
257        fn init(&mut self, sample_rate: f32) { self.state.sample_rate = sample_rate; }
258        fn reset(&mut self) {}
259        fn get_parameter(&self, _: &rill_core::traits::ParameterId) -> Option<rill_core::traits::ParamValue> { None }
260        fn set_parameter(&mut self, _: &rill_core::traits::ParameterId, _: rill_core::traits::ParamValue) -> ProcessResult<()> { Ok(()) }
261        fn id(&self) -> NodeId { self.id }
262        fn set_id(&mut self, id: NodeId) { self.id = id; }
263        fn input_port(&self, _: usize) -> Option<&Port<T, B>> { None }
264        fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
265        fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
266            if index == 0 { Some(&self.output) } else { None }
267        }
268        fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
269            if index == 0 { Some(&mut self.output) } else { None }
270        }
271        fn control_port(&self, _: usize) -> Option<&Port<T, B>> { None }
272        fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
273        fn state(&self) -> &NodeState<T, B> { &self.state }
274        fn state_mut(&mut self) -> &mut NodeState<T, B> { &mut self.state }
275    }
276
277    impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
278        fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> { Ok(()) }
279    }
280
281    impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
282        fn process(&mut self, _: &ClockTick, _: &[&[T; B]], _: &[T], _: &[ClockTick], _: &[&[T; B]]) -> ProcessResult<()> { Ok(()) }
283        fn latency(&self) -> usize { 0 }
284    }
285
286    struct TestSourceCtor;
287    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
288        fn type_name(&self) -> &'static str { "test/source" }
289        fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
290            let mut node = TestSource::<T, B>::new();
291            node.set_id_and_init(id, params.sample_rate);
292            NodeVariant::Source(Box::new(node))
293        }
294    }
295
296    struct TestProcessorCtor;
297    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
298        fn type_name(&self) -> &'static str { "test/processor" }
299        fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
300            let mut node = TestSource::<T, B>::new();
301            node.meta_name = "Noop";
302            node.meta_cat = NodeCategory::Processor;
303            node.set_id_and_init(id, params.sample_rate);
304            NodeVariant::Processor(Box::new(node))
305        }
306    }
307
308    // ── Tests ───────────────────────────────────────────────────────
309
310    #[test]
311    fn test_registry_empty() {
312        let registry = NodeRegistry::<f32, 64>::new();
313        assert!(registry.is_empty());
314        assert_eq!(registry.len(), 0);
315    }
316
317    #[test]
318    fn test_registry_register_and_construct() {
319        let mut registry = NodeRegistry::<f32, 64>::new();
320        registry.register(TestSourceCtor);
321
322        assert!(registry.contains("test/source"));
323        assert_eq!(registry.len(), 1);
324
325        let params = NodeParams::new(48000.0);
326        let variant = registry.construct("test/source", NodeId(42), &params)
327            .expect("should construct");
328
329        match &variant {
330            NodeVariant::Source(_) => {}
331            _ => panic!("expected Source variant"),
332        }
333
334        // Verify init was called (sample_rate stored in state)
335        assert_eq!(variant.metadata().name, "TestSource");
336    }
337
338    #[test]
339    fn test_registry_unknown_type() {
340        let registry = NodeRegistry::<f32, 64>::new();
341        let params = NodeParams::new(44100.0);
342        let result = registry.construct("nonexistent", NodeId(0), &params);
343        assert!(result.is_err());
344        match result {
345            Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
346            _ => panic!("expected UnknownType error"),
347        }
348    }
349
350    #[test]
351    fn test_registry_register_fn() {
352        let mut registry = NodeRegistry::<f32, 64>::new();
353        registry.register_fn("test/fn_ctor", |id, params| {
354            let mut node = TestSource::<f32, 64>::new();
355            node.set_id(id);
356            node.init(params.sample_rate);
357            NodeVariant::Source(Box::new(node))
358        });
359
360        assert!(registry.contains("test/fn_ctor"));
361        let params = NodeParams::new(44100.0);
362        let variant = registry.construct("test/fn_ctor", NodeId(1), &params)
363            .expect("should construct from fn");
364        match variant {
365            NodeVariant::Source(_) => {}
366            _ => panic!("expected Source variant"),
367        }
368    }
369
370    #[test]
371    fn test_registry_list_types() {
372        let mut registry = NodeRegistry::<f32, 64>::new();
373        registry.register(TestSourceCtor);
374        registry.register(TestProcessorCtor);
375
376        let mut types = registry.list_types();
377        types.sort();
378        assert_eq!(types, vec!["test/processor", "test/source"]);
379    }
380
381    #[test]
382    fn test_registry_replace() {
383        let mut registry = NodeRegistry::<f32, 64>::new();
384        registry.register(TestSourceCtor);
385        assert_eq!(registry.len(), 1);
386
387        // Registering again under the same name replaces.
388        registry.register(TestSourceCtor);
389        assert_eq!(registry.len(), 1);
390    }
391
392    #[test]
393    fn test_registry_metadata() {
394        let mut registry = NodeRegistry::<f32, 64>::new();
395        registry.register(TestSourceCtor);
396
397        let meta = registry.metadata("test/source");
398        assert!(meta.is_some());
399        assert_eq!(meta.unwrap().name, "TestSource");
400    }
401
402    #[test]
403    fn test_construct_with_params() {
404        let mut registry = NodeRegistry::<f32, 64>::new();
405        registry.register_fn("test/with_params", |id, params| {
406            let freq = params.get_f32("frequency", 440.0);
407            assert_eq!(freq, 220.0);
408            let amp = params.get_f32("amplitude", 0.5);
409            assert_eq!(amp, 0.8);
410
411            let mut node = TestSource::<f32, 64>::new();
412            node.set_id(id);
413            node.init(params.sample_rate);
414            NodeVariant::Source(Box::new(node))
415        });
416
417        let params = NodeParams::new(44100.0)
418            .with("frequency", ParamValue::Float(220.0))
419            .with("amplitude", ParamValue::Float(0.8));
420        let result = registry.construct("test/with_params", NodeId(0), &params);
421        assert!(result.is_ok());
422    }
423}