Skip to main content

agent_orchestrator/prehook/
cel.rs

1use crate::config::{
2    ConvergenceContext, ItemFinalizeContext, StepPrehookContext, WorkflowFinalizeRule,
3};
4use anyhow::Result;
5use cel_interpreter::{Program, Value as CelValue};
6
7use super::context::{
8    build_convergence_cel_context, build_finalize_cel_context, build_step_prehook_cel_context,
9};
10
11/// Evaluates a step prehook CEL expression against the provided context.
12pub fn evaluate_step_prehook_expression(
13    expression: &str,
14    context: &StepPrehookContext,
15) -> Result<bool> {
16    let compiled = std::panic::catch_unwind(|| Program::compile(expression))
17        .map_err(|_| anyhow::anyhow!("step '{}' prehook compilation panicked", context.step))?;
18    let program = compiled.map_err(|err| {
19        anyhow::anyhow!(
20            "step '{}' prehook compilation failed: {}",
21            context.step,
22            err
23        )
24    })?;
25    let cel_context = build_step_prehook_cel_context(context)?;
26    let value = program.execute(&cel_context).map_err(|err| {
27        anyhow::anyhow!("step '{}' prehook execution failed: {}", context.step, err)
28    })?;
29    match value {
30        CelValue::Bool(v) => Ok(v),
31        other => {
32            anyhow::bail!(
33                "step '{}' prehook must return bool, got {:?}",
34                context.step,
35                other
36            );
37        }
38    }
39}
40
41/// Evaluates a finalize-rule CEL expression against the provided context.
42pub fn evaluate_finalize_rule_expression(
43    rule: &WorkflowFinalizeRule,
44    context: &ItemFinalizeContext,
45) -> Result<bool> {
46    let expression = rule.when.trim();
47    let compiled = std::panic::catch_unwind(|| Program::compile(expression))
48        .map_err(|_| anyhow::anyhow!("finalize rule '{}' compilation panicked", rule.id))?;
49    let program = compiled.map_err(|err| {
50        anyhow::anyhow!("finalize rule '{}' compilation failed: {}", rule.id, err)
51    })?;
52    let cel_context = build_finalize_cel_context(context)?;
53    let value = program
54        .execute(&cel_context)
55        .map_err(|err| anyhow::anyhow!("finalize rule '{}' execution failed: {}", rule.id, err))?;
56    match value {
57        CelValue::Bool(v) => Ok(v),
58        other => anyhow::bail!(
59            "finalize rule '{}' must return bool, got {:?}",
60            rule.id,
61            other
62        ),
63    }
64}
65
66/// Evaluates a webhook payload filter CEL expression.
67///
68/// The `payload` JSON value is injected as a CEL variable named `payload`.
69/// Top-level string/number/bool fields are accessible as `payload.field_name`.
70pub fn evaluate_webhook_filter(expression: &str, payload: &serde_json::Value) -> Result<bool> {
71    let compiled = std::panic::catch_unwind(|| Program::compile(expression))
72        .map_err(|_| anyhow::anyhow!("webhook filter compilation panicked"))?;
73    let program =
74        compiled.map_err(|err| anyhow::anyhow!("webhook filter compilation failed: {}", err))?;
75
76    let mut cel_context = cel_interpreter::Context::default();
77    // Inject the full payload as a JSON string variable for complex access.
78    let payload_str = serde_json::to_string(payload).unwrap_or_default();
79    cel_context
80        .add_variable("payload_json", payload_str)
81        .map_err(|e| anyhow::anyhow!("webhook filter context: {e}"))?;
82    // Inject top-level fields as individual variables for direct access.
83    if let serde_json::Value::Object(map) = payload {
84        for (key, val) in map {
85            let var_name = format!("payload_{key}");
86            match val {
87                serde_json::Value::String(s) => {
88                    let _ = cel_context.add_variable(var_name, s.clone());
89                }
90                serde_json::Value::Number(n) => {
91                    if let Some(i) = n.as_i64() {
92                        let _ = cel_context.add_variable(var_name, i);
93                    } else if let Some(f) = n.as_f64() {
94                        let _ = cel_context.add_variable(var_name, f);
95                    }
96                }
97                serde_json::Value::Bool(b) => {
98                    let _ = cel_context.add_variable(var_name, *b);
99                }
100                _ => {
101                    // Nested objects/arrays: serialize as JSON string
102                    let s = serde_json::to_string(val).unwrap_or_default();
103                    let _ = cel_context.add_variable(var_name, s);
104                }
105            }
106        }
107    }
108
109    let value = program
110        .execute(&cel_context)
111        .map_err(|err| anyhow::anyhow!("webhook filter execution failed: {}", err))?;
112    match value {
113        CelValue::Bool(v) => Ok(v),
114        other => anyhow::bail!("webhook filter must return bool, got {:?}", other),
115    }
116}
117
118/// Evaluates a convergence CEL expression against the provided context.
119pub fn evaluate_convergence_expression(
120    expression: &str,
121    context: &ConvergenceContext,
122) -> Result<bool> {
123    let compiled = std::panic::catch_unwind(|| Program::compile(expression))
124        .map_err(|_| anyhow::anyhow!("convergence_expr compilation panicked"))?;
125    let program =
126        compiled.map_err(|err| anyhow::anyhow!("convergence_expr compilation failed: {}", err))?;
127    let cel_context = build_convergence_cel_context(context)?;
128    let value = program
129        .execute(&cel_context)
130        .map_err(|err| anyhow::anyhow!("convergence_expr execution failed: {}", err))?;
131    match value {
132        CelValue::Bool(v) => Ok(v),
133        other => anyhow::bail!("convergence_expr must return bool, got {:?}", other),
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn webhook_filter_matches_string_field() {
143        let payload = serde_json::json!({"type": "message", "channel": "C123"});
144        assert!(evaluate_webhook_filter("payload_type == 'message'", &payload).unwrap());
145    }
146
147    #[test]
148    fn webhook_filter_rejects_non_matching() {
149        let payload = serde_json::json!({"type": "reaction", "channel": "C123"});
150        assert!(!evaluate_webhook_filter("payload_type == 'message'", &payload).unwrap());
151    }
152
153    #[test]
154    fn webhook_filter_matches_bool_field() {
155        let payload = serde_json::json!({"active": true, "count": 5});
156        assert!(evaluate_webhook_filter("payload_active == true", &payload).unwrap());
157    }
158
159    #[test]
160    fn webhook_filter_matches_number_field() {
161        let payload = serde_json::json!({"count": 42});
162        assert!(evaluate_webhook_filter("payload_count > 10", &payload).unwrap());
163    }
164
165    #[test]
166    fn webhook_filter_empty_payload() {
167        let payload = serde_json::json!({});
168        // Expression referencing non-existent variable should fail
169        assert!(evaluate_webhook_filter("payload_type == 'x'", &payload).is_err());
170    }
171
172    #[test]
173    fn webhook_filter_invalid_expression() {
174        let payload = serde_json::json!({"a": 1});
175        assert!(evaluate_webhook_filter("invalid %%% syntax", &payload).is_err());
176    }
177}