Skip to main content

frp_engine/
transform.rs

1//! Named transform registry and edge transform evaluation.
2
3use std::collections::{BTreeMap, HashMap};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::sync::RwLock;
8
9use frp_domain::EdgeTransform;
10use frp_plexus::Value;
11
12use crate::error::EngineError;
13
14/// A pinned, heap-allocated future that resolves to `T` and is `Send`.
15pub type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
16
17/// A registry of named async transform functions keyed by name.
18///
19/// Both sync and async callables can be registered:
20/// - [`register`](TransformRegistry::register) — wraps a sync `Fn` in an async block.
21/// - [`register_async`](TransformRegistry::register_async) — stores a native async fn directly.
22///
23/// A shared [`rhai::Engine`] is embedded for evaluating [`EdgeTransform::Script`]
24/// variants. The engine has conservative safety limits applied by default
25/// (`max_operations = 50_000`, `max_call_levels = 32`).
26#[derive(Clone)]
27pub struct TransformRegistry {
28    fns: HashMap<String, Arc<dyn Fn(Vec<Value>) -> BoxFuture<Value> + Send + Sync>>,
29    pub(crate) script_engine: Arc<rhai::Engine>,
30    ast_cache: Arc<RwLock<HashMap<String, rhai::AST>>>,
31}
32
33impl Default for TransformRegistry {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl TransformRegistry {
40    /// Create an empty registry with a sandboxed Rhai scripting engine.
41    ///
42    /// The engine has the following safety limits:
43    /// - Maximum 50,000 operations per script evaluation.
44    /// - Maximum 32 nested function call levels.
45    pub fn new() -> Self {
46        let mut engine = rhai::Engine::new();
47        engine.set_max_operations(50_000u64);
48        engine.set_max_call_levels(32);
49        TransformRegistry {
50            fns: HashMap::new(),
51            script_engine: Arc::new(engine),
52            ast_cache: Arc::new(RwLock::new(HashMap::new())),
53        }
54    }
55
56    /// Register a **synchronous** transform function.
57    ///
58    /// The function is wrapped in an async block so it integrates seamlessly
59    /// with the rest of the async execution path.
60    pub fn register(
61        &mut self,
62        name: impl Into<String>,
63        f: impl Fn(Vec<Value>) -> Value + Send + Sync + 'static,
64    ) {
65        let f = Arc::new(f);
66        self.fns.insert(
67            name.into(),
68            Arc::new(move |inputs: Vec<Value>| {
69                let f = Arc::clone(&f);
70                Box::pin(async move { f(inputs) }) as BoxFuture<Value>
71            }),
72        );
73    }
74
75    /// Register a **native async** transform function.
76    ///
77    /// Use this when your transform needs to `await` (e.g. database lookups,
78    /// HTTP calls, inter-task channels).
79    pub fn register_async<F, Fut>(
80        &mut self,
81        name: impl Into<String>,
82        f: F,
83    )
84    where
85        F: Fn(Vec<Value>) -> Fut + Send + Sync + 'static,
86        Fut: Future<Output = Value> + Send + 'static,
87    {
88        self.fns.insert(
89            name.into(),
90            Arc::new(move |inputs: Vec<Value>| {
91                Box::pin(f(inputs)) as BoxFuture<Value>
92            }),
93        );
94    }
95
96    /// Retrieve a named transform function by name.
97    pub fn get(
98        &self,
99        name: &str,
100    ) -> Option<&Arc<dyn Fn(Vec<Value>) -> BoxFuture<Value> + Send + Sync>> {
101        self.fns.get(name)
102    }
103
104    fn get_or_compile_ast(&self, code: &str) -> Result<rhai::AST, EngineError> {
105        {
106            let cache = self
107                .ast_cache
108                .read()
109                .map_err(|_| EngineError::ExecutionFailed("script AST cache read lock poisoned".to_string()))?;
110            if let Some(ast) = cache.get(code).cloned() {
111                return Ok(ast);
112            }
113        }
114
115        let mut cache = self
116            .ast_cache
117            .write()
118            .map_err(|_| EngineError::ExecutionFailed("script AST cache write lock poisoned".to_string()))?;
119
120        if let Some(ast) = cache.get(code).cloned() {
121            return Ok(ast);
122        }
123
124        let compiled = self
125            .script_engine
126            .compile(code)
127            .map_err(|e| EngineError::ExecutionFailed(format!("script compile failed: {e}")))?;
128        cache.insert(code.to_string(), compiled.clone());
129        Ok(compiled)
130    }
131}
132
133/// Evaluate an [`EdgeTransform`] against a list of input values.
134///
135/// - `PassThrough` — returns the first input value, or `Value::Null` if empty.
136/// - `Named(name)` — looks up `name` in `registry`, calls it, and awaits the result.
137/// - `Inline(f)` — calls `f` synchronously (sync closure, no await).
138/// - `Script(code)` — evaluates `code` via the embedded Rhai engine. The script
139///   receives the inputs as an `inputs` array variable and must return a value.
140pub async fn eval_transform(
141    transform: &EdgeTransform,
142    inputs: Vec<Value>,
143    registry: &TransformRegistry,
144) -> Result<Value, EngineError> {
145    match transform {
146        EdgeTransform::PassThrough => Ok(inputs.into_iter().next().unwrap_or(Value::Null)),
147        EdgeTransform::Named(name) => {
148            let f = registry
149                .get(name)
150                .ok_or_else(|| EngineError::TransformNotFound(name.clone()))?;
151            Ok(f(inputs).await)
152        }
153        EdgeTransform::Inline(f) => Ok(f(&inputs)),
154        EdgeTransform::Script(code) => {
155            let inputs_arr: rhai::Array =
156                inputs.into_iter().map(value_to_dynamic).collect();
157            let mut scope = rhai::Scope::new();
158            scope.push("inputs", inputs_arr);
159
160            let ast = registry.get_or_compile_ast(code)?;
161
162            let dyn_result = registry
163                .script_engine
164                .eval_ast_with_scope::<rhai::Dynamic>(&mut scope, &ast)
165                .map_err(|e| EngineError::ExecutionFailed(format!("script eval failed: {e}")))?;
166            Ok(dynamic_to_value(dyn_result))
167        }
168    }
169}
170
171// ---------------------------------------------------------------------------
172// Rhai ↔ Value conversions
173// ---------------------------------------------------------------------------
174
175/// Convert a [`Value`] into a [`rhai::Dynamic`] for use inside Rhai scripts.
176fn value_to_dynamic(v: Value) -> rhai::Dynamic {
177    match v {
178        Value::Null => rhai::Dynamic::UNIT,
179        Value::Bool(b) => rhai::Dynamic::from(b),
180        Value::Int(i) => rhai::Dynamic::from(i),
181        Value::Float(f) => rhai::Dynamic::from(f),
182        Value::Str(s) => rhai::Dynamic::from(s),
183        Value::Bytes(b) => {
184            let blob: rhai::Blob = b;
185            rhai::Dynamic::from_blob(blob)
186        }
187        Value::List(l) => {
188            let arr: rhai::Array = l.into_iter().map(value_to_dynamic).collect();
189            rhai::Dynamic::from_array(arr)
190        }
191        Value::Map(m) => {
192            let map: rhai::Map = m
193                .into_iter()
194                .map(|(k, v)| (k.into(), value_to_dynamic(v)))
195                .collect();
196            rhai::Dynamic::from_map(map)
197        }
198    }
199}
200
201/// Convert a [`rhai::Dynamic`] back to a [`Value`].
202///
203/// Unknown Rhai types (custom objects, etc.) map to [`Value::Null`].
204fn dynamic_to_value(d: rhai::Dynamic) -> Value {
205    if d.is_unit() {
206        Value::Null
207    } else if d.is::<bool>() {
208        Value::Bool(d.cast::<bool>())
209    } else if d.is::<i64>() {
210        Value::Int(d.cast::<i64>())
211    } else if d.is::<f64>() {
212        Value::Float(d.cast::<f64>())
213    } else if d.is::<rhai::ImmutableString>() {
214        Value::Str(d.cast::<rhai::ImmutableString>().to_string())
215    } else if d.is::<rhai::Blob>() {
216        Value::Bytes(d.cast::<rhai::Blob>())
217    } else if d.is::<rhai::Array>() {
218        Value::List(
219            d.cast::<rhai::Array>()
220                .into_iter()
221                .map(dynamic_to_value)
222                .collect(),
223        )
224    } else if d.is::<rhai::Map>() {
225        let map: BTreeMap<String, Value> = d
226            .cast::<rhai::Map>()
227            .into_iter()
228            .map(|(k, v)| (k.to_string(), dynamic_to_value(v)))
229            .collect();
230        Value::Map(map)
231    } else {
232        Value::Null
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use frp_plexus::Value;
240
241    #[tokio::test]
242    async fn passthrough_returns_first() {
243        let reg = TransformRegistry::new();
244        let result = eval_transform(
245            &EdgeTransform::PassThrough,
246            vec![Value::Int(42), Value::Int(99)],
247            &reg,
248        )
249        .await
250        .unwrap();
251        assert_eq!(result, Value::Int(42));
252    }
253
254    #[tokio::test]
255    async fn passthrough_empty_returns_null() {
256        let reg = TransformRegistry::new();
257        let result = eval_transform(&EdgeTransform::PassThrough, vec![], &reg)
258            .await
259            .unwrap();
260        assert_eq!(result, Value::Null);
261    }
262
263    #[tokio::test]
264    async fn named_sync_transform_found() {
265        let mut reg = TransformRegistry::new();
266        reg.register("double", |inputs| {
267            if let Some(Value::Int(n)) = inputs.first() {
268                Value::Int(n * 2)
269            } else {
270                Value::Null
271            }
272        });
273        let result = eval_transform(
274            &EdgeTransform::Named("double".to_string()),
275            vec![Value::Int(5)],
276            &reg,
277        )
278        .await
279        .unwrap();
280        assert_eq!(result, Value::Int(10));
281    }
282
283    #[tokio::test]
284    async fn named_async_transform_found() {
285        let mut reg = TransformRegistry::new();
286        reg.register_async("async_double", |inputs| async move {
287            if let Some(Value::Int(n)) = inputs.first() {
288                Value::Int(n * 2)
289            } else {
290                Value::Null
291            }
292        });
293        let result = eval_transform(
294            &EdgeTransform::Named("async_double".to_string()),
295            vec![Value::Int(6)],
296            &reg,
297        )
298        .await
299        .unwrap();
300        assert_eq!(result, Value::Int(12));
301    }
302
303    #[tokio::test]
304    async fn named_transform_not_found() {
305        let reg = TransformRegistry::new();
306        let err = eval_transform(
307            &EdgeTransform::Named("missing".to_string()),
308            vec![],
309            &reg,
310        )
311        .await
312        .unwrap_err();
313        assert!(matches!(err, EngineError::TransformNotFound(_)));
314    }
315
316    #[tokio::test]
317    async fn inline_transform_called() {
318        let reg = TransformRegistry::new();
319        let t = EdgeTransform::Inline(Arc::new(|_inputs| Value::Bool(true)));
320        let result = eval_transform(&t, vec![], &reg).await.unwrap();
321        assert_eq!(result, Value::Bool(true));
322    }
323
324    #[tokio::test]
325    async fn script_transform_arithmetic() {
326        let reg = TransformRegistry::new();
327        let t = EdgeTransform::Script("inputs[0] + inputs[1]".to_string());
328        let result = eval_transform(&t, vec![Value::Int(3), Value::Int(4)], &reg)
329            .await
330            .unwrap();
331        assert_eq!(result, Value::Int(7));
332    }
333
334    #[tokio::test]
335    async fn script_transform_string_passthrough() {
336        let reg = TransformRegistry::new();
337        let t = EdgeTransform::Script("inputs[0]".to_string());
338        let result = eval_transform(&t, vec![Value::Str("hello".to_string())], &reg)
339            .await
340            .unwrap();
341        assert_eq!(result, Value::Str("hello".to_string()));
342    }
343
344    #[tokio::test]
345    async fn script_transform_reuses_cached_ast() {
346        let reg = TransformRegistry::new();
347        let t = EdgeTransform::Script("inputs[0] + 1".to_string());
348
349        let first = eval_transform(&t, vec![Value::Int(1)], &reg).await.unwrap();
350        let second = eval_transform(&t, vec![Value::Int(2)], &reg).await.unwrap();
351
352        assert_eq!(first, Value::Int(2));
353        assert_eq!(second, Value::Int(3));
354        assert_eq!(reg.ast_cache.read().unwrap().len(), 1);
355    }
356
357    #[tokio::test]
358    async fn script_transform_concurrent_uses_single_cached_ast() {
359        let reg = TransformRegistry::new();
360        let t = EdgeTransform::Script("inputs[0] * 2".to_string());
361
362        let mut tasks = Vec::new();
363        for i in 0_i64..16_i64 {
364            let reg_clone = reg.clone();
365            let t_clone = t.clone();
366            tasks.push(tokio::spawn(async move {
367                eval_transform(&t_clone, vec![Value::Int(i)], &reg_clone).await
368            }));
369        }
370
371        for (i, task) in tasks.into_iter().enumerate() {
372            let value = task.await.unwrap().unwrap();
373            assert_eq!(value, Value::Int((i as i64) * 2));
374        }
375
376        assert_eq!(reg.ast_cache.read().unwrap().len(), 1);
377    }
378
379    #[tokio::test]
380    async fn script_transform_error_on_invalid_code() {
381        let reg = TransformRegistry::new();
382        let t = EdgeTransform::Script("!!!invalid!!!".to_string());
383        let err = eval_transform(&t, vec![], &reg).await.unwrap_err();
384        assert!(matches!(err, EngineError::ExecutionFailed(_)));
385        assert_eq!(reg.ast_cache.read().unwrap().len(), 0);
386    }
387}