1use rill_core::math::Transcendental;
2use rill_core::traits::{SignalNode, NodeId, NodeMetadata, NodeParams, NodeVariant};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
11pub enum RegistryError {
12 UnknownType(String),
14}
15
16impl std::fmt::Display for RegistryError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
20 }
21 }
22}
23
24impl std::error::Error for RegistryError {}
25
26pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
37 fn type_name(&self) -> &'static str;
39
40 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE>;
49}
50
51pub struct NodeRegistry<T: Transcendental, const BUF_SIZE: usize> {
65 entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
66}
67
68impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeRegistry<T, BUF_SIZE> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl<T: Transcendental, const BUF_SIZE: usize> NodeRegistry<T, BUF_SIZE> {
75 pub fn new() -> Self {
77 Self {
78 entries: HashMap::new(),
79 }
80 }
81
82 pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
88 let name = ctor.type_name();
89 self.entries.insert(name, Box::new(ctor));
90 }
91
92 pub fn register_fn(
97 &mut self,
98 type_name: &'static str,
99 f: impl Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
100 ) {
101 self.entries.insert(
102 type_name,
103 Box::new(ClosureCtor {
104 type_name,
105 f: Box::new(f),
106 }),
107 );
108 }
109
110 pub fn construct(
115 &self,
116 type_name: &str,
117 id: NodeId,
118 params: &NodeParams,
119 ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
120 self.entries
121 .get(type_name)
122 .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
123 .map(|ctor| ctor.construct(id, params))
124 }
125
126 pub fn contains(&self, type_name: &str) -> bool {
128 self.entries.contains_key(type_name)
129 }
130
131 pub fn list_types(&self) -> Vec<&'static str> {
133 self.entries.keys().copied().collect()
134 }
135
136 pub fn len(&self) -> usize {
138 self.entries.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.entries.is_empty()
144 }
145
146 pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
152 self.entries.get(type_name).map(|ctor| {
153 let dummy = NodeParams::new(44100.0);
154 let variant = ctor.construct(NodeId(u32::MAX), &dummy);
155 variant.metadata()
156 })
157 }
158}
159
160struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
165 type_name: &'static str,
166 f: Box<dyn Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
167}
168
169impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
170 for ClosureCtor<T, BUF_SIZE>
171{
172 fn type_name(&self) -> &'static str {
173 self.type_name
174 }
175
176 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE> {
177 (self.f)(id, params)
178 }
179}
180
181#[cfg(test)]
186mod tests {
187 use super::*;
188 use rill_core::traits::Source;
189 use rill_core::traits::Processor;
190 use rill_core::time::ClockTick;
191 use rill_core::traits::{ParamValue, ProcessResult};
192 use rill_core::traits::node::NodeState;
193 use rill_core::traits::port::Port;
194 use rill_core::traits::NodeCategory;
195
196 struct TestSource<T: Transcendental, const B: usize> {
199 id: NodeId,
200 state: NodeState<T, B>,
201 output: Port<T, B>,
202 meta_name: &'static str,
203 meta_cat: NodeCategory,
204 }
205
206 impl<T: Transcendental, const B: usize> TestSource<T, B> {
207 fn new() -> Self {
208 Self {
209 id: NodeId(0),
210 state: NodeState::new(44100.0),
211 output: Port::output(NodeId(0), 0, "out"),
212 meta_name: "TestSource",
213 meta_cat: NodeCategory::Source,
214 }
215 }
216
217 fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
218 self.id = id;
219 self.state.sample_rate = sample_rate;
220 }
221 }
222
223 impl<T: Transcendental, const B: usize> SignalNode<T, B> for TestSource<T, B> {
224 fn metadata(&self) -> rill_core::traits::NodeMetadata {
225 rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
226 }
227 fn init(&mut self, sample_rate: f32) { self.state.sample_rate = sample_rate; }
228 fn reset(&mut self) {}
229 fn get_parameter(&self, _: &rill_core::traits::ParameterId) -> Option<rill_core::traits::ParamValue> { None }
230 fn set_parameter(&mut self, _: &rill_core::traits::ParameterId, _: rill_core::traits::ParamValue) -> ProcessResult<()> { Ok(()) }
231 fn id(&self) -> NodeId { self.id }
232 fn set_id(&mut self, id: NodeId) { self.id = id; }
233 fn input_port(&self, _: usize) -> Option<&Port<T, B>> { None }
234 fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
235 fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
236 if index == 0 { Some(&self.output) } else { None }
237 }
238 fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
239 if index == 0 { Some(&mut self.output) } else { None }
240 }
241 fn control_port(&self, _: usize) -> Option<&Port<T, B>> { None }
242 fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> { None }
243 fn state(&self) -> &NodeState<T, B> { &self.state }
244 fn state_mut(&mut self) -> &mut NodeState<T, B> { &mut self.state }
245 }
246
247 impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
248 fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> { Ok(()) }
249 }
250
251 impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
252 fn process(&mut self, _: &ClockTick, _: &[&[T; B]], _: &[T], _: &[ClockTick], _: &[&[T; B]]) -> ProcessResult<()> { Ok(()) }
253 fn latency(&self) -> usize { 0 }
254 }
255
256 struct TestSourceCtor;
257 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
258 fn type_name(&self) -> &'static str { "test/source" }
259 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
260 let mut node = TestSource::<T, B>::new();
261 node.set_id_and_init(id, params.sample_rate);
262 NodeVariant::Source(Box::new(node))
263 }
264 }
265
266 struct TestProcessorCtor;
267 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
268 fn type_name(&self) -> &'static str { "test/processor" }
269 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
270 let mut node = TestSource::<T, B>::new();
271 node.meta_name = "Noop";
272 node.meta_cat = NodeCategory::Processor;
273 node.set_id_and_init(id, params.sample_rate);
274 NodeVariant::Processor(Box::new(node))
275 }
276 }
277
278 #[test]
281 fn test_registry_empty() {
282 let registry = NodeRegistry::<f32, 64>::new();
283 assert!(registry.is_empty());
284 assert_eq!(registry.len(), 0);
285 }
286
287 #[test]
288 fn test_registry_register_and_construct() {
289 let mut registry = NodeRegistry::<f32, 64>::new();
290 registry.register(TestSourceCtor);
291
292 assert!(registry.contains("test/source"));
293 assert_eq!(registry.len(), 1);
294
295 let params = NodeParams::new(48000.0);
296 let variant = registry.construct("test/source", NodeId(42), ¶ms)
297 .expect("should construct");
298
299 match &variant {
300 NodeVariant::Source(_) => {}
301 _ => panic!("expected Source variant"),
302 }
303
304 assert_eq!(variant.metadata().name, "TestSource");
306 }
307
308 #[test]
309 fn test_registry_unknown_type() {
310 let registry = NodeRegistry::<f32, 64>::new();
311 let params = NodeParams::new(44100.0);
312 let result = registry.construct("nonexistent", NodeId(0), ¶ms);
313 assert!(result.is_err());
314 match result {
315 Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
316 _ => panic!("expected UnknownType error"),
317 }
318 }
319
320 #[test]
321 fn test_registry_register_fn() {
322 let mut registry = NodeRegistry::<f32, 64>::new();
323 registry.register_fn("test/fn_ctor", |id, params| {
324 let mut node = TestSource::<f32, 64>::new();
325 node.set_id(id);
326 node.init(params.sample_rate);
327 NodeVariant::Source(Box::new(node))
328 });
329
330 assert!(registry.contains("test/fn_ctor"));
331 let params = NodeParams::new(44100.0);
332 let variant = registry.construct("test/fn_ctor", NodeId(1), ¶ms)
333 .expect("should construct from fn");
334 match variant {
335 NodeVariant::Source(_) => {}
336 _ => panic!("expected Source variant"),
337 }
338 }
339
340 #[test]
341 fn test_registry_list_types() {
342 let mut registry = NodeRegistry::<f32, 64>::new();
343 registry.register(TestSourceCtor);
344 registry.register(TestProcessorCtor);
345
346 let mut types = registry.list_types();
347 types.sort();
348 assert_eq!(types, vec!["test/processor", "test/source"]);
349 }
350
351 #[test]
352 fn test_registry_replace() {
353 let mut registry = NodeRegistry::<f32, 64>::new();
354 registry.register(TestSourceCtor);
355 assert_eq!(registry.len(), 1);
356
357 registry.register(TestSourceCtor);
359 assert_eq!(registry.len(), 1);
360 }
361
362 #[test]
363 fn test_registry_metadata() {
364 let mut registry = NodeRegistry::<f32, 64>::new();
365 registry.register(TestSourceCtor);
366
367 let meta = registry.metadata("test/source");
368 assert!(meta.is_some());
369 assert_eq!(meta.unwrap().name, "TestSource");
370 }
371
372 #[test]
373 fn test_construct_with_params() {
374 let mut registry = NodeRegistry::<f32, 64>::new();
375 registry.register_fn("test/with_params", |id, params| {
376 let freq = params.get_f32("frequency", 440.0);
377 assert_eq!(freq, 220.0);
378 let amp = params.get_f32("amplitude", 0.5);
379 assert_eq!(amp, 0.8);
380
381 let mut node = TestSource::<f32, 64>::new();
382 node.set_id(id);
383 node.init(params.sample_rate);
384 NodeVariant::Source(Box::new(node))
385 });
386
387 let params = NodeParams::new(44100.0)
388 .with("frequency", ParamValue::Float(220.0))
389 .with("amplitude", ParamValue::Float(0.8));
390 let result = registry.construct("test/with_params", NodeId(0), ¶ms);
391 assert!(result.is_ok());
392 }
393}