1use serde::{Deserialize, Serialize};
2use tracing::instrument;
3
4use cel_parser::{
5 ast::{self, ArithmeticOp, Expression, RelationOp},
6 parse_expression,
7};
8
9use crate::{cel_type::*, context::*, error::*, value::*};
10
11#[derive(Debug, Clone, Deserialize, Serialize)]
12#[serde(try_from = "String")]
13#[serde(into = "String")]
14pub struct CelExpression {
15 source: String,
16 expr: Expression,
17}
18
19impl CelExpression {
20 pub fn try_evaluate<'a, T: TryFrom<CelResult<'a>, Error = ResultCoercionError>>(
21 &'a self,
22 ctx: &CelContext,
23 ) -> Result<T, CelError> {
24 let res = self.evaluate(ctx)?;
25 Ok(T::try_from(CelResult {
26 expr: &self.expr,
27 val: res,
28 })?)
29 }
30
31 #[instrument(name = "cel.evaluate", skip_all, fields(expression = %self.source, context = tracing::field::Empty, result = tracing::field::Empty), err)]
32 pub fn evaluate(&self, ctx: &CelContext) -> Result<CelValue, CelError> {
33 let context_debug = ctx.debug_context();
35 if !context_debug.is_empty() {
36 tracing::Span::current().record("context", &context_debug);
37 }
38
39 let result = match evaluate_expression(&self.expr, ctx)? {
40 EvalType::Value(val) => Ok(val),
41 EvalType::ContextItem(ContextItem::Value(val)) => Ok(val.clone()),
42 _ => Err(CelError::Unexpected(
43 "evaluate didn't return a value".to_string(),
44 )),
45 }?;
46
47 tracing::Span::current().record("result", format!("{:?}", result));
49
50 Ok(result)
51 }
52}
53
54impl std::fmt::Display for CelExpression {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.source)
57 }
58}
59
60enum EvalType<'a> {
61 Value(CelValue),
62 ContextItem(&'a ContextItem),
63 MemberFn(&'a CelValue, &'a CelMemberFunction),
64}
65
66impl EvalType<'_> {
67 fn try_into_bool(self) -> Result<bool, CelError> {
68 if let EvalType::Value(val) = self {
69 val.try_bool()
70 } else {
71 Err(CelError::Unexpected(
72 "Expression didn't resolve to a bool".to_string(),
73 ))
74 }
75 }
76
77 fn try_into_key(self) -> Result<CelKey, CelError> {
78 if let EvalType::Value(val) = self {
79 match val {
80 CelValue::Int(i) => Ok(CelKey::Int(i)),
81 CelValue::UInt(u) => Ok(CelKey::UInt(u)),
82 CelValue::Bool(b) => Ok(CelKey::Bool(b)),
83 CelValue::String(s) => Ok(CelKey::String(s)),
84 _ => Err(CelError::Unexpected(
85 "Expression didn't resolve to a valid key".to_string(),
86 )),
87 }
88 } else {
89 Err(CelError::Unexpected(
90 "Expression didn't resolve to value".to_string(),
91 ))
92 }
93 }
94
95 fn try_into_value(self) -> Result<CelValue, CelError> {
96 if let EvalType::Value(val) = self {
97 Ok(val)
98 } else {
99 Err(CelError::Unexpected("Couldn't unwrap value".to_string()))
100 }
101 }
102}
103
104#[instrument(name = "cel.evaluate_expression", skip_all, level = "debug", err)]
105fn evaluate_expression<'a>(
106 expr: &Expression,
107 ctx: &'a CelContext,
108) -> Result<EvalType<'a>, CelError> {
109 match evaluate_expression_inner(expr, ctx) {
110 Ok(val) => Ok(val),
111 Err(e) => Err(CelError::EvaluationError(format!("{expr:?}"), Box::new(e))),
112 }
113}
114
115#[instrument(name = "cel.evaluate_expr", skip_all, level = "debug", err)]
116fn evaluate_expression_inner<'a>(
117 expr: &Expression,
118 ctx: &'a CelContext,
119) -> Result<EvalType<'a>, CelError> {
120 use Expression::*;
121 match expr {
122 Ternary(cond, left, right) => {
123 if evaluate_expression(cond, ctx)?.try_into_bool()? {
124 evaluate_expression(left, ctx)
125 } else {
126 evaluate_expression(right, ctx)
127 }
128 }
129 Member(expr, member) => {
130 let ident = evaluate_expression(expr, ctx)?;
131 evaluate_member(ident, member, ctx)
132 }
133 Has(expr) => {
134 fn extract_last_field(
140 expr: &Expression,
141 ) -> Option<(&Expression, &std::sync::Arc<String>)> {
142 match expr {
143 Expression::Member(target, member) => match member.as_ref() {
144 ast::Member::Attribute(field_name) => Some((target.as_ref(), field_name)),
145 _ => None,
146 },
147 _ => None,
148 }
149 }
150
151 if let Some((target_expr, field_name)) = extract_last_field(expr.as_ref()) {
152 let target = evaluate_expression(target_expr, ctx)?;
154
155 let has_field = match target {
157 EvalType::Value(CelValue::Map(map)) => map.contains_key(field_name.as_str()),
158 EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
159 map.contains_key(field_name.as_str())
160 }
161 _ => {
162 return Err(CelError::IllegalTarget);
164 }
165 };
166
167 Ok(EvalType::Value(CelValue::Bool(has_field)))
168 } else {
169 Err(CelError::Unexpected(
170 "has() expects a member expression".to_string(),
171 ))
172 }
173 }
174 Map(entries) => {
175 let mut map = CelMap::new();
176 for (k, v) in entries {
177 let key = evaluate_expression(k, ctx)?;
178 let value = evaluate_expression(v, ctx)?;
179 map.insert(key.try_into_key()?, value.try_into_value()?)
180 }
181 Ok(EvalType::Value(CelValue::from(map)))
182 }
183 Ident(name) => Ok(EvalType::ContextItem(ctx.lookup_ident(name)?)),
184 Literal(val) => Ok(EvalType::Value(CelValue::from(val))),
185 Arithmetic(op, left, right) => {
186 let left = evaluate_expression(left, ctx)?;
187 let right = evaluate_expression(right, ctx)?;
188 Ok(EvalType::Value(evaluate_arithmetic(
189 *op,
190 left.try_into_value()?,
191 right.try_into_value()?,
192 )?))
193 }
194 Relation(op, left, right) => {
195 let left = evaluate_expression(left, ctx)?;
196 let right = evaluate_expression(right, ctx)?;
197 Ok(EvalType::Value(evaluate_relation(
198 *op,
199 left.try_into_value()?,
200 right.try_into_value()?,
201 )?))
202 }
203 Unary(op, expr) => {
204 use ast::UnaryOp;
205 match op {
206 UnaryOp::Not => {
207 let val = evaluate_expression(expr, ctx)?.try_into_bool()?;
208 Ok(EvalType::Value(CelValue::Bool(!val)))
209 }
210 _ => Err(CelError::Unexpected(format!(
211 "unimplemented unary op: {op:?}"
212 ))),
213 }
214 }
215 e => Err(CelError::Unexpected(format!("unimplemented {e:?}"))),
216 }
217}
218
219#[instrument(name = "cel.evaluate_member", skip_all, level = "debug", err)]
220fn evaluate_member<'a>(
221 target: EvalType<'a>,
222 member: &ast::Member,
223 ctx: &'a CelContext,
224) -> Result<EvalType<'a>, CelError> {
225 use ast::Member::*;
226 match member {
227 Attribute(name) => match target {
228 EvalType::Value(CelValue::Map(map)) if map.contains_key(name) => {
229 Ok(EvalType::Value(map.get(name)))
230 }
231 EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
232 Ok(EvalType::Value(map.get(name)))
233 }
234 EvalType::ContextItem(ContextItem::Package(p)) => {
235 Ok(EvalType::ContextItem(p.lookup(name)?))
236 }
237 EvalType::ContextItem(ContextItem::Value(v)) => {
238 Ok(EvalType::MemberFn(v, ctx.lookup_member_fn(v, name)?))
239 }
240 _ => Err(CelError::IllegalTarget),
241 },
242 FunctionCall(exprs) => match target {
243 EvalType::ContextItem(ContextItem::Function(f)) => {
244 let mut args = Vec::new();
245 for e in exprs {
246 args.push(evaluate_expression(e, ctx)?.try_into_value()?)
247 }
248 Ok(EvalType::Value(f(args)?))
249 }
250 EvalType::ContextItem(ContextItem::Package(p)) => {
251 evaluate_member(EvalType::ContextItem(p.package_self()?), member, ctx)
252 }
253 EvalType::MemberFn(v, f) => {
254 let mut args = Vec::new();
255 for e in exprs {
256 args.push(evaluate_expression(e, ctx)?.try_into_value()?)
257 }
258 Ok(EvalType::Value(f(v, args)?))
259 }
260 _ => Err(CelError::IllegalTarget),
261 },
262 _ => unimplemented!(),
263 }
264}
265
266#[instrument(name = "cel.evaluate_arithmetic", skip_all, level = "debug", err)]
267fn evaluate_arithmetic(
268 op: ArithmeticOp,
269 left: CelValue,
270 right: CelValue,
271) -> Result<CelValue, CelError> {
272 use CelValue::*;
273 match op {
274 ArithmeticOp::Multiply => match (&left, &right) {
275 (UInt(l), UInt(r)) => Ok(UInt(l * r)),
276 (Int(l), Int(r)) => Ok(Int(l * r)),
277 (Double(l), Double(r)) => Ok(Double(l * r)),
278 (Decimal(l), Decimal(r)) => Ok(Decimal(l * r)),
279 _ => Err(CelError::NoMatchingOverload(format!(
280 "Cannot apply '*' to {:?} and {:?}",
281 CelType::from(&left),
282 CelType::from(&right)
283 ))),
284 },
285 ArithmeticOp::Add => match (&left, &right) {
286 (UInt(l), UInt(r)) => Ok(UInt(l + r)),
287 (Int(l), Int(r)) => Ok(Int(l + r)),
288 (Double(l), Double(r)) => Ok(Double(l + r)),
289 (Decimal(l), Decimal(r)) => Ok(Decimal(l + r)),
290 _ => Err(CelError::NoMatchingOverload(format!(
291 "Cannot apply '+' to {:?} and {:?}",
292 CelType::from(&left),
293 CelType::from(&right)
294 ))),
295 },
296 ArithmeticOp::Subtract => match (&left, &right) {
297 (UInt(l), UInt(r)) => Ok(UInt(l - r)),
298 (Int(l), Int(r)) => Ok(Int(l - r)),
299 (Double(l), Double(r)) => Ok(Double(l - r)),
300 (Decimal(l), Decimal(r)) => Ok(Decimal(l - r)),
301 _ => Err(CelError::NoMatchingOverload(format!(
302 "Cannot apply '-' to {:?} and {:?}",
303 CelType::from(&left),
304 CelType::from(&right)
305 ))),
306 },
307 _ => unimplemented!(),
308 }
309}
310
311#[instrument(name = "cel.evaluate_relation", skip_all, level = "debug", err)]
312fn evaluate_relation(
313 op: RelationOp,
314 left: CelValue,
315 right: CelValue,
316) -> Result<CelValue, CelError> {
317 use CelValue::*;
318 match op {
319 RelationOp::LessThan => match (&left, &right) {
320 (UInt(l), UInt(r)) => Ok(Bool(l < r)),
321 (Int(l), Int(r)) => Ok(Bool(l < r)),
322 (Double(l), Double(r)) => Ok(Bool(l < r)),
323 (Decimal(l), Decimal(r)) => Ok(Bool(l < r)),
324 (Date(l), Date(r)) => Ok(Bool(l < r)),
325 (Timestamp(l), Timestamp(r)) => Ok(Bool(l < r)),
326 _ => Err(CelError::NoMatchingOverload(format!(
327 "Cannot apply '<' to {:?} and {:?}",
328 CelType::from(&left),
329 CelType::from(&right)
330 ))),
331 },
332 RelationOp::LessThanEq => match (&left, &right) {
333 (UInt(l), UInt(r)) => Ok(Bool(l <= r)),
334 (Int(l), Int(r)) => Ok(Bool(l <= r)),
335 (Double(l), Double(r)) => Ok(Bool(l <= r)),
336 (Decimal(l), Decimal(r)) => Ok(Bool(l <= r)),
337 (Date(l), Date(r)) => Ok(Bool(l <= r)),
338 (Timestamp(l), Timestamp(r)) => Ok(Bool(l <= r)),
339 _ => Err(CelError::NoMatchingOverload(format!(
340 "Cannot apply '<=' to {:?} and {:?}",
341 CelType::from(&left),
342 CelType::from(&right)
343 ))),
344 },
345 RelationOp::GreaterThan => match (&left, &right) {
346 (UInt(l), UInt(r)) => Ok(Bool(l > r)),
347 (Int(l), Int(r)) => Ok(Bool(l > r)),
348 (Double(l), Double(r)) => Ok(Bool(l > r)),
349 (Decimal(l), Decimal(r)) => Ok(Bool(l > r)),
350 (Date(l), Date(r)) => Ok(Bool(l > r)),
351 (Timestamp(l), Timestamp(r)) => Ok(Bool(l > r)),
352 _ => Err(CelError::NoMatchingOverload(format!(
353 "Cannot apply '>' to {:?} and {:?}",
354 CelType::from(&left),
355 CelType::from(&right)
356 ))),
357 },
358 RelationOp::GreaterThanEq => match (&left, &right) {
359 (UInt(l), UInt(r)) => Ok(Bool(l >= r)),
360 (Int(l), Int(r)) => Ok(Bool(l >= r)),
361 (Double(l), Double(r)) => Ok(Bool(l >= r)),
362 (Decimal(l), Decimal(r)) => Ok(Bool(l >= r)),
363 (Date(l), Date(r)) => Ok(Bool(l >= r)),
364 (Timestamp(l), Timestamp(r)) => Ok(Bool(l >= r)),
365 _ => Err(CelError::NoMatchingOverload(format!(
366 "Cannot apply '>=' to {:?} and {:?}",
367 CelType::from(&left),
368 CelType::from(&right)
369 ))),
370 },
371 RelationOp::Equals => match (&left, &right) {
372 (UInt(l), UInt(r)) => Ok(Bool(l == r)),
373 (Int(l), Int(r)) => Ok(Bool(l == r)),
374 (Double(l), Double(r)) => Ok(Bool(l == r)),
375 (Decimal(l), Decimal(r)) => Ok(Bool(l == r)),
376 (Date(l), Date(r)) => Ok(Bool(l == r)),
377 (Timestamp(l), Timestamp(r)) => Ok(Bool(l == r)),
378 _ => Err(CelError::NoMatchingOverload(format!(
379 "Cannot apply '==' to {:?} and {:?}",
380 CelType::from(&left),
381 CelType::from(&right)
382 ))),
383 },
384 RelationOp::NotEquals => match (&left, &right) {
385 (UInt(l), UInt(r)) => Ok(Bool(l != r)),
386 (Int(l), Int(r)) => Ok(Bool(l != r)),
387 (Double(l), Double(r)) => Ok(Bool(l != r)),
388 (Decimal(l), Decimal(r)) => Ok(Bool(l != r)),
389 (Date(l), Date(r)) => Ok(Bool(l != r)),
390 (Timestamp(l), Timestamp(r)) => Ok(Bool(l != r)),
391 _ => Err(CelError::NoMatchingOverload(format!(
392 "Cannot apply '!=' to {:?} and {:?}",
393 CelType::from(&left),
394 CelType::from(&right)
395 ))),
396 },
397 _ => unimplemented!(),
398 }
399}
400
401impl From<CelExpression> for String {
402 fn from(expr: CelExpression) -> Self {
403 expr.source
404 }
405}
406
407impl TryFrom<String> for CelExpression {
408 type Error = CelError;
409
410 fn try_from(source: String) -> Result<Self, Self::Error> {
411 let expr = parse_expression(source.clone()).map_err(CelError::CelParseError)?;
412 Ok(Self { source, expr })
413 }
414}
415impl TryFrom<&str> for CelExpression {
416 type Error = CelError;
417
418 fn try_from(source: &str) -> Result<Self, Self::Error> {
419 Self::try_from(source.to_string())
420 }
421}
422impl std::str::FromStr for CelExpression {
423 type Err = CelError;
424
425 fn from_str(source: &str) -> Result<Self, Self::Err> {
426 Self::try_from(source.to_string())
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use chrono::NaiveDate;
434
435 #[test]
436 fn literals() {
437 let expression = "true".parse::<CelExpression>().unwrap();
438 let context = CelContext::new();
439 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
440
441 let expression = "1".parse::<CelExpression>().unwrap();
442 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(1));
443
444 let expression = "-1".parse::<CelExpression>().unwrap();
445 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(-1));
446
447 let expression = "'hello'".parse::<CelExpression>().unwrap();
448 assert_eq!(
449 expression.evaluate(&context).unwrap(),
450 CelValue::String("hello".to_string().into())
451 );
452
453 }
457
458 #[test]
459 fn logic() {
460 let expression = "true || false ? false && true : true"
461 .parse::<CelExpression>()
462 .unwrap();
463 let context = CelContext::new();
464 assert_eq!(
465 expression.evaluate(&context).unwrap(),
466 CelValue::Bool(false)
467 );
468 let expression = "true && false ? false : true || false"
469 .parse::<CelExpression>()
470 .unwrap();
471 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true))
472 }
473
474 #[test]
475 fn lookup() {
476 let expression = "params.hello.world".parse::<CelExpression>().unwrap();
477 let mut hello = CelMap::new();
478 hello.insert("world", 42);
479 let mut params = CelMap::new();
480 params.insert("hello", hello);
481 let mut context = CelContext::new();
482 context.add_variable("params", params);
483 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Int(42));
484 }
485
486 #[test]
487 fn to_level_function() {
488 let expression = "date('2022-10-10')".parse::<CelExpression>().unwrap();
489 let context = CelContext::new();
490 assert_eq!(
491 expression.evaluate(&context).unwrap(),
492 CelValue::Date(NaiveDate::parse_from_str("2022-10-10", "%Y-%m-%d").unwrap())
493 );
494 }
495
496 #[test]
497 fn cast_function() {
498 let expression = "decimal('1')".parse::<CelExpression>().unwrap();
499 let context = CelContext::new();
500 assert_eq!(
501 expression.evaluate(&context).unwrap(),
502 CelValue::Decimal(1.into())
503 );
504 }
505
506 #[test]
507 fn package_function() -> anyhow::Result<()> {
508 let expression = "decimal.Add(decimal('1'), decimal('2'))"
509 .parse::<CelExpression>()
510 .unwrap();
511 let context = CelContext::new();
512 assert_eq!(expression.evaluate(&context)?, CelValue::Decimal(3.into()));
513 Ok(())
514 }
515
516 #[test]
517 fn has_macro_with_map() {
518 let expression = "has(params.hello)".parse::<CelExpression>().unwrap();
520 let mut params = CelMap::new();
521 params.insert("hello", "world");
522 let mut context = CelContext::new();
523 context.add_variable("params", params);
524 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
525
526 let expression = "has(params.missing)".parse::<CelExpression>().unwrap();
528 let mut params = CelMap::new();
529 params.insert("hello", "world");
530 let mut context = CelContext::new();
531 context.add_variable("params", params);
532 assert_eq!(
533 expression.evaluate(&context).unwrap(),
534 CelValue::Bool(false)
535 );
536
537 let expression = "has(params.nested.field)".parse::<CelExpression>().unwrap();
539 let mut nested = CelMap::new();
540 nested.insert("field", 42);
541 let mut params = CelMap::new();
542 params.insert("nested", nested);
543 let mut context = CelContext::new();
544 context.add_variable("params", params);
545 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
546
547 let expression = "has(config.database.settings.maxConnections)"
549 .parse::<CelExpression>()
550 .unwrap();
551 let mut settings = CelMap::new();
552 settings.insert("maxConnections", 100);
553 settings.insert("timeout", 30);
554 let mut database = CelMap::new();
555 database.insert("settings", settings);
556 let mut config = CelMap::new();
557 config.insert("database", database);
558 let mut context = CelContext::new();
559 context.add_variable("config", config);
560 assert_eq!(expression.evaluate(&context).unwrap(), CelValue::Bool(true));
561
562 let expression = "has(config.database.settings.missingField)"
564 .parse::<CelExpression>()
565 .unwrap();
566 let mut settings = CelMap::new();
567 settings.insert("maxConnections", 100);
568 let mut database = CelMap::new();
569 database.insert("settings", settings);
570 let mut config = CelMap::new();
571 config.insert("database", database);
572 let mut context = CelContext::new();
573 context.add_variable("config", config);
574 assert_eq!(
575 expression.evaluate(&context).unwrap(),
576 CelValue::Bool(false)
577 );
578 }
579
580 #[test]
581 fn function_on_timestamp() -> anyhow::Result<()> {
582 use chrono::{DateTime, Utc};
583
584 let time: DateTime<Utc> = "1940-12-21T00:00:00Z".parse().unwrap();
585 let mut context = CelContext::new();
586 context.add_variable("now", time);
587
588 let expression = "now.format('%d/%m/%Y')".parse::<CelExpression>().unwrap();
589 assert_eq!(expression.evaluate(&context)?, CelValue::from("21/12/1940"));
590
591 Ok(())
592 }
593}