Skip to main content

swf_runtime/
expression.rs

1use crate::error::{WorkflowError, WorkflowResult};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::LazyLock;
5use swf_core::models::expression::{is_strict_expr, sanitize_expr};
6
7/// Trait for pluggable expression evaluation engines.
8///
9/// Implement this trait to add support for expression languages beyond JQ
10/// (e.g., CEL, JavaScript). Register engines with `WorkflowRunner::with_expression_engine()`.
11///
12/// Expression routing uses the `engine:` prefix convention:
13/// - `jq: .foo` → JQ engine
14/// - `cel: payload.model.startsWith("gpt")` → CEL engine
15/// - No prefix → default engine (JQ)
16///
17/// # Example
18///
19/// ```no_run
20/// use async_trait::async_trait;
21/// use serde_json::Value;
22/// use std::collections::HashMap;
23/// use swf_runtime::{ExpressionEngine, WorkflowResult};
24///
25/// struct CelEngine;
26///
27/// #[async_trait]
28/// impl ExpressionEngine for CelEngine {
29///     fn engine_prefix(&self) -> &str { "cel" }
30///
31///     fn evaluate(
32///         &self,
33///         expression: &str,
34///         input: &Value,
35///         vars: &HashMap<String, Value>,
36///     ) -> WorkflowResult<Value> {
37///         // Implement CEL evaluation here
38///         Ok(Value::Null)
39///     }
40/// }
41/// ```
42pub trait ExpressionEngine: Send + Sync {
43    /// Returns the prefix that routes expressions to this engine (e.g., "cel", "js").
44    fn engine_prefix(&self) -> &str;
45
46    /// Evaluates an expression against the given input with variable bindings.
47    fn evaluate(
48        &self,
49        expression: &str,
50        input: &Value,
51        vars: &HashMap<String, Value>,
52    ) -> WorkflowResult<Value>;
53}
54
55/// Registry of expression engines, keyed by prefix.
56#[derive(Default, Clone)]
57pub struct ExpressionEngineRegistry {
58    engines: std::sync::Arc<HashMap<String, std::sync::Arc<dyn ExpressionEngine>>>,
59}
60
61impl ExpressionEngineRegistry {
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    pub fn register(&mut self, engine: std::sync::Arc<dyn ExpressionEngine>) {
67        let key = engine.engine_prefix().to_string();
68        std::sync::Arc::make_mut(&mut self.engines).insert(key, engine);
69    }
70
71    pub fn get(&self, prefix: &str) -> Option<std::sync::Arc<dyn ExpressionEngine>> {
72        self.engines.get(prefix).cloned()
73    }
74}
75
76/// Checks if an expression has an engine prefix (e.g., "cel:", "jq:").
77/// Returns (prefix, remaining_expression) if a prefix is found.
78pub fn strip_engine_prefix(expr: &str) -> Option<(&str, &str)> {
79    // Match patterns like "cel:" or "jq:" at the start
80    let expr = expr.trim_start();
81    for sep in &[':'] {
82        if let Some(pos) = expr.find(*sep) {
83            let prefix = &expr[..pos];
84            // Only accept alphabetic prefixes (not JQ operators like `.foo`)
85            if prefix.chars().all(|c| c.is_ascii_alphabetic()) && !prefix.is_empty() {
86                let rest = expr[pos + 1..].trim_start();
87                return Some((prefix, rest));
88            }
89        }
90    }
91    None
92}
93
94/// Evaluates an expression, routing to the appropriate engine based on prefix.
95/// Falls back to JQ evaluation if no prefix is found.
96pub fn evaluate_with_engines(
97    expression: &str,
98    input: &Value,
99    vars: &HashMap<String, Value>,
100    engines: &ExpressionEngineRegistry,
101) -> WorkflowResult<Value> {
102    // Try to strip engine prefix
103    if let Some((prefix, rest)) = strip_engine_prefix(expression) {
104        if let Some(engine) = engines.get(prefix) {
105            return engine.evaluate(rest, input, vars);
106        }
107        // Unknown prefix: fall through to JQ (treat as JQ expression with colon)
108    }
109    evaluate_jq(expression, input, vars)
110}
111
112/// Compiled JQ filter cache key: (expression, sorted variable names joined by null)
113type CacheKey = (String, String);
114
115/// Global cache for compiled JQ filters.
116/// Key: (expression_text, sorted_variable_names_joined)
117/// Value: compiled Filter that can be reused with matching variable bindings
118static FILTER_CACHE: LazyLock<
119    std::sync::RwLock<HashMap<CacheKey, jaq_core::Filter<jaq_core::Native<jaq_json::Val>>>>,
120> = LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
121
122/// Evaluates a JQ expression against a JSON input with variable bindings.
123/// Uses a global cache to avoid recompiling the same expression with the same variable names.
124pub fn evaluate_jq(
125    expression: &str,
126    input: &Value,
127    vars: &HashMap<String, Value>,
128) -> WorkflowResult<Value> {
129    use jaq_core::{load, Compiler, Ctx, RcIter};
130    use jaq_json::Val;
131
132    // Prepare global variable names in a stable sorted order
133    let mut var_names: Vec<String> = vars.keys().cloned().collect();
134    var_names.sort();
135    let var_name_refs: Vec<&str> = var_names.iter().map(|s| s.as_str()).collect();
136
137    // Build cache key from expression and variable names
138    let cache_key = (expression.to_string(), var_names.join("\0"));
139
140    // Try to get a compiled filter from the cache
141    let filter = {
142        let cache = FILTER_CACHE.read().unwrap_or_else(|e| e.into_inner());
143        cache.get(&cache_key).cloned()
144    };
145
146    let filter = match filter {
147        Some(f) => f,
148        None => {
149            // Parse the expression
150            let program = load::File {
151                code: expression,
152                path: (),
153            };
154            let loader = load::Loader::new(jaq_std::defs().chain(jaq_json::defs()));
155            let arena = load::Arena::default();
156
157            let modules = loader.load(&arena, program).map_err(|e| {
158                WorkflowError::expression(
159                    format!("failed to parse jq expression '{}': {:?}", expression, e),
160                    "",
161                )
162            })?;
163
164            // Compile with standard functions and global variables
165            let filter = Compiler::default()
166                .with_funs(jaq_std::funs().chain(jaq_json::funs()))
167                .with_global_vars(var_name_refs)
168                .compile(modules)
169                .map_err(|errs| {
170                    WorkflowError::expression(
171                        format!(
172                            "failed to compile jq expression '{}': {:?}",
173                            expression, errs
174                        ),
175                        "",
176                    )
177                })?;
178
179            // Store in cache
180            let mut cache = FILTER_CACHE.write().unwrap_or_else(|e| e.into_inner());
181            cache.entry(cache_key).or_insert(filter).clone()
182        }
183    };
184
185    // Convert serde_json::Value to jaq Val
186    let jaq_input = Val::from(input.clone());
187
188    // Build variable bindings for jaq context using the same key order as var_names
189    let var_vals: Vec<Val> = var_names
190        .iter()
191        .map(|k| Val::from(vars[k].clone()))
192        .collect();
193    let inputs = RcIter::new(core::iter::empty());
194
195    let out = filter.run((Ctx::new(var_vals, &inputs), jaq_input));
196
197    let mut results = Vec::new();
198    for item in out {
199        match item {
200            Ok(val) => {
201                let json_val: Value = val.into();
202                results.push(json_val);
203            }
204            Err(e) => {
205                return Err(WorkflowError::expression(
206                    format!("jq evaluation error: {:?}", e),
207                    "",
208                ));
209            }
210        }
211    }
212
213    match results.len() {
214        0 => Err(WorkflowError::expression(
215            "no result from jq evaluation",
216            "",
217        )),
218        1 => Ok(results.into_iter().next().unwrap_or(Value::Null)),
219        _ => Ok(Value::Array(results)),
220    }
221}
222
223/// Recursively traverses a JSON structure and evaluates all runtime expressions
224pub fn traverse_and_evaluate(
225    node: &mut Value,
226    input: &Value,
227    vars: &HashMap<String, Value>,
228) -> WorkflowResult<()> {
229    match node {
230        Value::Object(map) => {
231            for (_key, value) in map.iter_mut() {
232                traverse_and_evaluate(value, input, vars)?;
233            }
234        }
235        Value::Array(arr) => {
236            for item in arr.iter_mut() {
237                traverse_and_evaluate(item, input, vars)?;
238            }
239        }
240        Value::String(s) if is_strict_expr(s) => {
241            let expr = sanitize_expr(s);
242            let result = evaluate_jq(&expr, input, vars)?;
243            *node = result;
244        }
245        _ => {}
246    }
247    Ok(())
248}
249
250/// Evaluates an expression and returns the result as a boolean
251pub fn traverse_and_evaluate_bool(
252    expr: &str,
253    input: &Value,
254    vars: &HashMap<String, Value>,
255) -> WorkflowResult<bool> {
256    if expr.is_empty() {
257        return Ok(false);
258    }
259
260    // Normalize: add ${} if not strict
261    let normalized = if is_strict_expr(expr) {
262        expr.to_string()
263    } else {
264        swf_core::models::expression::normalize_expr(expr)
265    };
266
267    let sanitized = sanitize_expr(&normalized);
268    let result = evaluate_jq(&sanitized, input, vars)?;
269
270    match result {
271        Value::Bool(b) => Ok(b),
272        _ => Ok(false),
273    }
274}
275
276/// Evaluates an optional runtime expression object (input.from, output.as, etc.)
277pub fn traverse_and_evaluate_obj(
278    obj: Option<&Value>,
279    input: &Value,
280    vars: &HashMap<String, Value>,
281    task_name: &str,
282) -> WorkflowResult<Value> {
283    match obj {
284        None => Ok(input.clone()),
285        Some(value) => {
286            let mut result = value.clone();
287            traverse_and_evaluate(&mut result, input, vars)
288                .map_err(|e| WorkflowError::expression(format!("{}", e), task_name))?;
289            Ok(result)
290        }
291    }
292}
293
294/// Evaluates a string that may contain a JQ expression (${...}).
295///
296/// Supports three forms:
297/// 1. Full expression: `${ .foo }` — evaluates the whole thing as JQ
298/// 2. Embedded expressions: `http://host/${ .id }/path` — substitutes each `${...}` inline
299/// 3. Plain string: `hello` — returned as-is
300///
301/// Returns the result as a String (JSON values are converted via Display).
302pub fn evaluate_expression_str(
303    expr: &str,
304    input: &Value,
305    vars: &HashMap<String, Value>,
306    task_name: &str,
307) -> WorkflowResult<String> {
308    if is_strict_expr(expr) {
309        // Full expression: ${ .foo } -> evaluate the whole thing
310        let sanitized = sanitize_expr(expr);
311        let result = evaluate_jq(&sanitized, input, vars)
312            .map_err(|e| WorkflowError::expression(format!("{}", e), task_name))?;
313        match result {
314            Value::String(s) => Ok(s),
315            other => Ok(other.to_string()),
316        }
317    } else if expr.contains("${") {
318        // Embedded expression: http://host/${ .id }/path -> substitute each ${...}
319        evaluate_embedded_expressions(expr, input, vars, task_name)
320    } else {
321        Ok(expr.to_string())
322    }
323}
324
325/// Evaluates embedded ${...} expressions within a string, replacing each with its JQ result
326fn evaluate_embedded_expressions(
327    s: &str,
328    input: &Value,
329    vars: &HashMap<String, Value>,
330    task_name: &str,
331) -> WorkflowResult<String> {
332    let mut result = String::new();
333    let mut chars = s.chars().peekable();
334
335    while let Some(c) = chars.next() {
336        if c == '$' && chars.peek() == Some(&'{') {
337            // Found ${...} - find the matching }
338            chars.next(); // consume '{'
339            let mut depth = 1;
340            let mut expr_buf = String::new();
341            #[allow(clippy::while_let_on_iterator)]
342            while let Some(ec) = chars.next() {
343                match ec {
344                    '{' => depth += 1,
345                    '}' => {
346                        depth -= 1;
347                        if depth == 0 {
348                            break;
349                        }
350                    }
351                    _ => {}
352                }
353                expr_buf.push(ec);
354            }
355            // Evaluate the expression
356            let sanitized = sanitize_expr(&expr_buf);
357            let val = evaluate_jq(&sanitized, input, vars)
358                .map_err(|e| WorkflowError::expression(format!("{}", e), task_name))?;
359            match val {
360                Value::String(vs) => result.push_str(&vs),
361                other => result.push_str(&other.to_string()),
362            }
363        } else {
364            result.push(c);
365        }
366    }
367
368    Ok(result)
369}
370
371/// Evaluates a `Value` that may contain a JQ expression.
372///
373/// - String values are prepared (normalized + sanitized) and evaluated as JQ.
374/// - Non-string values have embedded `${...}` expressions evaluated via traverse_and_evaluate.
375pub fn evaluate_value_expr(
376    value: &Value,
377    input: &Value,
378    vars: &HashMap<String, Value>,
379    task_name: &str,
380) -> WorkflowResult<Value> {
381    match value {
382        Value::String(expr) => {
383            let sanitized = prepare_expression(expr);
384            evaluate_jq(&sanitized, input, vars)
385                .map_err(|e| WorkflowError::expression(format!("{}", e), task_name))
386        }
387        _ => traverse_and_evaluate_obj(Some(value), input, vars, task_name),
388    }
389}
390
391/// Prepares an expression for JQ evaluation by normalizing and sanitizing.
392///
393/// If the expression is a strict expression (`${...}`), strips the `${` and `}` wrapper.
394/// Otherwise, normalizes it first (adding `${}` wrapper if missing) then sanitizes.
395pub fn prepare_expression(expr: &str) -> String {
396    if is_strict_expr(expr) {
397        sanitize_expr(expr)
398    } else {
399        let normalized = swf_core::models::expression::normalize_expr(expr);
400        sanitize_expr(&normalized)
401    }
402}
403
404/// Evaluates a JQ expression string and returns the JSON result.
405///
406/// If the string is a strict expression (`${...}`), evaluates it as JQ.
407/// Otherwise, tries to parse it as JSON.
408pub fn evaluate_expression_json(
409    expr: &str,
410    input: &Value,
411    vars: &HashMap<String, Value>,
412    task_name: &str,
413) -> WorkflowResult<Value> {
414    if is_strict_expr(expr) {
415        let sanitized = sanitize_expr(expr);
416        evaluate_jq(&sanitized, input, vars)
417            .map_err(|e| WorkflowError::expression(format!("{}", e), task_name))
418    } else {
419        // Not an expression, try parsing as JSON
420        serde_json::from_str(expr).map_err(|e| {
421            WorkflowError::expression(
422                format!("failed to parse non-expression value as JSON: {}", e),
423                task_name,
424            )
425        })
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use serde_json::json;
433
434    // === evaluate_jq tests ===
435
436    #[test]
437    fn test_evaluate_jq_simple_path() {
438        let input = json!({"foo": "bar"});
439        let vars = HashMap::new();
440        let result = evaluate_jq(".foo", &input, &vars).unwrap();
441        assert_eq!(result, json!("bar"));
442    }
443
444    #[test]
445    fn test_evaluate_jq_nested_path() {
446        let input = json!({"foo": {"bar": 42}});
447        let vars = HashMap::new();
448        let result = evaluate_jq(".foo.bar", &input, &vars).unwrap();
449        assert_eq!(result, json!(42));
450    }
451
452    #[test]
453    fn test_evaluate_jq_with_variable() {
454        let input = json!({});
455        let mut vars = HashMap::new();
456        vars.insert("$input".to_string(), json!({"x": 1}));
457        let result = evaluate_jq("$input.x", &input, &vars).unwrap();
458        assert_eq!(result, json!(1));
459    }
460
461    #[test]
462    fn test_evaluate_jq_undefined_variable() {
463        let input = json!({"foo": "bar"});
464        let vars = HashMap::new();
465        let result = evaluate_jq("$undefinedVar", &input, &vars);
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_evaluate_jq_invalid_expression() {
471        let input = json!({"foo": "bar"});
472        let vars = HashMap::new();
473        let result = evaluate_jq(".foo(", &input, &vars);
474        assert!(result.is_err());
475    }
476
477    #[test]
478    fn test_evaluate_jq_array_result() {
479        let input = json!({"items": [1, 2, 3]});
480        let vars = HashMap::new();
481        let result = evaluate_jq(".items[]", &input, &vars).unwrap();
482        assert_eq!(result, json!([1, 2, 3]));
483    }
484
485    #[test]
486    fn test_evaluate_jq_length_function() {
487        let input = json!({"items": [1, 2, 3]});
488        let vars = HashMap::new();
489        let result = evaluate_jq(".items | length", &input, &vars).unwrap();
490        assert_eq!(result, json!(3));
491    }
492
493    #[test]
494    fn test_evaluate_jq_arithmetic() {
495        let input = json!({"a": 10, "b": 3});
496        let vars = HashMap::new();
497        let result = evaluate_jq(".a - .b", &input, &vars).unwrap();
498        assert_eq!(result, json!(7));
499    }
500
501    #[test]
502    fn test_jq_filter_cache_hit() {
503        // Evaluate the same expression twice - second call should use cache
504        let input1 = json!({"x": 1});
505        let input2 = json!({"x": 2});
506        let vars = HashMap::new();
507
508        let result1 = evaluate_jq(".x", &input1, &vars).unwrap();
509        assert_eq!(result1, json!(1));
510
511        let result2 = evaluate_jq(".x", &input2, &vars).unwrap();
512        assert_eq!(result2, json!(2));
513
514        // Verify cache has entries
515        let cache = FILTER_CACHE.read().unwrap();
516        assert!(!cache.is_empty());
517    }
518
519    // === traverse_and_evaluate tests (port from Go TestTraverseAndEvaluate) ===
520
521    #[test]
522    fn test_traverse_no_expression() {
523        let mut node = json!({
524            "key": "value",
525            "num": 123
526        });
527        let input = json!(null);
528        let vars = HashMap::new();
529        traverse_and_evaluate(&mut node, &input, &vars).unwrap();
530        assert_eq!(node["key"], json!("value"));
531        assert_eq!(node["num"], json!(123));
532    }
533
534    #[test]
535    fn test_traverse_and_evaluate_object() {
536        let mut node = json!({
537            "name": "${.foo}",
538            "count": 42
539        });
540        let input = json!({"foo": "hello"});
541        let vars = HashMap::new();
542        traverse_and_evaluate(&mut node, &input, &vars).unwrap();
543        assert_eq!(node["name"], json!("hello"));
544        assert_eq!(node["count"], json!(42));
545    }
546
547    #[test]
548    fn test_traverse_expression_in_array() {
549        let mut node = json!(["static", "${.foo}"]);
550        let input = json!({"foo": "bar"});
551        let vars = HashMap::new();
552        traverse_and_evaluate(&mut node, &input, &vars).unwrap();
553        assert_eq!(node[0], json!("static"));
554        assert_eq!(node[1], json!("bar"));
555    }
556
557    #[test]
558    fn test_traverse_and_evaluate_nested_expr() {
559        let mut node = json!({
560            "data": {
561                "inner": "${.name}"
562            }
563        });
564        let input = json!({"name": "world"});
565        let vars = HashMap::new();
566        traverse_and_evaluate(&mut node, &input, &vars).unwrap();
567        assert_eq!(node["data"]["inner"], json!("world"));
568    }
569
570    #[test]
571    fn test_traverse_nested_structure_in_array() {
572        let mut node = json!({
573            "level1": [{"expr": "${.foo}"}]
574        });
575        let input = json!({"foo": "nestedValue"});
576        let vars = HashMap::new();
577        traverse_and_evaluate(&mut node, &input, &vars).unwrap();
578        assert_eq!(node["level1"][0]["expr"], json!("nestedValue"));
579    }
580
581    #[test]
582    fn test_traverse_with_vars() {
583        let mut node = json!({"expr": "${$myVar}"});
584        let input = json!({});
585        let mut vars = HashMap::new();
586        vars.insert("$myVar".to_string(), json!("HelloVars"));
587        traverse_and_evaluate(&mut node, &input, &vars).unwrap();
588        assert_eq!(node["expr"], json!("HelloVars"));
589    }
590
591    #[test]
592    fn test_traverse_invalid_jq_expression() {
593        let mut node = json!("${ .foo( }");
594        let input = json!({"foo": "bar"});
595        let vars = HashMap::new();
596        let result = traverse_and_evaluate(&mut node, &input, &vars);
597        assert!(result.is_err());
598    }
599
600    // === traverse_and_evaluate_bool tests ===
601
602    #[test]
603    fn test_traverse_and_evaluate_bool_true() {
604        let input = json!({"x": 1});
605        let vars = HashMap::new();
606        let result = traverse_and_evaluate_bool("${.x == 1}", &input, &vars).unwrap();
607        assert!(result);
608    }
609
610    #[test]
611    fn test_traverse_and_evaluate_bool_false() {
612        let input = json!({"x": 1});
613        let vars = HashMap::new();
614        let result = traverse_and_evaluate_bool("${.x == 2}", &input, &vars).unwrap();
615        assert!(!result);
616    }
617
618    #[test]
619    fn test_traverse_and_evaluate_bool_empty() {
620        let input = json!({});
621        let vars = HashMap::new();
622        let result = traverse_and_evaluate_bool("", &input, &vars).unwrap();
623        assert!(!result);
624    }
625
626    // === traverse_and_evaluate_obj tests ===
627
628    #[test]
629    fn test_traverse_and_evaluate_obj_none() {
630        let input = json!({"key": "value"});
631        let vars = HashMap::new();
632        let result = traverse_and_evaluate_obj(None, &input, &vars, "test").unwrap();
633        assert_eq!(result, input);
634    }
635
636    #[test]
637    fn test_traverse_and_evaluate_obj_with_expression() {
638        let obj = json!({"result": "${.value}"});
639        let input = json!({"value": 42});
640        let vars = HashMap::new();
641        let result = traverse_and_evaluate_obj(Some(&obj), &input, &vars, "test").unwrap();
642        assert_eq!(result["result"], json!(42));
643    }
644
645    #[test]
646    fn test_jq_update_operator() {
647        // Test if jaq supports the += update operator (used in Java SDK's for-sum.yaml)
648        let input = json!({"incr": [2, 3], "counter": 6});
649        let vars = HashMap::new();
650        let result = evaluate_jq(".incr += [5]", &input, &vars);
651        // jaq 2.x may or may not support +=; if not, we'll get an error
652        match result {
653            Ok(val) => {
654                // If supported, the result should have incr updated
655                assert_eq!(val["incr"], json!([2, 3, 5]));
656                assert_eq!(val["counter"], json!(6));
657            }
658            Err(_) => {
659                // += not supported - this is expected for jaq
660                // The Java SDK uses jq which supports update operators
661            }
662        }
663    }
664
665    #[test]
666    fn test_jq_if_then_else_with_concat() {
667        // Alternative to += that works in jaq
668        // Matches Java SDK's for-sum.yaml export.as expression pattern
669        let input = json!({"incr": [2, 3], "counter": 6});
670        let vars = HashMap::new();
671        // This is the if-then-else part that builds a new object
672        let result = evaluate_jq(
673            "if .incr == null then {incr: [5]} else {incr: (.incr + [5])} end",
674            &input,
675            &vars,
676        )
677        .unwrap();
678        assert_eq!(result["incr"], json!([2, 3, 5]));
679    }
680
681    #[test]
682    fn test_jq_if_then_else_null_check() {
683        // First iteration: .incr is null
684        let input = json!({"counter": 0});
685        let vars = HashMap::new();
686        let result = evaluate_jq(
687            "if .incr == null then {incr: [2]} else {incr: (.incr + [2])} end",
688            &input,
689            &vars,
690        )
691        .unwrap();
692        assert_eq!(result["incr"], json!([2]));
693    }
694}