Skip to main content

rill_patchbay/
function_registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4/// A named function that can be referenced from serialized MappingDef
5/// and ServoDef instead of a raw closure.
6///
7/// Each function takes an input value plus a parameter map and returns a mapped value.
8pub type NamedFunction = Arc<dyn Fn(f64, &HashMap<String, f64>) -> f64 + Send + Sync>;
9
10/// Registry of named functions for serialization-safe custom transforms.
11///
12/// Provides the bridge between the non-serializable `Transform::Custom(Arc<dyn Fn>)`
13/// and the serializable `TransformDef::NamedFunction { name, params }`.
14///
15/// # Example
16///
17/// ```
18/// use rill_patchbay::function_registry::FunctionRegistry;
19///
20/// let reg = FunctionRegistry::builtin();
21/// let out = reg.apply("tanh", 0.5, &Default::default()).unwrap();
22/// assert!((out - 0.5f64.tanh()).abs() < 1e-10);
23/// ```
24#[derive(Clone)]
25pub struct FunctionRegistry {
26    functions: HashMap<String, NamedFunction>,
27}
28
29impl FunctionRegistry {
30    /// Create an empty registry.
31    pub fn new() -> Self {
32        Self {
33            functions: HashMap::new(),
34        }
35    }
36
37    /// Register a named function.
38    pub fn register(
39        &mut self,
40        name: impl Into<String>,
41        f: NamedFunction,
42    ) {
43        self.functions.insert(name.into(), f);
44    }
45
46    /// Apply a named function.
47    ///
48    /// Returns `None` if the function name is not registered.
49    pub fn apply(&self, name: &str, input: f64, params: &HashMap<String, f64>) -> Option<f64> {
50        self.functions.get(name).map(|f| f(input, params))
51    }
52
53    /// Fill with built-in functions.
54    pub fn builtin() -> Self {
55        let mut reg = Self::new();
56
57        reg.register("tanh", Arc::new(|x, _| x.tanh()));
58        reg.register("clip", Arc::new(|x, p| {
59            let lo = p.get("min").copied().unwrap_or(-1.0);
60            let hi = p.get("max").copied().unwrap_or(1.0);
61            x.clamp(lo, hi)
62        }));
63        reg.register("scale", Arc::new(|x, p| {
64            let from_lo = p.get("from_min").copied().unwrap_or(0.0);
65            let from_hi = p.get("from_max").copied().unwrap_or(1.0);
66            let to_lo = p.get("to_min").copied().unwrap_or(0.0);
67            let to_hi = p.get("to_max").copied().unwrap_or(1.0);
68            let norm = (x - from_lo) / (from_hi - from_lo);
69            to_lo + norm * (to_hi - to_lo)
70        }));
71        reg.register("invert", Arc::new(|x, _| 1.0 - x));
72        reg.register("abs", Arc::new(|x, _| x.abs()));
73        reg.register("smooth", Arc::new(|x, p| {
74            let factor = p.get("factor").copied().unwrap_or(0.5);
75            x * factor
76            // Note: true smoothing requires state (one-pole), handled at runtime
77        }));
78        reg.register("quantize", Arc::new(|x, p| {
79            let steps = p.get("steps").copied().unwrap_or(12.0);
80            (x * steps).round() / steps
81        }));
82
83        reg
84    }
85}
86
87impl Default for FunctionRegistry {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn test_builtin_tanh() {
99        let reg = FunctionRegistry::builtin();
100        let params = HashMap::new();
101        let out = reg.apply("tanh", 0.5, &params).unwrap();
102        assert!((out - 0.5f64.tanh()).abs() < 1e-10);
103    }
104
105    #[test]
106    fn test_builtin_clip() {
107        let reg = FunctionRegistry::builtin();
108        let mut params = HashMap::new();
109        params.insert("min".into(), -0.5);
110        params.insert("max".into(), 0.5);
111        let out = reg.apply("clip", 2.0, &params).unwrap();
112        assert!((out - 0.5).abs() < 1e-10);
113    }
114
115    #[test]
116    fn test_builtin_scale() {
117        let reg = FunctionRegistry::builtin();
118        let mut params = HashMap::new();
119        params.insert("from_min".into(), 0.0);
120        params.insert("from_max".into(), 1.0);
121        params.insert("to_min".into(), 0.0);
122        params.insert("to_max".into(), 127.0);
123        let out = reg.apply("scale", 0.5, &params).unwrap();
124        assert!((out - 63.5).abs() < 1e-10);
125    }
126
127    #[test]
128    fn test_unknown_function() {
129        let reg = FunctionRegistry::new();
130        assert!(reg.apply("nonexistent", 0.0, &HashMap::new()).is_none());
131    }
132}