1use rill_core::math::Transcendental;
2use rill_core::traits::{NodeId, NodeMetadata, NodeParams, NodeVariant, SignalNode};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
11pub enum RegistryError {
12 UnknownType(String),
14}
15
16impl std::fmt::Display for RegistryError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 Self::UnknownType(name) => write!(f, "unknown node type: {name}"),
20 }
21 }
22}
23
24impl std::error::Error for RegistryError {}
25
26pub trait NodeConstructor<T: Transcendental, const BUF_SIZE: usize>: Send + Sync {
37 fn type_name(&self) -> &'static str;
39
40 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE>;
49}
50
51pub struct NodeRegistry<T: Transcendental, const BUF_SIZE: usize> {
65 entries: HashMap<&'static str, Box<dyn NodeConstructor<T, BUF_SIZE>>>,
66}
67
68impl<T: Transcendental, const BUF_SIZE: usize> Default for NodeRegistry<T, BUF_SIZE> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl<T: Transcendental, const BUF_SIZE: usize> NodeRegistry<T, BUF_SIZE> {
75 pub fn new() -> Self {
77 Self {
78 entries: HashMap::new(),
79 }
80 }
81
82 pub fn register(&mut self, ctor: impl NodeConstructor<T, BUF_SIZE> + 'static) {
88 let name = ctor.type_name();
89 self.entries.insert(name, Box::new(ctor));
90 }
91
92 pub fn register_fn(
97 &mut self,
98 type_name: &'static str,
99 f: impl Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync + 'static,
100 ) {
101 self.entries.insert(
102 type_name,
103 Box::new(ClosureCtor {
104 type_name,
105 f: Box::new(f),
106 }),
107 );
108 }
109
110 pub fn construct(
115 &self,
116 type_name: &str,
117 id: NodeId,
118 params: &NodeParams,
119 ) -> Result<NodeVariant<T, BUF_SIZE>, RegistryError> {
120 self.entries
121 .get(type_name)
122 .ok_or_else(|| RegistryError::UnknownType(type_name.to_string()))
123 .map(|ctor| ctor.construct(id, params))
124 }
125
126 pub fn contains(&self, type_name: &str) -> bool {
128 self.entries.contains_key(type_name)
129 }
130
131 pub fn list_types(&self) -> Vec<&'static str> {
133 self.entries.keys().copied().collect()
134 }
135
136 pub fn len(&self) -> usize {
138 self.entries.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.entries.is_empty()
144 }
145
146 pub fn metadata(&self, type_name: &str) -> Option<NodeMetadata> {
152 self.entries.get(type_name).map(|ctor| {
153 let dummy = NodeParams::new(44100.0);
154 let variant = ctor.construct(NodeId(u32::MAX), &dummy);
155 variant.metadata()
156 })
157 }
158}
159
160#[allow(clippy::type_complexity)]
165struct ClosureCtor<T: Transcendental, const BUF_SIZE: usize> {
166 type_name: &'static str,
167 f: Box<dyn Fn(NodeId, &NodeParams) -> NodeVariant<T, BUF_SIZE> + Send + Sync>,
168}
169
170impl<T: Transcendental, const BUF_SIZE: usize> NodeConstructor<T, BUF_SIZE>
171 for ClosureCtor<T, BUF_SIZE>
172{
173 fn type_name(&self) -> &'static str {
174 self.type_name
175 }
176
177 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, BUF_SIZE> {
178 (self.f)(id, params)
179 }
180}
181
182#[macro_export]
206macro_rules! node_ctor {
207 ($registry:expr, $type_name:expr, $ctor:expr) => {
208 $registry.register_fn($type_name, $ctor);
209 };
210}
211
212#[cfg(test)]
217mod tests {
218 use super::*;
219 use rill_core::time::ClockTick;
220 use rill_core::traits::node::NodeState;
221 use rill_core::traits::port::Port;
222 use rill_core::traits::NodeCategory;
223 use rill_core::traits::Processor;
224 use rill_core::traits::Source;
225 use rill_core::traits::{ParamValue, ProcessResult};
226
227 struct TestSource<T: Transcendental, const B: usize> {
230 id: NodeId,
231 state: NodeState<T, B>,
232 output: Port<T, B>,
233 meta_name: &'static str,
234 meta_cat: NodeCategory,
235 }
236
237 impl<T: Transcendental, const B: usize> TestSource<T, B> {
238 fn new() -> Self {
239 Self {
240 id: NodeId(0),
241 state: NodeState::new(44100.0),
242 output: Port::output(NodeId(0), 0, "out"),
243 meta_name: "TestSource",
244 meta_cat: NodeCategory::Source,
245 }
246 }
247
248 fn set_id_and_init(&mut self, id: NodeId, sample_rate: f32) {
249 self.id = id;
250 self.state.sample_rate = sample_rate;
251 }
252 }
253
254 impl<T: Transcendental, const B: usize> SignalNode<T, B> for TestSource<T, B> {
255 fn metadata(&self) -> rill_core::traits::NodeMetadata {
256 rill_core::traits::NodeMetadata::new(self.meta_name, self.meta_cat)
257 }
258 fn init(&mut self, sample_rate: f32) {
259 self.state.sample_rate = sample_rate;
260 }
261 fn reset(&mut self) {}
262 fn get_parameter(
263 &self,
264 _: &rill_core::traits::ParameterId,
265 ) -> Option<rill_core::traits::ParamValue> {
266 None
267 }
268 fn set_parameter(
269 &mut self,
270 _: &rill_core::traits::ParameterId,
271 _: rill_core::traits::ParamValue,
272 ) -> ProcessResult<()> {
273 Ok(())
274 }
275 fn id(&self) -> NodeId {
276 self.id
277 }
278 fn set_id(&mut self, id: NodeId) {
279 self.id = id;
280 }
281 fn input_port(&self, _: usize) -> Option<&Port<T, B>> {
282 None
283 }
284 fn input_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
285 None
286 }
287 fn output_port(&self, index: usize) -> Option<&Port<T, B>> {
288 if index == 0 {
289 Some(&self.output)
290 } else {
291 None
292 }
293 }
294 fn output_port_mut(&mut self, index: usize) -> Option<&mut Port<T, B>> {
295 if index == 0 {
296 Some(&mut self.output)
297 } else {
298 None
299 }
300 }
301 fn control_port(&self, _: usize) -> Option<&Port<T, B>> {
302 None
303 }
304 fn control_port_mut(&mut self, _: usize) -> Option<&mut Port<T, B>> {
305 None
306 }
307 fn state(&self) -> &NodeState<T, B> {
308 &self.state
309 }
310 fn state_mut(&mut self) -> &mut NodeState<T, B> {
311 &mut self.state
312 }
313 }
314
315 impl<T: Transcendental, const B: usize> Source<T, B> for TestSource<T, B> {
316 fn generate(&mut self, _: &ClockTick, _: &[T], _: &[ClockTick]) -> ProcessResult<()> {
317 Ok(())
318 }
319 }
320
321 impl<T: Transcendental, const B: usize> Processor<T, B> for TestSource<T, B> {
322 fn process(
323 &mut self,
324 _: &ClockTick,
325 _: &[&[T; B]],
326 _: &[T],
327 _: &[ClockTick],
328 _: &[&[T; B]],
329 ) -> ProcessResult<()> {
330 Ok(())
331 }
332 fn latency(&self) -> usize {
333 0
334 }
335 }
336
337 struct TestSourceCtor;
338 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestSourceCtor {
339 fn type_name(&self) -> &'static str {
340 "test/source"
341 }
342 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
343 let mut node = TestSource::<T, B>::new();
344 node.set_id_and_init(id, params.sample_rate);
345 NodeVariant::Source(Box::new(node))
346 }
347 }
348
349 struct TestProcessorCtor;
350 impl<T: Transcendental, const B: usize> NodeConstructor<T, B> for TestProcessorCtor {
351 fn type_name(&self) -> &'static str {
352 "test/processor"
353 }
354 fn construct(&self, id: NodeId, params: &NodeParams) -> NodeVariant<T, B> {
355 let mut node = TestSource::<T, B>::new();
356 node.meta_name = "Noop";
357 node.meta_cat = NodeCategory::Processor;
358 node.set_id_and_init(id, params.sample_rate);
359 NodeVariant::Processor(Box::new(node))
360 }
361 }
362
363 #[test]
366 fn test_registry_empty() {
367 let registry = NodeRegistry::<f32, 64>::new();
368 assert!(registry.is_empty());
369 assert_eq!(registry.len(), 0);
370 }
371
372 #[test]
373 fn test_registry_register_and_construct() {
374 let mut registry = NodeRegistry::<f32, 64>::new();
375 registry.register(TestSourceCtor);
376
377 assert!(registry.contains("test/source"));
378 assert_eq!(registry.len(), 1);
379
380 let params = NodeParams::new(48000.0);
381 let variant = registry
382 .construct("test/source", NodeId(42), ¶ms)
383 .expect("should construct");
384
385 match &variant {
386 NodeVariant::Source(_) => {}
387 _ => panic!("expected Source variant"),
388 }
389
390 assert_eq!(variant.metadata().name, "TestSource");
392 }
393
394 #[test]
395 fn test_registry_unknown_type() {
396 let registry = NodeRegistry::<f32, 64>::new();
397 let params = NodeParams::new(44100.0);
398 let result = registry.construct("nonexistent", NodeId(0), ¶ms);
399 assert!(result.is_err());
400 match result {
401 Err(RegistryError::UnknownType(name)) => assert_eq!(name, "nonexistent"),
402 _ => panic!("expected UnknownType error"),
403 }
404 }
405
406 #[test]
407 fn test_registry_register_fn() {
408 let mut registry = NodeRegistry::<f32, 64>::new();
409 registry.register_fn("test/fn_ctor", |id, params| {
410 let mut node = TestSource::<f32, 64>::new();
411 node.set_id(id);
412 node.init(params.sample_rate);
413 NodeVariant::Source(Box::new(node))
414 });
415
416 assert!(registry.contains("test/fn_ctor"));
417 let params = NodeParams::new(44100.0);
418 let variant = registry
419 .construct("test/fn_ctor", NodeId(1), ¶ms)
420 .expect("should construct from fn");
421 match variant {
422 NodeVariant::Source(_) => {}
423 _ => panic!("expected Source variant"),
424 }
425 }
426
427 #[test]
428 fn test_registry_list_types() {
429 let mut registry = NodeRegistry::<f32, 64>::new();
430 registry.register(TestSourceCtor);
431 registry.register(TestProcessorCtor);
432
433 let mut types = registry.list_types();
434 types.sort();
435 assert_eq!(types, vec!["test/processor", "test/source"]);
436 }
437
438 #[test]
439 fn test_registry_replace() {
440 let mut registry = NodeRegistry::<f32, 64>::new();
441 registry.register(TestSourceCtor);
442 assert_eq!(registry.len(), 1);
443
444 registry.register(TestSourceCtor);
446 assert_eq!(registry.len(), 1);
447 }
448
449 #[test]
450 fn test_registry_metadata() {
451 let mut registry = NodeRegistry::<f32, 64>::new();
452 registry.register(TestSourceCtor);
453
454 let meta = registry.metadata("test/source");
455 assert!(meta.is_some());
456 assert_eq!(meta.unwrap().name, "TestSource");
457 }
458
459 #[test]
460 fn test_construct_with_params() {
461 let mut registry = NodeRegistry::<f32, 64>::new();
462 registry.register_fn("test/with_params", |id, params| {
463 let freq = params.get_f32("frequency", 440.0);
464 assert_eq!(freq, 220.0);
465 let amp = params.get_f32("amplitude", 0.5);
466 assert_eq!(amp, 0.8);
467
468 let mut node = TestSource::<f32, 64>::new();
469 node.set_id(id);
470 node.init(params.sample_rate);
471 NodeVariant::Source(Box::new(node))
472 });
473
474 let params = NodeParams::new(44100.0)
475 .with("frequency", ParamValue::Float(220.0))
476 .with("amplitude", ParamValue::Float(0.8));
477 let result = registry.construct("test/with_params", NodeId(0), ¶ms);
478 assert!(result.is_ok());
479 }
480}