cala_cel_interpreter/
interpreter.rs

1use serde::{Deserialize, Serialize};
2use tracing::instrument;
3
4use cel_parser::{
5    ast::{self, ArithmeticOp, Expression, RelationOp},
6    parse_expression,
7};
8
9use crate::{cel_type::*, context::*, error::*, value::*};
10
11#[derive(Debug, Clone, Deserialize, Serialize)]
12#[serde(try_from = "String")]
13#[serde(into = "String")]
14pub struct CelExpression {
15    source: String,
16    expr: Expression,
17}
18
19impl CelExpression {
20    pub fn try_evaluate<'a, T: TryFrom<CelResult<'a>, Error = ResultCoercionError>>(
21        &'a self,
22        ctx: &CelContext,
23    ) -> Result<T, CelError> {
24        let res = self.evaluate(ctx)?;
25        Ok(T::try_from(CelResult {
26            expr: &self.expr,
27            val: res,
28        })?)
29    }
30
31    #[instrument(name = "cel.evaluate", skip_all, fields(expression = %self.source, context = tracing::field::Empty, result = tracing::field::Empty), err)]
32    pub fn evaluate(&self, ctx: &CelContext) -> Result<CelValue, CelError> {
33        // Record context with actual values for debugging
34        let context_debug = ctx.debug_context();
35        if !context_debug.is_empty() {
36            tracing::Span::current().record("context", &context_debug);
37        }
38
39        let result = match evaluate_expression(&self.expr, ctx)? {
40            EvalType::Value(val) => Ok(val),
41            EvalType::ContextItem(ContextItem::Value(val)) => Ok(val.clone()),
42            _ => Err(CelError::Unexpected(
43                "evaluate didn't return a value".to_string(),
44            )),
45        }?;
46
47        // Record the result value for debugging
48        tracing::Span::current().record("result", format!("{:?}", result));
49
50        Ok(result)
51    }
52}
53
54impl std::fmt::Display for CelExpression {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        write!(f, "{}", self.source)
57    }
58}
59
60enum EvalType<'a> {
61    Value(CelValue),
62    ContextItem(&'a ContextItem),
63    MemberFn(&'a CelValue, &'a CelMemberFunction),
64}
65
66impl EvalType<'_> {
67    fn try_into_bool(self) -> Result<bool, CelError> {
68        if let EvalType::Value(val) = self {
69            val.try_bool()
70        } else {
71            Err(CelError::Unexpected(
72                "Expression didn't resolve to a bool".to_string(),
73            ))
74        }
75    }
76
77    fn try_into_key(self) -> Result<CelKey, CelError> {
78        if let EvalType::Value(val) = self {
79            match val {
80                CelValue::Int(i) => Ok(CelKey::Int(i)),
81                CelValue::UInt(u) => Ok(CelKey::UInt(u)),
82                CelValue::Bool(b) => Ok(CelKey::Bool(b)),
83                CelValue::String(s) => Ok(CelKey::String(s)),
84                _ => Err(CelError::Unexpected(
85                    "Expression didn't resolve to a valid key".to_string(),
86                )),
87            }
88        } else {
89            Err(CelError::Unexpected(
90                "Expression didn't resolve to value".to_string(),
91            ))
92        }
93    }
94
95    fn try_into_value(self) -> Result<CelValue, CelError> {
96        if let EvalType::Value(val) = self {
97            Ok(val)
98        } else {
99            Err(CelError::Unexpected("Couldn't unwrap value".to_string()))
100        }
101    }
102}
103
104#[instrument(name = "cel.evaluate_expression", skip_all, level = "debug", err)]
105fn evaluate_expression<'a>(
106    expr: &Expression,
107    ctx: &'a CelContext,
108) -> Result<EvalType<'a>, CelError> {
109    match evaluate_expression_inner(expr, ctx) {
110        Ok(val) => Ok(val),
111        Err(e) => Err(CelError::EvaluationError(format!("{expr:?}"), Box::new(e))),
112    }
113}
114
115#[instrument(name = "cel.evaluate_expr", skip_all, level = "debug", err)]
116fn evaluate_expression_inner<'a>(
117    expr: &Expression,
118    ctx: &'a CelContext,
119) -> Result<EvalType<'a>, CelError> {
120    use Expression::*;
121    match expr {
122        Ternary(cond, left, right) => {
123            if evaluate_expression(cond, ctx)?.try_into_bool()? {
124                evaluate_expression(left, ctx)
125            } else {
126                evaluate_expression(right, ctx)
127            }
128        }
129        Member(expr, member) => {
130            let ident = evaluate_expression(expr, ctx)?;
131            evaluate_member(ident, member, ctx)
132        }
133        Has(expr) => {
134            // The 'has' macro checks if a field exists in a map
135            // It expects an expression of the form e.f or e.f.g.h (a Member expression)
136            // For nested fields like a.b.c.d, it evaluates a.b.c and checks if 'd' exists
137
138            // Helper function to extract the last field and the target expression
139            fn extract_last_field(
140                expr: &Expression,
141            ) -> Option<(&Expression, &std::sync::Arc<String>)> {
142                match expr {
143                    Expression::Member(target, member) => match member.as_ref() {
144                        ast::Member::Attribute(field_name) => Some((target.as_ref(), field_name)),
145                        _ => None,
146                    },
147                    _ => None,
148                }
149            }
150
151            if let Some((target_expr, field_name)) = extract_last_field(expr.as_ref()) {
152                // Evaluate the target expression (everything except the last field)
153                let target = evaluate_expression(target_expr, ctx)?;
154
155                // Check if the field exists in the map
156                let has_field = match target {
157                    EvalType::Value(CelValue::Map(map)) => map.contains_key(field_name.as_str()),
158                    EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
159                        map.contains_key(field_name.as_str())
160                    }
161                    _ => {
162                        // For non-map types, has() should return an error
163                        return Err(CelError::IllegalTarget);
164                    }
165                };
166
167                Ok(EvalType::Value(CelValue::Bool(has_field)))
168            } else {
169                Err(CelError::Unexpected(
170                    "has() expects a member expression".to_string(),
171                ))
172            }
173        }
174        Map(entries) => {
175            let mut map = CelMap::new();
176            for (k, v) in entries {
177                let key = evaluate_expression(k, ctx)?;
178                let value = evaluate_expression(v, ctx)?;
179                map.insert(key.try_into_key()?, value.try_into_value()?)
180            }
181            Ok(EvalType::Value(CelValue::from(map)))
182        }
183        Ident(name) => Ok(EvalType::ContextItem(ctx.lookup_ident(name)?)),
184        Literal(val) => Ok(EvalType::Value(CelValue::from(val))),
185        Arithmetic(op, left, right) => {
186            let left = evaluate_expression(left, ctx)?;
187            let right = evaluate_expression(right, ctx)?;
188            Ok(EvalType::Value(evaluate_arithmetic(
189                *op,
190                left.try_into_value()?,
191                right.try_into_value()?,
192            )?))
193        }
194        Relation(op, left, right) => {
195            let left = evaluate_expression(left, ctx)?;
196            let right = evaluate_expression(right, ctx)?;
197            Ok(EvalType::Value(evaluate_relation(
198                *op,
199                left.try_into_value()?,
200                right.try_into_value()?,
201            )?))
202        }
203        Unary(op, expr) => {
204            use ast::UnaryOp;
205            match op {
206                UnaryOp::Not => {
207                    let val = evaluate_expression(expr, ctx)?.try_into_bool()?;
208                    Ok(EvalType::Value(CelValue::Bool(!val)))
209                }
210                _ => Err(CelError::Unexpected(format!(
211                    "unimplemented unary op: {op:?}"
212                ))),
213            }
214        }
215        e => Err(CelError::Unexpected(format!("unimplemented {e:?}"))),
216    }
217}
218
219#[instrument(name = "cel.evaluate_member", skip_all, level = "debug", err)]
220fn evaluate_member<'a>(
221    target: EvalType<'a>,
222    member: &ast::Member,
223    ctx: &'a CelContext,
224) -> Result<EvalType<'a>, CelError> {
225    use ast::Member::*;
226    match member {
227        Attribute(name) => match target {
228            EvalType::Value(CelValue::Map(map)) if map.contains_key(name) => {
229                Ok(EvalType::Value(map.get(name)))
230            }
231            EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
232                Ok(EvalType::Value(map.get(name)))
233            }
234            EvalType::ContextItem(ContextItem::Package(p)) => {
235                Ok(EvalType::ContextItem(p.lookup(name)?))
236            }
237            EvalType::ContextItem(ContextItem::Value(v)) => {
238                Ok(EvalType::MemberFn(v, ctx.lookup_member_fn(v, name)?))
239            }
240            _ => Err(CelError::IllegalTarget),
241        },
242        FunctionCall(exprs) => match target {
243            EvalType::ContextItem(ContextItem::Function(f)) => {
244                let mut args = Vec::new();
245                for e in exprs {
246                    args.push(evaluate_expression(e, ctx)?.try_into_value()?)
247                }
248                Ok(EvalType::Value(f(args)?))
249            }
250            EvalType::ContextItem(ContextItem::Package(p)) => {
251                evaluate_member(EvalType::ContextItem(p.package_self()?), member, ctx)
252            }
253            EvalType::MemberFn(v, f) => {
254                let mut args = Vec::new();
255                for e in exprs {
256                    args.push(evaluate_expression(e, ctx)?.try_into_value()?)
257                }
258                Ok(EvalType::Value(f(v, args)?))
259            }
260            _ => Err(CelError::IllegalTarget),
261        },
262        _ => unimplemented!(),
263    }
264}
265
266#[instrument(name = "cel.evaluate_arithmetic", skip_all, level = "debug", err)]
267fn evaluate_arithmetic(
268    op: ArithmeticOp,
269    left: CelValue,
270    right: CelValue,
271) -> Result<CelValue, CelError> {
272    use CelValue::*;
273    match op {
274        ArithmeticOp::Multiply => match (&left, &right) {
275            (UInt(l), UInt(r)) => Ok(UInt(l * r)),
276            (Int(l), Int(r)) => Ok(Int(l * r)),
277            (Double(l), Double(r)) => Ok(Double(l * r)),
278            (Decimal(l), Decimal(r)) => Ok(Decimal(l * r)),
279            _ => Err(CelError::NoMatchingOverload(format!(
280                "Cannot apply '*' to {:?} and {:?}",
281                CelType::from(&left),
282                CelType::from(&right)
283            ))),
284        },
285        ArithmeticOp::Add => match (&left, &right) {
286            (UInt(l), UInt(r)) => Ok(UInt(l + r)),
287            (Int(l), Int(r)) => Ok(Int(l + r)),
288            (Double(l), Double(r)) => Ok(Double(l + r)),
289            (Decimal(l), Decimal(r)) => Ok(Decimal(l + r)),
290            _ => Err(CelError::NoMatchingOverload(format!(
291                "Cannot apply '+' to {:?} and {:?}",
292                CelType::from(&left),
293                CelType::from(&right)
294            ))),
295        },
296        ArithmeticOp::Subtract => match (&left, &right) {
297            (UInt(l), UInt(r)) => Ok(UInt(l - r)),
298            (Int(l), Int(r)) => Ok(Int(l - r)),
299            (Double(l), Double(r)) => Ok(Double(l - r)),
300            (Decimal(l), Decimal(r)) => Ok(Decimal(l - r)),
301            _ => Err(CelError::NoMatchingOverload(format!(
302                "Cannot apply '-' to {:?} and {:?}",
303                CelType::from(&left),
304                CelType::from(&right)
305            ))),
306        },
307        _ => unimplemented!(),
308    }
309}
310
311#[instrument(name = "cel.evaluate_relation", skip_all, level = "debug", err)]
312fn evaluate_relation(
313    op: RelationOp,
314    left: CelValue,
315    right: CelValue,
316) -> Result<CelValue, CelError> {
317    use CelValue::*;
318    match op {
319        RelationOp::LessThan => match (&left, &right) {
320            (UInt(l), UInt(r)) => Ok(Bool(l < r)),
321            (Int(l), Int(r)) => Ok(Bool(l < r)),
322            (Double(l), Double(r)) => Ok(Bool(l < r)),
323            (Decimal(l), Decimal(r)) => Ok(Bool(l < r)),
324            (Date(l), Date(r)) => Ok(Bool(l < r)),
325            (Timestamp(l), Timestamp(r)) => Ok(Bool(l < r)),
326            _ => Err(CelError::NoMatchingOverload(format!(
327                "Cannot apply '<' to {:?} and {:?}",
328                CelType::from(&left),
329                CelType::from(&right)
330            ))),
331        },
332        RelationOp::LessThanEq => match (&left, &right) {
333            (UInt(l), UInt(r)) => Ok(Bool(l <= r)),
334            (Int(l), Int(r)) => Ok(Bool(l <= r)),
335            (Double(l), Double(r)) => Ok(Bool(l <= r)),
336            (Decimal(l), Decimal(r)) => Ok(Bool(l <= r)),
337            (Date(l), Date(r)) => Ok(Bool(l <= r)),
338            (Timestamp(l), Timestamp(r)) => Ok(Bool(l <= r)),
339            _ => Err(CelError::NoMatchingOverload(format!(
340                "Cannot apply '<=' to {:?} and {:?}",
341                CelType::from(&left),
342                CelType::from(&right)
343            ))),
344        },
345        RelationOp::GreaterThan => match (&left, &right) {
346            (UInt(l), UInt(r)) => Ok(Bool(l > r)),
347            (Int(l), Int(r)) => Ok(Bool(l > r)),
348            (Double(l), Double(r)) => Ok(Bool(l > r)),
349            (Decimal(l), Decimal(r)) => Ok(Bool(l > r)),
350            (Date(l), Date(r)) => Ok(Bool(l > r)),
351            (Timestamp(l), Timestamp(r)) => Ok(Bool(l > r)),
352            _ => Err(CelError::NoMatchingOverload(format!(
353                "Cannot apply '>' to {:?} and {:?}",
354                CelType::from(&left),
355                CelType::from(&right)
356            ))),
357        },
358        RelationOp::GreaterThanEq => match (&left, &right) {
359            (UInt(l), UInt(r)) => Ok(Bool(l >= r)),
360            (Int(l), Int(r)) => Ok(Bool(l >= r)),
361            (Double(l), Double(r)) => Ok(Bool(l >= r)),
362            (Decimal(l), Decimal(r)) => Ok(Bool(l >= r)),
363            (Date(l), Date(r)) => Ok(Bool(l >= r)),
364            (Timestamp(l), Timestamp(r)) => Ok(Bool(l >= r)),
365            _ => Err(CelError::NoMatchingOverload(format!(
366                "Cannot apply '>=' to {:?} and {:?}",
367                CelType::from(&left),
368                CelType::from(&right)
369            ))),
370        },
371        RelationOp::Equals => match (&left, &right) {
372            (UInt(l), UInt(r)) => Ok(Bool(l == r)),
373            (Int(l), Int(r)) => Ok(Bool(l == r)),
374            (Double(l), Double(r)) => Ok(Bool(l == r)),
375            (Decimal(l), Decimal(r)) => Ok(Bool(l == r)),
376            (Date(l), Date(r)) => Ok(Bool(l == r)),
377            (Timestamp(l), Timestamp(r)) => Ok(Bool(l == r)),
378            _ => Err(CelError::NoMatchingOverload(format!(
379                "Cannot apply '==' to {:?} and {:?}",
380                CelType::from(&left),
381                CelType::from(&right)
382            ))),
383        },
384        RelationOp::NotEquals => match (&left, &right) {
385            (UInt(l), UInt(r)) => Ok(Bool(l != r)),
386            (Int(l), Int(r)) => Ok(Bool(l != r)),
387            (Double(l), Double(r)) => Ok(Bool(l != r)),
388            (Decimal(l), Decimal(r)) => Ok(Bool(l != r)),
389            (Date(l), Date(r)) => Ok(Bool(l != r)),
390            (Timestamp(l), Timestamp(r)) => Ok(Bool(l != r)),
391            _ => Err(CelError::NoMatchingOverload(format!(
392                "Cannot apply '!=' to {:?} and {:?}",
393                CelType::from(&left),
394                CelType::from(&right)
395            ))),
396        },
397        _ => unimplemented!(),
398    }
399}
400
401impl From<CelExpression> for String {
402    fn from(expr: CelExpression) -> Self {
403        expr.source
404    }
405}
406
407impl TryFrom<String> for CelExpression {
408    type Error = CelError;
409
410    fn try_from(source: String) -> Result<Self, Self::Error> {
411        let expr = parse_expression(source.clone()).map_err(CelError::CelParseError)?;
412        Ok(Self { source, expr })
413    }
414}
415impl TryFrom<&str> for CelExpression {
416    type Error = CelError;
417
418    fn try_from(source: &str) -> Result<Self, Self::Error> {
419        Self::try_from(source.to_string())
420    }
421}
422impl std::str::FromStr for CelExpression {
423    type Err = CelError;
424
425    fn from_str(source: &str) -> Result<Self, Self::Err> {
426        Self::try_from(source.to_string())
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use chrono::NaiveDate;
434
435    #[test]
436    fn literals() {
437        let expression = "true".parse::<CelExpression>().unwrap();
438        let context = CelContext::new();
439        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
440
441        let expression = "1".parse::<CelExpression>().unwrap();
442        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(1));
443
444        let expression = "-1".parse::<CelExpression>().unwrap();
445        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(-1));
446
447        let expression = "'hello'".parse::<CelExpression>().unwrap();
448        assert_eq!(
449            expression.evaluate(&context).unwrap(),
450            CelValue::String("hello".to_string().into())
451        );
452
453        // Tokenizer needs fixing
454        // let expression = "1u".parse::<CelExpression>().unwrap();
455        // assert_eq!(expression.evaluate(&context).unwrap(), CelValue::UInt(1))
456    }
457
458    #[test]
459    fn logic() {
460        let expression = "true || false ? false && true : true"
461            .parse::<CelExpression>()
462            .unwrap();
463        let context = CelContext::new();
464        assert_eq!(
465            expression.evaluate(&context).unwrap(),
466            CelValue::Bool(false)
467        );
468        let expression = "true && false ? false : true || false"
469            .parse::<CelExpression>()
470            .unwrap();
471        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true))
472    }
473
474    #[test]
475    fn lookup() {
476        let expression = "params.hello.world".parse::<CelExpression>().unwrap();
477        let mut hello = CelMap::new();
478        hello.insert("world", 42);
479        let mut params = CelMap::new();
480        params.insert("hello", hello);
481        let mut context = CelContext::new();
482        context.add_variable("params", params);
483        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(42));
484    }
485
486    #[test]
487    fn to_level_function() {
488        let expression = "date('2022-10-10')".parse::<CelExpression>().unwrap();
489        let context = CelContext::new();
490        assert_eq!(
491            expression.evaluate(&context).unwrap(),
492            CelValue::Date(NaiveDate::parse_from_str("2022-10-10", "%Y-%m-%d").unwrap())
493        );
494    }
495
496    #[test]
497    fn cast_function() {
498        let expression = "decimal('1')".parse::<CelExpression>().unwrap();
499        let context = CelContext::new();
500        assert_eq!(
501            expression.evaluate(&context).unwrap(),
502            CelValue::Decimal(1.into())
503        );
504    }
505
506    #[test]
507    fn package_function() -> anyhow::Result<()> {
508        let expression = "decimal.Add(decimal('1'), decimal('2'))"
509            .parse::<CelExpression>()
510            .unwrap();
511        let context = CelContext::new();
512        assert_eq!(expression.evaluate(&context)?, CelValue::Decimal(3.into()));
513        Ok(())
514    }
515
516    #[test]
517    fn has_macro_with_map() {
518        // Test 'has' with existing field
519        let expression = "has(params.hello)".parse::<CelExpression>().unwrap();
520        let mut params = CelMap::new();
521        params.insert("hello", "world");
522        let mut context = CelContext::new();
523        context.add_variable("params", params);
524        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
525
526        // Test 'has' with non-existing field
527        let expression = "has(params.missing)".parse::<CelExpression>().unwrap();
528        let mut params = CelMap::new();
529        params.insert("hello", "world");
530        let mut context = CelContext::new();
531        context.add_variable("params", params);
532        assert_eq!(
533            expression.evaluate(&context).unwrap(),
534            CelValue::Bool(false)
535        );
536
537        // Test 'has' with nested maps
538        let expression = "has(params.nested.field)".parse::<CelExpression>().unwrap();
539        let mut nested = CelMap::new();
540        nested.insert("field", 42);
541        let mut params = CelMap::new();
542        params.insert("nested", nested);
543        let mut context = CelContext::new();
544        context.add_variable("params", params);
545        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
546
547        // Test 'has' with deeply nested maps (a.b.c.d)
548        let expression = "has(config.database.settings.maxConnections)"
549            .parse::<CelExpression>()
550            .unwrap();
551        let mut settings = CelMap::new();
552        settings.insert("maxConnections", 100);
553        settings.insert("timeout", 30);
554        let mut database = CelMap::new();
555        database.insert("settings", settings);
556        let mut config = CelMap::new();
557        config.insert("database", database);
558        let mut context = CelContext::new();
559        context.add_variable("config", config);
560        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
561
562        // Test 'has' with deeply nested maps - missing final field
563        let expression = "has(config.database.settings.missingField)"
564            .parse::<CelExpression>()
565            .unwrap();
566        let mut settings = CelMap::new();
567        settings.insert("maxConnections", 100);
568        let mut database = CelMap::new();
569        database.insert("settings", settings);
570        let mut config = CelMap::new();
571        config.insert("database", database);
572        let mut context = CelContext::new();
573        context.add_variable("config", config);
574        assert_eq!(
575            expression.evaluate(&context).unwrap(),
576            CelValue::Bool(false)
577        );
578    }
579
580    #[test]
581    fn function_on_timestamp() -> anyhow::Result<()> {
582        use chrono::{DateTime, Utc};
583
584        let time: DateTime<Utc> = "1940-12-21T00:00:00Z".parse().unwrap();
585        let mut context = CelContext::new();
586        context.add_variable("now", time);
587
588        let expression = "now.format('%d/%m/%Y')".parse::<CelExpression>().unwrap();
589        assert_eq!(expression.evaluate(&context)?, CelValue::from("21/12/1940"));
590
591        Ok(())
592    }
593}