rill_patchbay/
function_registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4pub type NamedFunction = Arc<dyn Fn(f64, &HashMap<String, f64>) -> f64 + Send + Sync>;
9
10#[derive(Clone)]
25pub struct FunctionRegistry {
26 functions: HashMap<String, NamedFunction>,
27}
28
29impl FunctionRegistry {
30 pub fn new() -> Self {
32 Self {
33 functions: HashMap::new(),
34 }
35 }
36
37 pub fn register(&mut self, name: impl Into<String>, f: NamedFunction) {
39 self.functions.insert(name.into(), f);
40 }
41
42 pub fn apply(&self, name: &str, input: f64, params: &HashMap<String, f64>) -> Option<f64> {
46 self.functions.get(name).map(|f| f(input, params))
47 }
48
49 pub fn builtin() -> Self {
51 let mut reg = Self::new();
52
53 reg.register("tanh", Arc::new(|x, _| x.tanh()));
54 reg.register(
55 "clip",
56 Arc::new(|x, p| {
57 let lo = p.get("min").copied().unwrap_or(-1.0);
58 let hi = p.get("max").copied().unwrap_or(1.0);
59 x.clamp(lo, hi)
60 }),
61 );
62 reg.register(
63 "scale",
64 Arc::new(|x, p| {
65 let from_lo = p.get("from_min").copied().unwrap_or(0.0);
66 let from_hi = p.get("from_max").copied().unwrap_or(1.0);
67 let to_lo = p.get("to_min").copied().unwrap_or(0.0);
68 let to_hi = p.get("to_max").copied().unwrap_or(1.0);
69 let norm = (x - from_lo) / (from_hi - from_lo);
70 to_lo + norm * (to_hi - to_lo)
71 }),
72 );
73 reg.register("invert", Arc::new(|x, _| 1.0 - x));
74 reg.register("abs", Arc::new(|x, _| x.abs()));
75 reg.register(
76 "smooth",
77 Arc::new(|x, p| {
78 let factor = p.get("factor").copied().unwrap_or(0.5);
79 x * factor
80 }),
82 );
83 reg.register(
84 "quantize",
85 Arc::new(|x, p| {
86 let steps = p.get("steps").copied().unwrap_or(12.0);
87 (x * steps).round() / steps
88 }),
89 );
90
91 reg
92 }
93}
94
95impl Default for FunctionRegistry {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_builtin_tanh() {
107 let reg = FunctionRegistry::builtin();
108 let params = HashMap::new();
109 let out = reg.apply("tanh", 0.5, ¶ms).unwrap();
110 assert!((out - 0.5f64.tanh()).abs() < 1e-10);
111 }
112
113 #[test]
114 fn test_builtin_clip() {
115 let reg = FunctionRegistry::builtin();
116 let mut params = HashMap::new();
117 params.insert("min".into(), -0.5);
118 params.insert("max".into(), 0.5);
119 let out = reg.apply("clip", 2.0, ¶ms).unwrap();
120 assert!((out - 0.5).abs() < 1e-10);
121 }
122
123 #[test]
124 fn test_builtin_scale() {
125 let reg = FunctionRegistry::builtin();
126 let mut params = HashMap::new();
127 params.insert("from_min".into(), 0.0);
128 params.insert("from_max".into(), 1.0);
129 params.insert("to_min".into(), 0.0);
130 params.insert("to_max".into(), 127.0);
131 let out = reg.apply("scale", 0.5, ¶ms).unwrap();
132 assert!((out - 63.5).abs() < 1e-10);
133 }
134
135 #[test]
136 fn test_unknown_function() {
137 let reg = FunctionRegistry::new();
138 assert!(reg.apply("nonexistent", 0.0, &HashMap::new()).is_none());
139 }
140}