cala_cel_interpreter/
interpreter.rs

1use serde::{Deserialize, Serialize};
2
3use cel_parser::{
4    ast::{self, ArithmeticOp, Expression, RelationOp},
5    parser::ExpressionParser,
6};
7
8use crate::{cel_type::*, context::*, error::*, value::*};
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
11#[serde(try_from = "String")]
12#[serde(into = "String")]
13pub struct CelExpression {
14    source: String,
15    expr: Expression,
16}
17
18impl CelExpression {
19    pub fn try_evaluate<'a, T: TryFrom<CelResult<'a>, Error = ResultCoercionError>>(
20        &'a self,
21        ctx: &CelContext,
22    ) -> Result<T, CelError> {
23        let res = self.evaluate(ctx)?;
24        Ok(T::try_from(CelResult {
25            expr: &self.expr,
26            val: res,
27        })?)
28    }
29
30    pub fn evaluate(&self, ctx: &CelContext) -> Result<CelValue, CelError> {
31        match evaluate_expression(&self.expr, ctx)? {
32            EvalType::Value(val) => Ok(val),
33            EvalType::ContextItem(ContextItem::Value(val)) => Ok(val.clone()),
34            _ => Err(CelError::Unexpected(
35                "evaluate didn't return a value".to_string(),
36            )),
37        }
38    }
39}
40
41impl std::fmt::Display for CelExpression {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(f, "{}", self.source)
44    }
45}
46
47enum EvalType<'a> {
48    Value(CelValue),
49    ContextItem(&'a ContextItem),
50    MemberFn(&'a CelValue, &'a CelMemberFunction),
51}
52
53impl EvalType<'_> {
54    fn try_into_bool(self) -> Result<bool, CelError> {
55        if let EvalType::Value(val) = self {
56            val.try_bool()
57        } else {
58            Err(CelError::Unexpected(
59                "Expression didn't resolve to a bool".to_string(),
60            ))
61        }
62    }
63
64    fn try_into_key(self) -> Result<CelKey, CelError> {
65        if let EvalType::Value(val) = self {
66            match val {
67                CelValue::Int(i) => Ok(CelKey::Int(i)),
68                CelValue::UInt(u) => Ok(CelKey::UInt(u)),
69                CelValue::Bool(b) => Ok(CelKey::Bool(b)),
70                CelValue::String(s) => Ok(CelKey::String(s)),
71                _ => Err(CelError::Unexpected(
72                    "Expression didn't resolve to a valid key".to_string(),
73                )),
74            }
75        } else {
76            Err(CelError::Unexpected(
77                "Expression didn't resolve to value".to_string(),
78            ))
79        }
80    }
81
82    fn try_into_value(self) -> Result<CelValue, CelError> {
83        if let EvalType::Value(val) = self {
84            Ok(val)
85        } else {
86            Err(CelError::Unexpected("Couldn't unwrap value".to_string()))
87        }
88    }
89}
90
91fn evaluate_expression<'a>(
92    expr: &Expression,
93    ctx: &'a CelContext,
94) -> Result<EvalType<'a>, CelError> {
95    match evaluate_expression_inner(expr, ctx) {
96        Ok(val) => Ok(val),
97        Err(e) => Err(CelError::EvaluationError(format!("{expr:?}"), Box::new(e))),
98    }
99}
100
101fn evaluate_expression_inner<'a>(
102    expr: &Expression,
103    ctx: &'a CelContext,
104) -> Result<EvalType<'a>, CelError> {
105    use Expression::*;
106    match expr {
107        Ternary(cond, left, right) => {
108            if evaluate_expression(cond, ctx)?.try_into_bool()? {
109                evaluate_expression(left, ctx)
110            } else {
111                evaluate_expression(right, ctx)
112            }
113        }
114        Member(expr, member) => {
115            let ident = evaluate_expression(expr, ctx)?;
116            evaluate_member(ident, member, ctx)
117        }
118        Map(entries) => {
119            let mut map = CelMap::new();
120            for (k, v) in entries {
121                let key = evaluate_expression(k, ctx)?;
122                let value = evaluate_expression(v, ctx)?;
123                map.insert(key.try_into_key()?, value.try_into_value()?)
124            }
125            Ok(EvalType::Value(CelValue::from(map)))
126        }
127        Ident(name) => Ok(EvalType::ContextItem(ctx.lookup_ident(name)?)),
128        Literal(val) => Ok(EvalType::Value(CelValue::from(val))),
129        Arithmetic(op, left, right) => {
130            let left = evaluate_expression(left, ctx)?;
131            let right = evaluate_expression(right, ctx)?;
132            Ok(EvalType::Value(evaluate_arithmetic(
133                *op,
134                left.try_into_value()?,
135                right.try_into_value()?,
136            )?))
137        }
138        Relation(op, left, right) => {
139            let left = evaluate_expression(left, ctx)?;
140            let right = evaluate_expression(right, ctx)?;
141            Ok(EvalType::Value(evaluate_relation(
142                *op,
143                left.try_into_value()?,
144                right.try_into_value()?,
145            )?))
146        }
147        e => Err(CelError::Unexpected(format!("unimplemented {e:?}"))),
148    }
149}
150
151fn evaluate_member<'a>(
152    target: EvalType<'a>,
153    member: &ast::Member,
154    ctx: &'a CelContext,
155) -> Result<EvalType<'a>, CelError> {
156    use ast::Member::*;
157    match member {
158        Attribute(name) => match target {
159            EvalType::Value(CelValue::Map(map)) if map.contains_key(name) => {
160                Ok(EvalType::Value(map.get(name)))
161            }
162            EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
163                Ok(EvalType::Value(map.get(name)))
164            }
165            EvalType::ContextItem(ContextItem::Package(p)) => {
166                Ok(EvalType::ContextItem(p.lookup(name)?))
167            }
168            EvalType::ContextItem(ContextItem::Value(v)) => {
169                Ok(EvalType::MemberFn(v, ctx.lookup_member_fn(v, name)?))
170            }
171            _ => Err(CelError::IllegalTarget),
172        },
173        FunctionCall(exprs) => match target {
174            EvalType::ContextItem(ContextItem::Function(f)) => {
175                let mut args = Vec::new();
176                for e in exprs {
177                    args.push(evaluate_expression(e, ctx)?.try_into_value()?)
178                }
179                Ok(EvalType::Value(f(args)?))
180            }
181            EvalType::ContextItem(ContextItem::Package(p)) => {
182                evaluate_member(EvalType::ContextItem(p.package_self()?), member, ctx)
183            }
184            EvalType::MemberFn(v, f) => {
185                let mut args = Vec::new();
186                for e in exprs {
187                    args.push(evaluate_expression(e, ctx)?.try_into_value()?)
188                }
189                Ok(EvalType::Value(f(v, args)?))
190            }
191            _ => Err(CelError::IllegalTarget),
192        },
193        _ => unimplemented!(),
194    }
195}
196
197fn evaluate_arithmetic(
198    op: ArithmeticOp,
199    left: CelValue,
200    right: CelValue,
201) -> Result<CelValue, CelError> {
202    use CelValue::*;
203    match op {
204        ArithmeticOp::Multiply => match (&left, &right) {
205            (UInt(l), UInt(r)) => Ok(UInt(l * r)),
206            (Int(l), Int(r)) => Ok(Int(l * r)),
207            (Double(l), Double(r)) => Ok(Double(l * r)),
208            (Decimal(l), Decimal(r)) => Ok(Decimal(l * r)),
209            _ => Err(CelError::NoMatchingOverload(format!(
210                "Cannot apply '*' to {:?} and {:?}",
211                CelType::from(&left),
212                CelType::from(&right)
213            ))),
214        },
215        ArithmeticOp::Add => match (&left, &right) {
216            (UInt(l), UInt(r)) => Ok(UInt(l + r)),
217            (Int(l), Int(r)) => Ok(Int(l + r)),
218            (Double(l), Double(r)) => Ok(Double(l + r)),
219            (Decimal(l), Decimal(r)) => Ok(Decimal(l + r)),
220            _ => Err(CelError::NoMatchingOverload(format!(
221                "Cannot apply '+' to {:?} and {:?}",
222                CelType::from(&left),
223                CelType::from(&right)
224            ))),
225        },
226        ArithmeticOp::Subtract => match (&left, &right) {
227            (UInt(l), UInt(r)) => Ok(UInt(l - r)),
228            (Int(l), Int(r)) => Ok(Int(l - r)),
229            (Double(l), Double(r)) => Ok(Double(l - r)),
230            (Decimal(l), Decimal(r)) => Ok(Decimal(l - r)),
231            _ => Err(CelError::NoMatchingOverload(format!(
232                "Cannot apply '-' to {:?} and {:?}",
233                CelType::from(&left),
234                CelType::from(&right)
235            ))),
236        },
237        _ => unimplemented!(),
238    }
239}
240
241fn evaluate_relation(
242    op: RelationOp,
243    left: CelValue,
244    right: CelValue,
245) -> Result<CelValue, CelError> {
246    use CelValue::*;
247    match op {
248        RelationOp::LessThan => match (&left, &right) {
249            (UInt(l), UInt(r)) => Ok(Bool(l < r)),
250            (Int(l), Int(r)) => Ok(Bool(l < r)),
251            (Double(l), Double(r)) => Ok(Bool(l < r)),
252            (Decimal(l), Decimal(r)) => Ok(Bool(l < r)),
253            _ => Err(CelError::NoMatchingOverload(format!(
254                "Cannot apply '<' to {:?} and {:?}",
255                CelType::from(&left),
256                CelType::from(&right)
257            ))),
258        },
259        RelationOp::LessThanEq => match (&left, &right) {
260            (UInt(l), UInt(r)) => Ok(Bool(l <= r)),
261            (Int(l), Int(r)) => Ok(Bool(l <= r)),
262            (Double(l), Double(r)) => Ok(Bool(l <= r)),
263            (Decimal(l), Decimal(r)) => Ok(Bool(l <= r)),
264            _ => Err(CelError::NoMatchingOverload(format!(
265                "Cannot apply '<=' to {:?} and {:?}",
266                CelType::from(&left),
267                CelType::from(&right)
268            ))),
269        },
270        RelationOp::GreaterThan => match (&left, &right) {
271            (UInt(l), UInt(r)) => Ok(Bool(l > r)),
272            (Int(l), Int(r)) => Ok(Bool(l > r)),
273            (Double(l), Double(r)) => Ok(Bool(l > r)),
274            (Decimal(l), Decimal(r)) => Ok(Bool(l > r)),
275            _ => Err(CelError::NoMatchingOverload(format!(
276                "Cannot apply '>' to {:?} and {:?}",
277                CelType::from(&left),
278                CelType::from(&right)
279            ))),
280        },
281        RelationOp::GreaterThanEq => match (&left, &right) {
282            (UInt(l), UInt(r)) => Ok(Bool(l >= r)),
283            (Int(l), Int(r)) => Ok(Bool(l >= r)),
284            (Double(l), Double(r)) => Ok(Bool(l >= r)),
285            (Decimal(l), Decimal(r)) => Ok(Bool(l >= r)),
286            _ => Err(CelError::NoMatchingOverload(format!(
287                "Cannot apply '>=' to {:?} and {:?}",
288                CelType::from(&left),
289                CelType::from(&right)
290            ))),
291        },
292        RelationOp::Equals => match (&left, &right) {
293            (UInt(l), UInt(r)) => Ok(Bool(l == r)),
294            (Int(l), Int(r)) => Ok(Bool(l == r)),
295            (Double(l), Double(r)) => Ok(Bool(l == r)),
296            (Decimal(l), Decimal(r)) => Ok(Bool(l == r)),
297            _ => Err(CelError::NoMatchingOverload(format!(
298                "Cannot apply '==' to {:?} and {:?}",
299                CelType::from(&left),
300                CelType::from(&right)
301            ))),
302        },
303        RelationOp::NotEquals => match (&left, &right) {
304            (UInt(l), UInt(r)) => Ok(Bool(l != r)),
305            (Int(l), Int(r)) => Ok(Bool(l != r)),
306            (Double(l), Double(r)) => Ok(Bool(l != r)),
307            (Decimal(l), Decimal(r)) => Ok(Bool(l != r)),
308            _ => Err(CelError::NoMatchingOverload(format!(
309                "Cannot apply '!=' to {:?} and {:?}",
310                CelType::from(&left),
311                CelType::from(&right)
312            ))),
313        },
314        _ => unimplemented!(),
315    }
316}
317
318impl From<CelExpression> for String {
319    fn from(expr: CelExpression) -> Self {
320        expr.source
321    }
322}
323
324impl TryFrom<String> for CelExpression {
325    type Error = CelError;
326
327    fn try_from(source: String) -> Result<Self, Self::Error> {
328        let expr = ExpressionParser::new()
329            .parse(&source)
330            .map_err(|e| CelError::CelParseError(e.to_string()))?;
331        Ok(Self { source, expr })
332    }
333}
334impl TryFrom<&str> for CelExpression {
335    type Error = CelError;
336
337    fn try_from(source: &str) -> Result<Self, Self::Error> {
338        Self::try_from(source.to_string())
339    }
340}
341impl std::str::FromStr for CelExpression {
342    type Err = CelError;
343
344    fn from_str(source: &str) -> Result<Self, Self::Err> {
345        Self::try_from(source.to_string())
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use chrono::NaiveDate;
353
354    #[test]
355    fn literals() {
356        let expression = "true".parse::<CelExpression>().unwrap();
357        let context = CelContext::new();
358        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
359
360        let expression = "1".parse::<CelExpression>().unwrap();
361        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(1));
362
363        let expression = "-1".parse::<CelExpression>().unwrap();
364        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(-1));
365
366        let expression = "'hello'".parse::<CelExpression>().unwrap();
367        assert_eq!(
368            expression.evaluate(&context).unwrap(),
369            CelValue::String("hello".to_string().into())
370        );
371
372        // Tokenizer needs fixing
373        // let expression = "1u".parse::<CelExpression>().unwrap();
374        // assert_eq!(expression.evaluate(&context).unwrap(), CelValue::UInt(1))
375    }
376
377    #[test]
378    fn logic() {
379        let expression = "true || false ? false && true : true"
380            .parse::<CelExpression>()
381            .unwrap();
382        let context = CelContext::new();
383        assert_eq!(
384            expression.evaluate(&context).unwrap(),
385            CelValue::Bool(false)
386        );
387        let expression = "true && false ? false : true || false"
388            .parse::<CelExpression>()
389            .unwrap();
390        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true))
391    }
392
393    #[test]
394    fn lookup() {
395        let expression = "params.hello.world".parse::<CelExpression>().unwrap();
396        let mut hello = CelMap::new();
397        hello.insert("world", 42);
398        let mut params = CelMap::new();
399        params.insert("hello", hello);
400        let mut context = CelContext::new();
401        context.add_variable("params", params);
402        assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(42));
403    }
404
405    #[test]
406    fn to_level_function() {
407        let expression = "date('2022-10-10')".parse::<CelExpression>().unwrap();
408        let context = CelContext::new();
409        assert_eq!(
410            expression.evaluate(&context).unwrap(),
411            CelValue::Date(NaiveDate::parse_from_str("2022-10-10", "%Y-%m-%d").unwrap())
412        );
413    }
414
415    #[test]
416    fn cast_function() {
417        let expression = "decimal('1')".parse::<CelExpression>().unwrap();
418        let context = CelContext::new();
419        assert_eq!(
420            expression.evaluate(&context).unwrap(),
421            CelValue::Decimal(1.into())
422        );
423    }
424
425    #[test]
426    fn package_function() -> anyhow::Result<()> {
427        let expression = "decimal.Add(decimal('1'), decimal('2'))"
428            .parse::<CelExpression>()
429            .unwrap();
430        let context = CelContext::new();
431        assert_eq!(expression.evaluate(&context)?, CelValue::Decimal(3.into()));
432        Ok(())
433    }
434
435    #[test]
436    fn function_on_timestamp() -> anyhow::Result<()> {
437        use chrono::{DateTime, Utc};
438
439        let time: DateTime<Utc> = "1940-12-21T00:00:00Z".parse().unwrap();
440        let mut context = CelContext::new();
441        context.add_variable("now", time);
442
443        let expression = "now.format('%d/%m/%Y')".parse::<CelExpression>().unwrap();
444        assert_eq!(expression.evaluate(&context)?, CelValue::from("21/12/1940"));
445
446        Ok(())
447    }
448}