Skip to main content

rill_graph/
registry.rs

1use rill_core::math::Transcendental;
2use rill_core::traits::{SignalNode, NodeId, NodeMetadata, NodeParams, NodeVariant};
3use std::collections::HashMap;
4
5// ============================================================================
6// Registry Error
7// ============================================================================
8
9/// Errors that can occur during node construction via the registry.
10#[derive(Debug, Clone)]
11pub enum RegistryError {
12    /// No constructor registered for the given type name.
13    UnknownType(String),
14}
15
16impl std::fmt::Display for RegistryError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
20        }
21    }
22}
23
24impl std::error::Error for RegistryError {}
25
26// ============================================================================
27// NodeConstructor Trait
28// ============================================================================
29
30/// Factory trait for creating graph nodes by type name.
31///
32/// Each node type that wants to be constructable via the registry
33/// implements this trait. The [`construct`](Self::construct) method
34/// receives a [`NodeId`] and [`NodeParams`] and must return the
35/// appropriate [`NodeVariant`].
36pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
37    /// Canonical name for this node type (e.g. `"rill/sine_osc"`).
38    fn type_name(&self) -> &'static str;
39
40    /// Build a fully initialised node variant.
41    ///
42    /// Implementations should:
43    /// 1. Extract parameters from `params`.
44    /// 2. Create the concrete node.
45    /// 3. Call [`SignalNode::set_id`] with the given `id`.
46    /// 4. Call [`SignalNode::init`] with `params.sample_rate`.
47    /// 5. Wrap in the correct [`NodeVariant`] variant.
48    fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE>;
49}
50
51// ============================================================================
52// NodeRegistry
53// ============================================================================
54
55/// A registry of named node constructors.
56///
57/// Register constructors with [`register`](Self::register), then create
58/// nodes by type name with [`construct`](Self::construct).
59///
60/// # Type parameters
61///
62/// - `T` — sample type (typically `f32`)
63/// - `BUF_SIZE` — block size (must match the target graph)
64pub struct NodeRegistry<T: Transcendental, const BUF_SIZE: usize> {
65    entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
66}
67
68impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeRegistry<T, BUF_SIZE> {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl<T: Transcendental, const BUF_SIZE: usize> NodeRegistry<T, BUF_SIZE> {
75    /// Create an empty registry.
76    pub fn new() -> Self {
77        Self {
78            entries: HashMap::new(),
79        }
80    }
81
82    /// Register a node constructor.
83    ///
84    /// The constructor's [`type_name`](NodeConstructor::type_name) is used
85    /// as the lookup key. If a constructor with the same name already exists,
86    /// it is replaced.
87    pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
88        let name = ctor.type_name();
89        self.entries.insert(name, Box::new(ctor));
90    }
91
92    /// Register a node type via a closure.
93    ///
94    /// This is a convenience wrapper around [`register`](Self::register) for
95    /// cases where a full struct + trait impl is not needed.
96    pub fn register_fn(
97        &mut self,
98        type_name: &'static str,
99        f: impl Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
100    ) {
101        self.entries.insert(
102            type_name,
103            Box::new(ClosureCtor {
104                type_name,
105                f: Box::new(f),
106            }),
107        );
108    }
109
110    /// Construct a node by type name.
111    ///
112    /// Returns [`RegistryError::UnknownType`] if the name has not been
113    /// registered.
114    pub fn construct(
115        &self,
116        type_name: &str,
117        id: NodeId,
118        params: &NodeParams,
119    ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
120        self.entries
121            .get(type_name)
122            .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
123            .map(|ctor| ctor.construct(id, params))
124    }
125
126    /// Check whether a type name is registered.
127    pub fn contains(&self, type_name: &str) -> bool {
128        self.entries.contains_key(type_name)
129    }
130
131    /// List all registered type names.
132    pub fn list_types(&self) -> Vec<&'static str> {
133        self.entries.keys().copied().collect()
134    }
135
136    /// Number of registered constructors.
137    pub fn len(&self) -> usize {
138        self.entries.len()
139    }
140
141    /// True when no constructors are registered.
142    pub fn is_empty(&self) -> bool {
143        self.entries.is_empty()
144    }
145
146    /// Get metadata for a registered type without constructing a node.
147    ///
148    /// This requires constructing a temporary node and immediately
149    /// discarding it. If performance is a concern, cache the metadata
150    /// alongside the constructor in the registry.
151    pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
152        self.entries.get(type_name).map(|ctor| {
153            let dummy = NodeParams::new(44100.0);
154            let variant = ctor.construct(NodeId(u32::MAX), &dummy);
155            variant.metadata()
156        })
157    }
158}
159
160// ============================================================================
161// Internal: closure-based constructor wrapper
162// ============================================================================
163
164struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
165    type_name: &'static str,
166    f: Box<dyn Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
167}
168
169impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
170    for ClosureCtor<T, BUF_SIZE>
171{
172    fn type_name(&self) -> &'static str {
173        self.type_name
174    }
175
176    fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE> {
177        (self.f)(id, params)
178    }
179}
180
181// ============================================================================
182// Tests
183// ============================================================================
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use rill_core::traits::Source;
189    use rill_core::traits::Processor;
190    use rill_core::time::ClockTick;
191    use rill_core::traits::{ParamValue, ProcessResult};
192    use rill_core::traits::node::NodeState;
193    use rill_core::traits::port::Port;
194    use rill_core::traits::NodeCategory;
195
196    // ── Test helpers ────────────────────────────────────────────────
197
198    struct TestSource<T: Transcendental, const B: usize> {
199        id: NodeId,
200        state: NodeState<T, B>,
201        output: Port<T, B>,
202        meta_name: &'static str,
203        meta_cat: NodeCategory,
204    }
205
206    impl<T: Transcendental, const B: usize> TestSource<T, B> {
207        fn new() -> Self {
208            Self {
209                id: NodeId(0),
210                state: NodeState::new(44100.0),
211                output: Port::output(NodeId(0), 0, "out"),
212                meta_name: "TestSource",
213                meta_cat: NodeCategory::Source,
214            }
215        }
216
217        fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
218            self.id = id;
219            self.state.sample_rate = sample_rate;
220        }
221    }
222
223    impl<T: Transcendental, const B: usize> SignalNode<T, B> for TestSource<T, B> {
224        fn metadata(&self) -> rill_core::traits::NodeMetadata {
225            rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
226        }
227        fn init(&mut self, sample_rate: f32) { self.state.sample_rate = sample_rate; }
228        fn reset(&mut self) {}
229        fn get_parameter(&self, _: &rill_core::traits::ParameterId) -> Option<rill_core::traits::ParamValue> { None }
230        fn set_parameter(&mut self, _: &rill_core::traits::ParameterId, _: rill_core::traits::ParamValue) -> ProcessResult<()> { Ok(()) }
231        fn id(&self) -> NodeId { self.id }
232        fn set_id(&mut self, id: NodeId) { self.id = id; }
233        fn input_port(&self, _: usize) -> Option<&Port<T, B>> { None }
234        fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
235        fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
236            if index == 0 { Some(&self.output) } else { None }
237        }
238        fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
239            if index == 0 { Some(&mut self.output) } else { None }
240        }
241        fn control_port(&self, _: usize) -> Option<&Port<T, B>> { None }
242        fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
243        fn state(&self) -> &NodeState<T, B> { &self.state }
244        fn state_mut(&mut self) -> &mut NodeState<T, B> { &mut self.state }
245    }
246
247    impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
248        fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> { Ok(()) }
249    }
250
251    impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
252        fn process(&mut self, _: &ClockTick, _: &[&[T; B]], _: &[T], _: &[ClockTick], _: &[&[T; B]]) -> ProcessResult<()> { Ok(()) }
253        fn latency(&self) -> usize { 0 }
254    }
255
256    struct TestSourceCtor;
257    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
258        fn type_name(&self) -> &'static str { "test/source" }
259        fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
260            let mut node = TestSource::<T, B>::new();
261            node.set_id_and_init(id, params.sample_rate);
262            NodeVariant::Source(Box::new(node))
263        }
264    }
265
266    struct TestProcessorCtor;
267    impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
268        fn type_name(&self) -> &'static str { "test/processor" }
269        fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
270            let mut node = TestSource::<T, B>::new();
271            node.meta_name = "Noop";
272            node.meta_cat = NodeCategory::Processor;
273            node.set_id_and_init(id, params.sample_rate);
274            NodeVariant::Processor(Box::new(node))
275        }
276    }
277
278    // ── Tests ───────────────────────────────────────────────────────
279
280    #[test]
281    fn test_registry_empty() {
282        let registry = NodeRegistry::<f32, 64>::new();
283        assert!(registry.is_empty());
284        assert_eq!(registry.len(), 0);
285    }
286
287    #[test]
288    fn test_registry_register_and_construct() {
289        let mut registry = NodeRegistry::<f32, 64>::new();
290        registry.register(TestSourceCtor);
291
292        assert!(registry.contains("test/source"));
293        assert_eq!(registry.len(), 1);
294
295        let params = NodeParams::new(48000.0);
296        let variant = registry.construct("test/source", NodeId(42), &params)
297            .expect("should construct");
298
299        match &variant {
300            NodeVariant::Source(_) => {}
301            _ => panic!("expected Source variant"),
302        }
303
304        // Verify init was called (sample_rate stored in state)
305        assert_eq!(variant.metadata().name, "TestSource");
306    }
307
308    #[test]
309    fn test_registry_unknown_type() {
310        let registry = NodeRegistry::<f32, 64>::new();
311        let params = NodeParams::new(44100.0);
312        let result = registry.construct("nonexistent", NodeId(0), &params);
313        assert!(result.is_err());
314        match result {
315            Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
316            _ => panic!("expected UnknownType error"),
317        }
318    }
319
320    #[test]
321    fn test_registry_register_fn() {
322        let mut registry = NodeRegistry::<f32, 64>::new();
323        registry.register_fn("test/fn_ctor", |id, params| {
324            let mut node = TestSource::<f32, 64>::new();
325            node.set_id(id);
326            node.init(params.sample_rate);
327            NodeVariant::Source(Box::new(node))
328        });
329
330        assert!(registry.contains("test/fn_ctor"));
331        let params = NodeParams::new(44100.0);
332        let variant = registry.construct("test/fn_ctor", NodeId(1), &params)
333            .expect("should construct from fn");
334        match variant {
335            NodeVariant::Source(_) => {}
336            _ => panic!("expected Source variant"),
337        }
338    }
339
340    #[test]
341    fn test_registry_list_types() {
342        let mut registry = NodeRegistry::<f32, 64>::new();
343        registry.register(TestSourceCtor);
344        registry.register(TestProcessorCtor);
345
346        let mut types = registry.list_types();
347        types.sort();
348        assert_eq!(types, vec!["test/processor", "test/source"]);
349    }
350
351    #[test]
352    fn test_registry_replace() {
353        let mut registry = NodeRegistry::<f32, 64>::new();
354        registry.register(TestSourceCtor);
355        assert_eq!(registry.len(), 1);
356
357        // Registering again under the same name replaces.
358        registry.register(TestSourceCtor);
359        assert_eq!(registry.len(), 1);
360    }
361
362    #[test]
363    fn test_registry_metadata() {
364        let mut registry = NodeRegistry::<f32, 64>::new();
365        registry.register(TestSourceCtor);
366
367        let meta = registry.metadata("test/source");
368        assert!(meta.is_some());
369        assert_eq!(meta.unwrap().name, "TestSource");
370    }
371
372    #[test]
373    fn test_construct_with_params() {
374        let mut registry = NodeRegistry::<f32, 64>::new();
375        registry.register_fn("test/with_params", |id, params| {
376            let freq = params.get_f32("frequency", 440.0);
377            assert_eq!(freq, 220.0);
378            let amp = params.get_f32("amplitude", 0.5);
379            assert_eq!(amp, 0.8);
380
381            let mut node = TestSource::<f32, 64>::new();
382            node.set_id(id);
383            node.init(params.sample_rate);
384            NodeVariant::Source(Box::new(node))
385        });
386
387        let params = NodeParams::new(44100.0)
388            .with("frequency", ParamValue::Float(220.0))
389            .with("amplitude", ParamValue::Float(0.8));
390        let result = registry.construct("test/with_params", NodeId(0), &params);
391        assert!(result.is_ok());
392    }
393}