rill_patchbay/
function_registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4pub type NamedFunction = Arc<dyn Fn(f64, &HashMap<String, f64>) -> f64 + Send + Sync>;
9
10#[derive(Clone)]
25pub struct FunctionRegistry {
26 functions: HashMap<String, NamedFunction>,
27}
28
29impl FunctionRegistry {
30 pub fn new() -> Self {
32 Self {
33 functions: HashMap::new(),
34 }
35 }
36
37 pub fn register(
39 &mut self,
40 name: impl Into<String>,
41 f: NamedFunction,
42 ) {
43 self.functions.insert(name.into(), f);
44 }
45
46 pub fn apply(&self, name: &str, input: f64, params: &HashMap<String, f64>) -> Option<f64> {
50 self.functions.get(name).map(|f| f(input, params))
51 }
52
53 pub fn builtin() -> Self {
55 let mut reg = Self::new();
56
57 reg.register("tanh", Arc::new(|x, _| x.tanh()));
58 reg.register("clip", Arc::new(|x, p| {
59 let lo = p.get("min").copied().unwrap_or(-1.0);
60 let hi = p.get("max").copied().unwrap_or(1.0);
61 x.clamp(lo, hi)
62 }));
63 reg.register("scale", Arc::new(|x, p| {
64 let from_lo = p.get("from_min").copied().unwrap_or(0.0);
65 let from_hi = p.get("from_max").copied().unwrap_or(1.0);
66 let to_lo = p.get("to_min").copied().unwrap_or(0.0);
67 let to_hi = p.get("to_max").copied().unwrap_or(1.0);
68 let norm = (x - from_lo) / (from_hi - from_lo);
69 to_lo + norm * (to_hi - to_lo)
70 }));
71 reg.register("invert", Arc::new(|x, _| 1.0 - x));
72 reg.register("abs", Arc::new(|x, _| x.abs()));
73 reg.register("smooth", Arc::new(|x, p| {
74 let factor = p.get("factor").copied().unwrap_or(0.5);
75 x * factor
76 }));
78 reg.register("quantize", Arc::new(|x, p| {
79 let steps = p.get("steps").copied().unwrap_or(12.0);
80 (x * steps).round() / steps
81 }));
82
83 reg
84 }
85}
86
87impl Default for FunctionRegistry {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn test_builtin_tanh() {
99 let reg = FunctionRegistry::builtin();
100 let params = HashMap::new();
101 let out = reg.apply("tanh", 0.5, ¶ms).unwrap();
102 assert!((out - 0.5f64.tanh()).abs() < 1e-10);
103 }
104
105 #[test]
106 fn test_builtin_clip() {
107 let reg = FunctionRegistry::builtin();
108 let mut params = HashMap::new();
109 params.insert("min".into(), -0.5);
110 params.insert("max".into(), 0.5);
111 let out = reg.apply("clip", 2.0, ¶ms).unwrap();
112 assert!((out - 0.5).abs() < 1e-10);
113 }
114
115 #[test]
116 fn test_builtin_scale() {
117 let reg = FunctionRegistry::builtin();
118 let mut params = HashMap::new();
119 params.insert("from_min".into(), 0.0);
120 params.insert("from_max".into(), 1.0);
121 params.insert("to_min".into(), 0.0);
122 params.insert("to_max".into(), 127.0);
123 let out = reg.apply("scale", 0.5, ¶ms).unwrap();
124 assert!((out - 63.5).abs() < 1e-10);
125 }
126
127 #[test]
128 fn test_unknown_function() {
129 let reg = FunctionRegistry::new();
130 assert!(reg.apply("nonexistent", 0.0, &HashMap::new()).is_none());
131 }
132}