1use serde::{Deserialize, Serialize};
2
3use cel_parser::{
4 ast::{self, ArithmeticOp, Expression, RelationOp},
5 parser::ExpressionParser,
6};
7
8use crate::{cel_type::*, context::*, error::*, value::*};
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
11#[serde(try_from = "String")]
12#[serde(into = "String")]
13pub struct CelExpression {
14 source: String,
15 expr: Expression,
16}
17
18impl CelExpression {
19 pub fn try_evaluate<'a, T: TryFrom<CelResult<'a>, Error = ResultCoercionError>>(
20 &'a self,
21 ctx: &CelContext,
22 ) -> Result<T, CelError> {
23 let res = self.evaluate(ctx)?;
24 Ok(T::try_from(CelResult {
25 expr: &self.expr,
26 val: res,
27 })?)
28 }
29
30 pub fn evaluate(&self, ctx: &CelContext) -> Result<CelValue, CelError> {
31 match evaluate_expression(&self.expr, ctx)? {
32 EvalType::Value(val) => Ok(val),
33 EvalType::ContextItem(ContextItem::Value(val)) => Ok(val.clone()),
34 _ => Err(CelError::Unexpected(
35 "evaluate didn't return a value".to_string(),
36 )),
37 }
38 }
39}
40
41impl std::fmt::Display for CelExpression {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}", self.source)
44 }
45}
46
47enum EvalType<'a> {
48 Value(CelValue),
49 ContextItem(&'a ContextItem),
50 MemberFn(&'a CelValue, &'a CelMemberFunction),
51}
52
53impl EvalType<'_> {
54 fn try_into_bool(self) -> Result<bool, CelError> {
55 if let EvalType::Value(val) = self {
56 val.try_bool()
57 } else {
58 Err(CelError::Unexpected(
59 "Expression didn't resolve to a bool".to_string(),
60 ))
61 }
62 }
63
64 fn try_into_key(self) -> Result<CelKey, CelError> {
65 if let EvalType::Value(val) = self {
66 match val {
67 CelValue::Int(i) => Ok(CelKey::Int(i)),
68 CelValue::UInt(u) => Ok(CelKey::UInt(u)),
69 CelValue::Bool(b) => Ok(CelKey::Bool(b)),
70 CelValue::String(s) => Ok(CelKey::String(s)),
71 _ => Err(CelError::Unexpected(
72 "Expression didn't resolve to a valid key".to_string(),
73 )),
74 }
75 } else {
76 Err(CelError::Unexpected(
77 "Expression didn't resolve to value".to_string(),
78 ))
79 }
80 }
81
82 fn try_into_value(self) -> Result<CelValue, CelError> {
83 if let EvalType::Value(val) = self {
84 Ok(val)
85 } else {
86 Err(CelError::Unexpected("Couldn't unwrap value".to_string()))
87 }
88 }
89}
90
91fn evaluate_expression<'a>(
92 expr: &Expression,
93 ctx: &'a CelContext,
94) -> Result<EvalType<'a>, CelError> {
95 match evaluate_expression_inner(expr, ctx) {
96 Ok(val) => Ok(val),
97 Err(e) => Err(CelError::EvaluationError(format!("{expr:?}"), Box::new(e))),
98 }
99}
100
101fn evaluate_expression_inner<'a>(
102 expr: &Expression,
103 ctx: &'a CelContext,
104) -> Result<EvalType<'a>, CelError> {
105 use Expression::*;
106 match expr {
107 Ternary(cond, left, right) => {
108 if evaluate_expression(cond, ctx)?.try_into_bool()? {
109 evaluate_expression(left, ctx)
110 } else {
111 evaluate_expression(right, ctx)
112 }
113 }
114 Member(expr, member) => {
115 let ident = evaluate_expression(expr, ctx)?;
116 evaluate_member(ident, member, ctx)
117 }
118 Map(entries) => {
119 let mut map = CelMap::new();
120 for (k, v) in entries {
121 let key = evaluate_expression(k, ctx)?;
122 let value = evaluate_expression(v, ctx)?;
123 map.insert(key.try_into_key()?, value.try_into_value()?)
124 }
125 Ok(EvalType::Value(CelValue::from(map)))
126 }
127 Ident(name) => Ok(EvalType::ContextItem(ctx.lookup_ident(name)?)),
128 Literal(val) => Ok(EvalType::Value(CelValue::from(val))),
129 Arithmetic(op, left, right) => {
130 let left = evaluate_expression(left, ctx)?;
131 let right = evaluate_expression(right, ctx)?;
132 Ok(EvalType::Value(evaluate_arithmetic(
133 *op,
134 left.try_into_value()?,
135 right.try_into_value()?,
136 )?))
137 }
138 Relation(op, left, right) => {
139 let left = evaluate_expression(left, ctx)?;
140 let right = evaluate_expression(right, ctx)?;
141 Ok(EvalType::Value(evaluate_relation(
142 *op,
143 left.try_into_value()?,
144 right.try_into_value()?,
145 )?))
146 }
147 e => Err(CelError::Unexpected(format!("unimplemented {e:?}"))),
148 }
149}
150
151fn evaluate_member<'a>(
152 target: EvalType<'a>,
153 member: &ast::Member,
154 ctx: &'a CelContext,
155) -> Result<EvalType<'a>, CelError> {
156 use ast::Member::*;
157 match member {
158 Attribute(name) => match target {
159 EvalType::Value(CelValue::Map(map)) if map.contains_key(name) => {
160 Ok(EvalType::Value(map.get(name)))
161 }
162 EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
163 Ok(EvalType::Value(map.get(name)))
164 }
165 EvalType::ContextItem(ContextItem::Package(p)) => {
166 Ok(EvalType::ContextItem(p.lookup(name)?))
167 }
168 EvalType::ContextItem(ContextItem::Value(v)) => {
169 Ok(EvalType::MemberFn(v, ctx.lookup_member_fn(v, name)?))
170 }
171 _ => Err(CelError::IllegalTarget),
172 },
173 FunctionCall(exprs) => match target {
174 EvalType::ContextItem(ContextItem::Function(f)) => {
175 let mut args = Vec::new();
176 for e in exprs {
177 args.push(evaluate_expression(e, ctx)?.try_into_value()?)
178 }
179 Ok(EvalType::Value(f(args)?))
180 }
181 EvalType::ContextItem(ContextItem::Package(p)) => {
182 evaluate_member(EvalType::ContextItem(p.package_self()?), member, ctx)
183 }
184 EvalType::MemberFn(v, f) => {
185 let mut args = Vec::new();
186 for e in exprs {
187 args.push(evaluate_expression(e, ctx)?.try_into_value()?)
188 }
189 Ok(EvalType::Value(f(v, args)?))
190 }
191 _ => Err(CelError::IllegalTarget),
192 },
193 _ => unimplemented!(),
194 }
195}
196
197fn evaluate_arithmetic(
198 op: ArithmeticOp,
199 left: CelValue,
200 right: CelValue,
201) -> Result<CelValue, CelError> {
202 use CelValue::*;
203 match op {
204 ArithmeticOp::Multiply => match (&left, &right) {
205 (UInt(l), UInt(r)) => Ok(UInt(l * r)),
206 (Int(l), Int(r)) => Ok(Int(l * r)),
207 (Double(l), Double(r)) => Ok(Double(l * r)),
208 (Decimal(l), Decimal(r)) => Ok(Decimal(l * r)),
209 _ => Err(CelError::NoMatchingOverload(format!(
210 "Cannot apply '*' to {:?} and {:?}",
211 CelType::from(&left),
212 CelType::from(&right)
213 ))),
214 },
215 ArithmeticOp::Add => match (&left, &right) {
216 (UInt(l), UInt(r)) => Ok(UInt(l + r)),
217 (Int(l), Int(r)) => Ok(Int(l + r)),
218 (Double(l), Double(r)) => Ok(Double(l + r)),
219 (Decimal(l), Decimal(r)) => Ok(Decimal(l + r)),
220 _ => Err(CelError::NoMatchingOverload(format!(
221 "Cannot apply '+' to {:?} and {:?}",
222 CelType::from(&left),
223 CelType::from(&right)
224 ))),
225 },
226 ArithmeticOp::Subtract => match (&left, &right) {
227 (UInt(l), UInt(r)) => Ok(UInt(l - r)),
228 (Int(l), Int(r)) => Ok(Int(l - r)),
229 (Double(l), Double(r)) => Ok(Double(l - r)),
230 (Decimal(l), Decimal(r)) => Ok(Decimal(l - r)),
231 _ => Err(CelError::NoMatchingOverload(format!(
232 "Cannot apply '-' to {:?} and {:?}",
233 CelType::from(&left),
234 CelType::from(&right)
235 ))),
236 },
237 _ => unimplemented!(),
238 }
239}
240
241fn evaluate_relation(
242 op: RelationOp,
243 left: CelValue,
244 right: CelValue,
245) -> Result<CelValue, CelError> {
246 use CelValue::*;
247 match op {
248 RelationOp::LessThan => match (&left, &right) {
249 (UInt(l), UInt(r)) => Ok(Bool(l < r)),
250 (Int(l), Int(r)) => Ok(Bool(l < r)),
251 (Double(l), Double(r)) => Ok(Bool(l < r)),
252 (Decimal(l), Decimal(r)) => Ok(Bool(l < r)),
253 _ => Err(CelError::NoMatchingOverload(format!(
254 "Cannot apply '<' to {:?} and {:?}",
255 CelType::from(&left),
256 CelType::from(&right)
257 ))),
258 },
259 RelationOp::LessThanEq => match (&left, &right) {
260 (UInt(l), UInt(r)) => Ok(Bool(l <= r)),
261 (Int(l), Int(r)) => Ok(Bool(l <= r)),
262 (Double(l), Double(r)) => Ok(Bool(l <= r)),
263 (Decimal(l), Decimal(r)) => Ok(Bool(l <= r)),
264 _ => Err(CelError::NoMatchingOverload(format!(
265 "Cannot apply '<=' to {:?} and {:?}",
266 CelType::from(&left),
267 CelType::from(&right)
268 ))),
269 },
270 RelationOp::GreaterThan => match (&left, &right) {
271 (UInt(l), UInt(r)) => Ok(Bool(l > r)),
272 (Int(l), Int(r)) => Ok(Bool(l > r)),
273 (Double(l), Double(r)) => Ok(Bool(l > r)),
274 (Decimal(l), Decimal(r)) => Ok(Bool(l > r)),
275 _ => Err(CelError::NoMatchingOverload(format!(
276 "Cannot apply '>' to {:?} and {:?}",
277 CelType::from(&left),
278 CelType::from(&right)
279 ))),
280 },
281 RelationOp::GreaterThanEq => match (&left, &right) {
282 (UInt(l), UInt(r)) => Ok(Bool(l >= r)),
283 (Int(l), Int(r)) => Ok(Bool(l >= r)),
284 (Double(l), Double(r)) => Ok(Bool(l >= r)),
285 (Decimal(l), Decimal(r)) => Ok(Bool(l >= r)),
286 _ => Err(CelError::NoMatchingOverload(format!(
287 "Cannot apply '>=' to {:?} and {:?}",
288 CelType::from(&left),
289 CelType::from(&right)
290 ))),
291 },
292 RelationOp::Equals => match (&left, &right) {
293 (UInt(l), UInt(r)) => Ok(Bool(l == r)),
294 (Int(l), Int(r)) => Ok(Bool(l == r)),
295 (Double(l), Double(r)) => Ok(Bool(l == r)),
296 (Decimal(l), Decimal(r)) => Ok(Bool(l == r)),
297 _ => Err(CelError::NoMatchingOverload(format!(
298 "Cannot apply '==' to {:?} and {:?}",
299 CelType::from(&left),
300 CelType::from(&right)
301 ))),
302 },
303 RelationOp::NotEquals => match (&left, &right) {
304 (UInt(l), UInt(r)) => Ok(Bool(l != r)),
305 (Int(l), Int(r)) => Ok(Bool(l != r)),
306 (Double(l), Double(r)) => Ok(Bool(l != r)),
307 (Decimal(l), Decimal(r)) => Ok(Bool(l != r)),
308 _ => Err(CelError::NoMatchingOverload(format!(
309 "Cannot apply '!=' to {:?} and {:?}",
310 CelType::from(&left),
311 CelType::from(&right)
312 ))),
313 },
314 _ => unimplemented!(),
315 }
316}
317
318impl From<CelExpression> for String {
319 fn from(expr: CelExpression) -> Self {
320 expr.source
321 }
322}
323
324impl TryFrom<String> for CelExpression {
325 type Error = CelError;
326
327 fn try_from(source: String) -> Result<Self, Self::Error> {
328 let expr = ExpressionParser::new()
329 .parse(&source)
330 .map_err(|e| CelError::CelParseError(e.to_string()))?;
331 Ok(Self { source, expr })
332 }
333}
334impl TryFrom<&str> for CelExpression {
335 type Error = CelError;
336
337 fn try_from(source: &str) -> Result<Self, Self::Error> {
338 Self::try_from(source.to_string())
339 }
340}
341impl std::str::FromStr for CelExpression {
342 type Err = CelError;
343
344 fn from_str(source: &str) -> Result<Self, Self::Err> {
345 Self::try_from(source.to_string())
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use chrono::NaiveDate;
353
354 #[test]
355 fn literals() {
356 let expression = "true".parse::<CelExpression>().unwrap();
357 let context = CelContext::new();
358 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
359
360 let expression = "1".parse::<CelExpression>().unwrap();
361 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(1));
362
363 let expression = "-1".parse::<CelExpression>().unwrap();
364 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(-1));
365
366 let expression = "'hello'".parse::<CelExpression>().unwrap();
367 assert_eq!(
368 expression.evaluate(&context).unwrap(),
369 CelValue::String("hello".to_string().into())
370 );
371
372 }
376
377 #[test]
378 fn logic() {
379 let expression = "true || false ? false && true : true"
380 .parse::<CelExpression>()
381 .unwrap();
382 let context = CelContext::new();
383 assert_eq!(
384 expression.evaluate(&context).unwrap(),
385 CelValue::Bool(false)
386 );
387 let expression = "true && false ? false : true || false"
388 .parse::<CelExpression>()
389 .unwrap();
390 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true))
391 }
392
393 #[test]
394 fn lookup() {
395 let expression = "params.hello.world".parse::<CelExpression>().unwrap();
396 let mut hello = CelMap::new();
397 hello.insert("world", 42);
398 let mut params = CelMap::new();
399 params.insert("hello", hello);
400 let mut context = CelContext::new();
401 context.add_variable("params", params);
402 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(42));
403 }
404
405 #[test]
406 fn to_level_function() {
407 let expression = "date('2022-10-10')".parse::<CelExpression>().unwrap();
408 let context = CelContext::new();
409 assert_eq!(
410 expression.evaluate(&context).unwrap(),
411 CelValue::Date(NaiveDate::parse_from_str("2022-10-10", "%Y-%m-%d").unwrap())
412 );
413 }
414
415 #[test]
416 fn cast_function() {
417 let expression = "decimal('1')".parse::<CelExpression>().unwrap();
418 let context = CelContext::new();
419 assert_eq!(
420 expression.evaluate(&context).unwrap(),
421 CelValue::Decimal(1.into())
422 );
423 }
424
425 #[test]
426 fn package_function() -> anyhow::Result<()> {
427 let expression = "decimal.Add(decimal('1'), decimal('2'))"
428 .parse::<CelExpression>()
429 .unwrap();
430 let context = CelContext::new();
431 assert_eq!(expression.evaluate(&context)?, CelValue::Decimal(3.into()));
432 Ok(())
433 }
434
435 #[test]
436 fn function_on_timestamp() -> anyhow::Result<()> {
437 use chrono::{DateTime, Utc};
438
439 let time: DateTime<Utc> = "1940-12-21T00:00:00Z".parse().unwrap();
440 let mut context = CelContext::new();
441 context.add_variable("now", time);
442
443 let expression = "now.format('%d/%m/%Y')".parse::<CelExpression>().unwrap();
444 assert_eq!(expression.evaluate(&context)?, CelValue::from("21/12/1940"));
445
446 Ok(())
447 }
448}