cel/
objects.rs

1use crate::common::ast::{operators, EntryExpr, Expr};
2use crate::context::Context;
3use crate::functions::FunctionContext;
4use crate::{ExecutionError, Expression};
5use std::cmp::Ordering;
6use std::collections::HashMap;
7use std::convert::{Infallible, TryFrom, TryInto};
8use std::fmt::{Display, Formatter};
9use std::ops;
10use std::ops::Deref;
11use std::sync::Arc;
12#[cfg(feature = "chrono")]
13use std::sync::LazyLock;
14
15use crate::common::value::CelVal;
16#[cfg(feature = "chrono")]
17use chrono::TimeZone;
18
19/// Timestamp values are limited to the range of values which can be serialized as a string:
20/// `["0001-01-01T00:00:00Z", "9999-12-31T23:59:59.999999999Z"]`. Since the max is a smaller
21/// and the min is a larger timestamp than what is possible to represent with [`DateTime`],
22/// we need to perform our own spec-compliant overflow checks.
23///
24/// https://github.com/google/cel-spec/blob/master/doc/langdef.md#overflow
25#[cfg(feature = "chrono")]
26static MAX_TIMESTAMP: LazyLock<chrono::DateTime<chrono::FixedOffset>> = LazyLock::new(|| {
27    let naive = chrono::NaiveDate::from_ymd_opt(9999, 12, 31)
28        .unwrap()
29        .and_hms_nano_opt(23, 59, 59, 999_999_999)
30        .unwrap();
31    chrono::FixedOffset::east_opt(0)
32        .unwrap()
33        .from_utc_datetime(&naive)
34});
35
36#[cfg(feature = "chrono")]
37static MIN_TIMESTAMP: LazyLock<chrono::DateTime<chrono::FixedOffset>> = LazyLock::new(|| {
38    let naive = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
39        .unwrap()
40        .and_hms_opt(0, 0, 0)
41        .unwrap();
42    chrono::FixedOffset::east_opt(0)
43        .unwrap()
44        .from_utc_datetime(&naive)
45});
46
47#[derive(Debug, PartialEq, Clone)]
48#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
49pub struct Map {
50    pub map: Arc<HashMap<Key, Value>>,
51}
52
53impl PartialOrd for Map {
54    fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
55        None
56    }
57}
58
59impl Map {
60    /// Returns a reference to the value corresponding to the key. Implicitly converts between int
61    /// and uint keys.
62    pub fn get(&self, key: &Key) -> Option<&Value> {
63        self.map.get(key).or_else(|| {
64            // Also check keys that are cross type comparable.
65            let converted = match key {
66                Key::Int(k) => Key::Uint(u64::try_from(*k).ok()?),
67                Key::Uint(k) => Key::Int(i64::try_from(*k).ok()?),
68                _ => return None,
69            };
70            self.map.get(&converted)
71        })
72    }
73}
74
75#[derive(Debug, Eq, PartialEq, Hash, Ord, Clone, PartialOrd)]
76#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
77pub enum Key {
78    Int(i64),
79    Uint(u64),
80    Bool(bool),
81    String(Arc<String>),
82}
83
84/// Implement conversions from primitive types to [`Key`]
85impl From<String> for Key {
86    fn from(v: String) -> Self {
87        Key::String(v.into())
88    }
89}
90
91impl From<Arc<String>> for Key {
92    fn from(v: Arc<String>) -> Self {
93        Key::String(v)
94    }
95}
96
97impl<'a> From<&'a str> for Key {
98    fn from(v: &'a str) -> Self {
99        Key::String(Arc::new(v.into()))
100    }
101}
102
103impl From<bool> for Key {
104    fn from(v: bool) -> Self {
105        Key::Bool(v)
106    }
107}
108
109impl From<i64> for Key {
110    fn from(v: i64) -> Self {
111        Key::Int(v)
112    }
113}
114
115impl From<u64> for Key {
116    fn from(v: u64) -> Self {
117        Key::Uint(v)
118    }
119}
120
121impl serde::Serialize for Key {
122    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
123    where
124        S: serde::Serializer,
125    {
126        match self {
127            Key::Int(v) => v.serialize(serializer),
128            Key::Uint(v) => v.serialize(serializer),
129            Key::Bool(v) => v.serialize(serializer),
130            Key::String(v) => v.serialize(serializer),
131        }
132    }
133}
134
135impl Display for Key {
136    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
137        match self {
138            Key::Int(v) => write!(f, "{v}"),
139            Key::Uint(v) => write!(f, "{v}"),
140            Key::Bool(v) => write!(f, "{v}"),
141            Key::String(v) => write!(f, "{v}"),
142        }
143    }
144}
145
146/// Implement conversions from [`Key`] into [`Value`]
147impl TryInto<Key> for Value {
148    type Error = Value;
149
150    #[inline(always)]
151    fn try_into(self) -> Result<Key, Self::Error> {
152        match self {
153            Value::Int(v) => Ok(Key::Int(v)),
154            Value::UInt(v) => Ok(Key::Uint(v)),
155            Value::String(v) => Ok(Key::String(v)),
156            Value::Bool(v) => Ok(Key::Bool(v)),
157            _ => Err(self),
158        }
159    }
160}
161
162// Implement conversion from HashMap<K, V> into CelMap
163impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Map {
164    fn from(map: HashMap<K, V>) -> Self {
165        let mut new_map = HashMap::with_capacity(map.len());
166        for (k, v) in map {
167            new_map.insert(k.into(), v.into());
168        }
169        Map {
170            map: Arc::new(new_map),
171        }
172    }
173}
174
175pub trait TryIntoValue {
176    type Error: std::error::Error + 'static + Send + Sync;
177    fn try_into_value(self) -> Result<Value, Self::Error>;
178}
179
180impl<T: serde::Serialize> TryIntoValue for T {
181    type Error = crate::ser::SerializationError;
182    fn try_into_value(self) -> Result<Value, Self::Error> {
183        crate::ser::to_value(self)
184    }
185}
186impl TryIntoValue for Value {
187    type Error = Infallible;
188    fn try_into_value(self) -> Result<Value, Self::Error> {
189        Ok(self)
190    }
191}
192
193#[derive(Debug, Clone)]
194#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
195pub enum Value {
196    List(Arc<Vec<Value>>),
197    Map(Map),
198
199    Function(Arc<String>, Option<Box<Value>>),
200
201    // Atoms
202    Int(i64),
203    UInt(u64),
204    Float(f64),
205    String(Arc<String>),
206    Bytes(Arc<Vec<u8>>),
207    Bool(bool),
208    #[cfg(feature = "chrono")]
209    Duration(chrono::Duration),
210    #[cfg(feature = "chrono")]
211    Timestamp(chrono::DateTime<chrono::FixedOffset>),
212    Null,
213}
214
215impl From<CelVal> for Value {
216    fn from(val: CelVal) -> Self {
217        match val {
218            CelVal::String(s) => Value::String(Arc::new(s)),
219            CelVal::Boolean(b) => Value::Bool(b),
220            CelVal::Int(i) => Value::Int(i),
221            CelVal::UInt(u) => Value::UInt(u),
222            CelVal::Double(d) => Value::Float(d),
223            CelVal::Bytes(bytes) => Value::Bytes(Arc::new(bytes)),
224            CelVal::Null => Value::Null,
225            v => unimplemented!("{v:?}"),
226        }
227    }
228}
229
230#[derive(Clone, Copy, Debug)]
231pub enum ValueType {
232    List,
233    Map,
234    Function,
235    Int,
236    UInt,
237    Float,
238    String,
239    Bytes,
240    Bool,
241    Duration,
242    Timestamp,
243    Null,
244}
245
246impl Display for ValueType {
247    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
248        match self {
249            ValueType::List => write!(f, "list"),
250            ValueType::Map => write!(f, "map"),
251            ValueType::Function => write!(f, "function"),
252            ValueType::Int => write!(f, "int"),
253            ValueType::UInt => write!(f, "uint"),
254            ValueType::Float => write!(f, "float"),
255            ValueType::String => write!(f, "string"),
256            ValueType::Bytes => write!(f, "bytes"),
257            ValueType::Bool => write!(f, "bool"),
258            ValueType::Duration => write!(f, "duration"),
259            ValueType::Timestamp => write!(f, "timestamp"),
260            ValueType::Null => write!(f, "null"),
261        }
262    }
263}
264
265impl Value {
266    pub fn type_of(&self) -> ValueType {
267        match self {
268            Value::List(_) => ValueType::List,
269            Value::Map(_) => ValueType::Map,
270            Value::Function(_, _) => ValueType::Function,
271            Value::Int(_) => ValueType::Int,
272            Value::UInt(_) => ValueType::UInt,
273            Value::Float(_) => ValueType::Float,
274            Value::String(_) => ValueType::String,
275            Value::Bytes(_) => ValueType::Bytes,
276            Value::Bool(_) => ValueType::Bool,
277            #[cfg(feature = "chrono")]
278            Value::Duration(_) => ValueType::Duration,
279            #[cfg(feature = "chrono")]
280            Value::Timestamp(_) => ValueType::Timestamp,
281            Value::Null => ValueType::Null,
282        }
283    }
284
285    pub fn error_expected_type(&self, expected: ValueType) -> ExecutionError {
286        ExecutionError::UnexpectedType {
287            got: self.type_of().to_string(),
288            want: expected.to_string(),
289        }
290    }
291}
292
293impl From<&Value> for Value {
294    fn from(value: &Value) -> Self {
295        value.clone()
296    }
297}
298
299impl PartialEq for Value {
300    fn eq(&self, other: &Self) -> bool {
301        match (self, other) {
302            (Value::Map(a), Value::Map(b)) => a == b,
303            (Value::List(a), Value::List(b)) => a == b,
304            (Value::Function(a1, a2), Value::Function(b1, b2)) => a1 == b1 && a2 == b2,
305            (Value::Int(a), Value::Int(b)) => a == b,
306            (Value::UInt(a), Value::UInt(b)) => a == b,
307            (Value::Float(a), Value::Float(b)) => a == b,
308            (Value::String(a), Value::String(b)) => a == b,
309            (Value::Bytes(a), Value::Bytes(b)) => a == b,
310            (Value::Bool(a), Value::Bool(b)) => a == b,
311            (Value::Null, Value::Null) => true,
312            #[cfg(feature = "chrono")]
313            (Value::Duration(a), Value::Duration(b)) => a == b,
314            #[cfg(feature = "chrono")]
315            (Value::Timestamp(a), Value::Timestamp(b)) => a == b,
316            // Allow different numeric types to be compared without explicit casting.
317            (Value::Int(a), Value::UInt(b)) => a
318                .to_owned()
319                .try_into()
320                .map(|a: u64| a == *b)
321                .unwrap_or(false),
322            (Value::Int(a), Value::Float(b)) => (*a as f64) == *b,
323            (Value::UInt(a), Value::Int(b)) => a
324                .to_owned()
325                .try_into()
326                .map(|a: i64| a == *b)
327                .unwrap_or(false),
328            (Value::UInt(a), Value::Float(b)) => (*a as f64) == *b,
329            (Value::Float(a), Value::Int(b)) => *a == (*b as f64),
330            (Value::Float(a), Value::UInt(b)) => *a == (*b as f64),
331            (_, _) => false,
332        }
333    }
334}
335
336impl Eq for Value {}
337
338impl PartialOrd for Value {
339    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
340        match (self, other) {
341            (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
342            (Value::UInt(a), Value::UInt(b)) => Some(a.cmp(b)),
343            (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
344            (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
345            (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
346            (Value::Null, Value::Null) => Some(Ordering::Equal),
347            #[cfg(feature = "chrono")]
348            (Value::Duration(a), Value::Duration(b)) => Some(a.cmp(b)),
349            #[cfg(feature = "chrono")]
350            (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
351            // Allow different numeric types to be compared without explicit casting.
352            (Value::Int(a), Value::UInt(b)) => Some(
353                a.to_owned()
354                    .try_into()
355                    .map(|a: u64| a.cmp(b))
356                    // If the i64 doesn't fit into a u64 it must be less than 0.
357                    .unwrap_or(Ordering::Less),
358            ),
359            (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
360            (Value::UInt(a), Value::Int(b)) => Some(
361                a.to_owned()
362                    .try_into()
363                    .map(|a: i64| a.cmp(b))
364                    // If the u64 doesn't fit into a i64 it must be greater than i64::MAX.
365                    .unwrap_or(Ordering::Greater),
366            ),
367            (Value::UInt(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
368            (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
369            (Value::Float(a), Value::UInt(b)) => a.partial_cmp(&(*b as f64)),
370            _ => None,
371        }
372    }
373}
374
375impl From<&Key> for Value {
376    fn from(value: &Key) -> Self {
377        match value {
378            Key::Int(v) => Value::Int(*v),
379            Key::Uint(v) => Value::UInt(*v),
380            Key::Bool(v) => Value::Bool(*v),
381            Key::String(v) => Value::String(v.clone()),
382        }
383    }
384}
385
386impl From<Key> for Value {
387    fn from(value: Key) -> Self {
388        match value {
389            Key::Int(v) => Value::Int(v),
390            Key::Uint(v) => Value::UInt(v),
391            Key::Bool(v) => Value::Bool(v),
392            Key::String(v) => Value::String(v),
393        }
394    }
395}
396
397impl From<&Key> for Key {
398    fn from(key: &Key) -> Self {
399        key.clone()
400    }
401}
402
403// Convert Vec<T> to Value
404impl<T: Into<Value>> From<Vec<T>> for Value {
405    fn from(v: Vec<T>) -> Self {
406        Value::List(v.into_iter().map(|v| v.into()).collect::<Vec<_>>().into())
407    }
408}
409
410// Convert Vec<u8> to Value
411impl From<Vec<u8>> for Value {
412    fn from(v: Vec<u8>) -> Self {
413        Value::Bytes(v.into())
414    }
415}
416
417// Convert String to Value
418impl From<String> for Value {
419    fn from(v: String) -> Self {
420        Value::String(v.into())
421    }
422}
423
424impl From<&str> for Value {
425    fn from(v: &str) -> Self {
426        Value::String(v.to_string().into())
427    }
428}
429
430// Convert Option<T> to Value
431impl<T: Into<Value>> From<Option<T>> for Value {
432    fn from(v: Option<T>) -> Self {
433        match v {
434            Some(v) => v.into(),
435            None => Value::Null,
436        }
437    }
438}
439
440// Convert HashMap<K, V> to Value
441impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Value {
442    fn from(v: HashMap<K, V>) -> Self {
443        Value::Map(v.into())
444    }
445}
446
447impl From<ExecutionError> for ResolveResult {
448    fn from(value: ExecutionError) -> Self {
449        Err(value)
450    }
451}
452
453pub type ResolveResult = Result<Value, ExecutionError>;
454
455impl From<Value> for ResolveResult {
456    fn from(value: Value) -> Self {
457        Ok(value)
458    }
459}
460
461impl Value {
462    pub fn resolve_all(expr: &[Expression], ctx: &Context) -> ResolveResult {
463        let mut res = Vec::with_capacity(expr.len());
464        for expr in expr {
465            res.push(Value::resolve(expr, ctx)?);
466        }
467        Ok(Value::List(res.into()))
468    }
469
470    #[inline(always)]
471    pub fn resolve(expr: &Expression, ctx: &Context) -> ResolveResult {
472        match &expr.expr {
473            Expr::Literal(val) => Ok(val.clone().into()),
474            Expr::Call(call) => {
475                if call.args.len() == 3 && call.func_name == operators::CONDITIONAL {
476                    let cond = Value::resolve(&call.args[0], ctx)?;
477                    return if cond.to_bool() {
478                        Value::resolve(&call.args[1], ctx)
479                    } else {
480                        Value::resolve(&call.args[2], ctx)
481                    };
482                }
483                if call.args.len() == 2 {
484                    match call.func_name.as_str() {
485                        operators::ADD => {
486                            return Value::resolve(&call.args[0], ctx)?
487                                + Value::resolve(&call.args[1], ctx)?
488                        }
489                        operators::SUBSTRACT => {
490                            return Value::resolve(&call.args[0], ctx)?
491                                - Value::resolve(&call.args[1], ctx)?
492                        }
493                        operators::DIVIDE => {
494                            return Value::resolve(&call.args[0], ctx)?
495                                / Value::resolve(&call.args[1], ctx)?
496                        }
497                        operators::MULTIPLY => {
498                            return Value::resolve(&call.args[0], ctx)?
499                                * Value::resolve(&call.args[1], ctx)?
500                        }
501                        operators::MODULO => {
502                            return Value::resolve(&call.args[0], ctx)?
503                                % Value::resolve(&call.args[1], ctx)?
504                        }
505                        operators::EQUALS => {
506                            return Value::Bool(
507                                Value::resolve(&call.args[0], ctx)?
508                                    .eq(&Value::resolve(&call.args[1], ctx)?),
509                            )
510                            .into()
511                        }
512                        operators::NOT_EQUALS => {
513                            return Value::Bool(
514                                Value::resolve(&call.args[0], ctx)?
515                                    .ne(&Value::resolve(&call.args[1], ctx)?),
516                            )
517                            .into()
518                        }
519                        operators::LESS => {
520                            let left = Value::resolve(&call.args[0], ctx)?;
521                            let right = Value::resolve(&call.args[1], ctx)?;
522                            return Value::Bool(
523                                left.partial_cmp(&right)
524                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
525                                    == Ordering::Less,
526                            )
527                            .into();
528                        }
529                        operators::LESS_EQUALS => {
530                            let left = Value::resolve(&call.args[0], ctx)?;
531                            let right = Value::resolve(&call.args[1], ctx)?;
532                            return Value::Bool(
533                                left.partial_cmp(&right)
534                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
535                                    != Ordering::Greater,
536                            )
537                            .into();
538                        }
539                        operators::GREATER => {
540                            let left = Value::resolve(&call.args[0], ctx)?;
541                            let right = Value::resolve(&call.args[1], ctx)?;
542                            return Value::Bool(
543                                left.partial_cmp(&right)
544                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
545                                    == Ordering::Greater,
546                            )
547                            .into();
548                        }
549                        operators::GREATER_EQUALS => {
550                            let left = Value::resolve(&call.args[0], ctx)?;
551                            let right = Value::resolve(&call.args[1], ctx)?;
552                            return Value::Bool(
553                                left.partial_cmp(&right)
554                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
555                                    != Ordering::Less,
556                            )
557                            .into();
558                        }
559                        operators::IN => {
560                            let left = Value::resolve(&call.args[0], ctx)?;
561                            let right = Value::resolve(&call.args[1], ctx)?;
562                            match (left, right) {
563                                (Value::String(l), Value::String(r)) => {
564                                    return Value::Bool(r.contains(&*l)).into()
565                                }
566                                (any, Value::List(v)) => {
567                                    return Value::Bool(v.contains(&any)).into()
568                                }
569                                (any, Value::Map(m)) => match any.try_into() {
570                                    Ok(key) => return Value::Bool(m.map.contains_key(&key)).into(),
571                                    Err(_) => return Value::Bool(false).into(),
572                                },
573                                (left, right) => {
574                                    Err(ExecutionError::ValuesNotComparable(left, right))?
575                                }
576                            }
577                        }
578                        operators::LOGICAL_OR => {
579                            let left = Value::resolve(&call.args[0], ctx)?;
580                            return if left.to_bool() {
581                                left.into()
582                            } else {
583                                Value::resolve(&call.args[1], ctx)
584                            };
585                        }
586                        operators::LOGICAL_AND => {
587                            let left = Value::resolve(&call.args[0], ctx)?;
588                            return if !left.to_bool() {
589                                Value::Bool(false)
590                            } else {
591                                let right = Value::resolve(&call.args[1], ctx)?;
592                                Value::Bool(right.to_bool())
593                            }
594                            .into();
595                        }
596                        operators::INDEX => {
597                            let value = Value::resolve(&call.args[0], ctx)?;
598                            let idx = Value::resolve(&call.args[1], ctx)?;
599                            return match (value, idx) {
600                                (Value::List(items), Value::Int(idx)) => items
601                                    .get(idx as usize)
602                                    .cloned()
603                                    .unwrap_or(Value::Null)
604                                    .into(),
605                                (Value::String(str), Value::Int(idx)) => {
606                                    match str.get(idx as usize..(idx + 1) as usize) {
607                                        None => Ok(Value::Null),
608                                        Some(str) => Ok(Value::String(str.to_string().into())),
609                                    }
610                                }
611                                (Value::Map(map), Value::String(property)) => map
612                                    .get(&property.into())
613                                    .cloned()
614                                    .unwrap_or(Value::Null)
615                                    .into(),
616                                (Value::Map(map), Value::Bool(property)) => map
617                                    .get(&property.into())
618                                    .cloned()
619                                    .unwrap_or(Value::Null)
620                                    .into(),
621                                (Value::Map(map), Value::Int(property)) => map
622                                    .get(&property.into())
623                                    .cloned()
624                                    .unwrap_or(Value::Null)
625                                    .into(),
626                                (Value::Map(map), Value::UInt(property)) => map
627                                    .get(&property.into())
628                                    .cloned()
629                                    .unwrap_or(Value::Null)
630                                    .into(),
631                                (Value::Map(_), index) => {
632                                    Err(ExecutionError::UnsupportedMapIndex(index))
633                                }
634                                (Value::List(_), index) => {
635                                    Err(ExecutionError::UnsupportedListIndex(index))
636                                }
637                                (value, index) => {
638                                    Err(ExecutionError::UnsupportedIndex(value, index))
639                                }
640                            };
641                        }
642                        _ => (),
643                    }
644                }
645                if call.args.len() == 1 {
646                    let expr = Value::resolve(&call.args[0], ctx)?;
647                    match call.func_name.as_str() {
648                        operators::LOGICAL_NOT => return Ok(Value::Bool(!expr.to_bool())),
649                        operators::NEGATE => {
650                            return match expr {
651                                Value::Int(i) => Ok(Value::Int(-i)),
652                                Value::Float(f) => Ok(Value::Float(-f)),
653                                value => {
654                                    Err(ExecutionError::UnsupportedUnaryOperator("minus", value))
655                                }
656                            }
657                        }
658                        operators::NOT_STRICTLY_FALSE => {
659                            return match expr {
660                                Value::Bool(b) => Ok(Value::Bool(b)),
661                                _ => Ok(Value::Bool(true)),
662                            }
663                        }
664                        _ => (),
665                    }
666                }
667                let func = ctx.get_function(call.func_name.as_str()).ok_or_else(|| {
668                    ExecutionError::UndeclaredReference(call.func_name.clone().into())
669                })?;
670                match &call.target {
671                    None => {
672                        let mut ctx = FunctionContext::new(
673                            call.func_name.clone().into(),
674                            None,
675                            ctx,
676                            call.args.clone(),
677                        );
678                        (func)(&mut ctx)
679                    }
680                    Some(target) => {
681                        let mut ctx = FunctionContext::new(
682                            call.func_name.clone().into(),
683                            Some(Value::resolve(target, ctx)?),
684                            ctx,
685                            call.args.clone(),
686                        );
687                        (func)(&mut ctx)
688                    }
689                }
690            }
691            Expr::Ident(name) => ctx.get_variable(name),
692            Expr::Select(select) => {
693                let left = Value::resolve(select.operand.deref(), ctx)?;
694                if select.test {
695                    match &left {
696                        Value::Map(map) => {
697                            for key in map.map.deref().keys() {
698                                if key.to_string().eq(&select.field) {
699                                    return Ok(Value::Bool(true));
700                                }
701                            }
702                            Ok(Value::Bool(false))
703                        }
704                        _ => Ok(Value::Bool(false)),
705                    }
706                } else {
707                    left.member(&select.field, ctx)
708                }
709            }
710            Expr::List(list_expr) => {
711                let list = list_expr
712                    .elements
713                    .iter()
714                    .map(|i| Value::resolve(i, ctx))
715                    .collect::<Result<Vec<_>, _>>()?;
716                Value::List(list.into()).into()
717            }
718            Expr::Map(map_expr) => {
719                let mut map = HashMap::with_capacity(map_expr.entries.len());
720                for entry in map_expr.entries.iter() {
721                    let (k, v) = match &entry.expr {
722                        EntryExpr::StructField(_) => panic!("WAT?"),
723                        EntryExpr::MapEntry(e) => (&e.key, &e.value),
724                    };
725                    let key = Value::resolve(k, ctx)?
726                        .try_into()
727                        .map_err(ExecutionError::UnsupportedKeyType)?;
728                    let value = Value::resolve(v, ctx)?;
729                    map.insert(key, value);
730                }
731                Ok(Value::Map(Map {
732                    map: Arc::from(map),
733                }))
734            }
735            Expr::Comprehension(comprehension) => {
736                let accu_init = Value::resolve(&comprehension.accu_init, ctx)?;
737                let iter = Value::resolve(&comprehension.iter_range, ctx)?;
738                let mut ctx = ctx.new_inner_scope();
739                ctx.add_variable(&comprehension.accu_var, accu_init)
740                    .expect("Failed to add accu variable");
741
742                match iter {
743                    Value::List(items) => {
744                        for item in items.deref() {
745                            if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool() {
746                                break;
747                            }
748                            ctx.add_variable_from_value(&comprehension.iter_var, item.clone());
749                            let accu = Value::resolve(&comprehension.loop_step, &ctx)?;
750                            ctx.add_variable_from_value(&comprehension.accu_var, accu);
751                        }
752                    }
753                    Value::Map(map) => {
754                        for key in map.map.deref().keys() {
755                            if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool() {
756                                break;
757                            }
758                            ctx.add_variable_from_value(&comprehension.iter_var, key.clone());
759                            let accu = Value::resolve(&comprehension.loop_step, &ctx)?;
760                            ctx.add_variable_from_value(&comprehension.accu_var, accu);
761                        }
762                    }
763                    t => todo!("Support {t:?}"),
764                }
765                Value::resolve(&comprehension.result, &ctx)
766            }
767            Expr::Struct(_) => todo!("Support structs!"),
768            Expr::Unspecified => panic!("Can't evaluate Unspecified Expr"),
769        }
770    }
771
772    // >> a(b)
773    // Member(Ident("a"),
774    //        FunctionCall([Ident("b")]))
775    // >> a.b(c)
776    // Member(Member(Ident("a"),
777    //               Attribute("b")),
778    //        FunctionCall([Ident("c")]))
779
780    fn member(self, name: &str, ctx: &Context) -> ResolveResult {
781        // todo! Ideally we would avoid creating a String just to create a Key for lookup in the
782        // map, but this would require something like the `hashbrown` crate's `Equivalent` trait.
783        let name: Arc<String> = name.to_owned().into();
784
785        // This will always either be because we're trying to access
786        // a property on self, or a method on self.
787        let child = match self {
788            Value::Map(ref m) => m.map.get(&name.clone().into()).cloned(),
789            _ => None,
790        };
791
792        // If the property is both an attribute and a method, then we
793        // give priority to the property. Maybe we can implement lookahead
794        // to see if the next token is a function call?
795        if let Some(child) = child {
796            child.into()
797        } else if ctx.has_function(&name) {
798            Value::Function(name.clone(), Some(self.into())).into()
799        } else {
800            ExecutionError::NoSuchKey(name.clone()).into()
801        }
802    }
803
804    #[inline(always)]
805    fn to_bool(&self) -> bool {
806        match self {
807            Value::List(v) => !v.is_empty(),
808            Value::Map(v) => !v.map.is_empty(),
809            Value::Int(v) => *v != 0,
810            Value::UInt(v) => *v != 0,
811            Value::Float(v) => *v != 0.0,
812            Value::String(v) => !v.is_empty(),
813            Value::Bytes(v) => !v.is_empty(),
814            Value::Bool(v) => *v,
815            Value::Null => false,
816            #[cfg(feature = "chrono")]
817            Value::Duration(v) => v.num_nanoseconds().map(|n| n != 0).unwrap_or(false),
818            #[cfg(feature = "chrono")]
819            Value::Timestamp(v) => v.timestamp_nanos_opt().unwrap_or_default() > 0,
820            Value::Function(_, _) => false,
821        }
822    }
823}
824
825impl ops::Add<Value> for Value {
826    type Output = ResolveResult;
827
828    #[inline(always)]
829    fn add(self, rhs: Value) -> Self::Output {
830        match (self, rhs) {
831            (Value::Int(l), Value::Int(r)) => l
832                .checked_add(r)
833                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
834                .map(Value::Int),
835
836            (Value::UInt(l), Value::UInt(r)) => l
837                .checked_add(r)
838                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
839                .map(Value::UInt),
840
841            (Value::Float(l), Value::Float(r)) => Value::Float(l + r).into(),
842
843            (Value::List(mut l), Value::List(mut r)) => {
844                {
845                    // If this is the only reference to `l`, we can append to it in place.
846                    // `l` is replaced with a clone otherwise.
847                    let l = Arc::make_mut(&mut l);
848
849                    // Likewise, if this is the only reference to `r`, we can move its values
850                    // instead of cloning them.
851                    match Arc::get_mut(&mut r) {
852                        Some(r) => l.append(r),
853                        None => l.extend(r.iter().cloned()),
854                    }
855                }
856
857                Ok(Value::List(l))
858            }
859            (Value::String(mut l), Value::String(r)) => {
860                // If this is the only reference to `l`, we can append to it in place.
861                // `l` is replaced with a clone otherwise.
862                Arc::make_mut(&mut l).push_str(&r);
863                Ok(Value::String(l))
864            }
865            #[cfg(feature = "chrono")]
866            (Value::Duration(l), Value::Duration(r)) => l
867                .checked_add(&r)
868                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
869                .map(Value::Duration),
870            #[cfg(feature = "chrono")]
871            (Value::Timestamp(l), Value::Duration(r)) => checked_op(TsOp::Add, &l, &r),
872            #[cfg(feature = "chrono")]
873            (Value::Duration(l), Value::Timestamp(r)) => r
874                .checked_add_signed(l)
875                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
876                .map(Value::Timestamp),
877            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
878                "add", left, right,
879            )),
880        }
881    }
882}
883
884impl ops::Sub<Value> for Value {
885    type Output = ResolveResult;
886
887    #[inline(always)]
888    fn sub(self, rhs: Value) -> Self::Output {
889        match (self, rhs) {
890            (Value::Int(l), Value::Int(r)) => l
891                .checked_sub(r)
892                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
893                .map(Value::Int),
894
895            (Value::UInt(l), Value::UInt(r)) => l
896                .checked_sub(r)
897                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
898                .map(Value::UInt),
899
900            (Value::Float(l), Value::Float(r)) => Value::Float(l - r).into(),
901
902            #[cfg(feature = "chrono")]
903            (Value::Duration(l), Value::Duration(r)) => l
904                .checked_sub(&r)
905                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
906                .map(Value::Duration),
907            #[cfg(feature = "chrono")]
908            (Value::Timestamp(l), Value::Duration(r)) => checked_op(TsOp::Sub, &l, &r),
909            #[cfg(feature = "chrono")]
910            (Value::Timestamp(l), Value::Timestamp(r)) => {
911                Value::Duration(l.signed_duration_since(r)).into()
912            }
913            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
914                "sub", left, right,
915            )),
916        }
917    }
918}
919
920impl ops::Div<Value> for Value {
921    type Output = ResolveResult;
922
923    #[inline(always)]
924    fn div(self, rhs: Value) -> Self::Output {
925        match (self, rhs) {
926            (Value::Int(l), Value::Int(r)) => {
927                if r == 0 {
928                    Err(ExecutionError::DivisionByZero(l.into()))
929                } else {
930                    l.checked_div(r)
931                        .ok_or(ExecutionError::Overflow("div", l.into(), r.into()))
932                        .map(Value::Int)
933                }
934            }
935
936            (Value::UInt(l), Value::UInt(r)) => l
937                .checked_div(r)
938                .ok_or(ExecutionError::DivisionByZero(l.into()))
939                .map(Value::UInt),
940
941            (Value::Float(l), Value::Float(r)) => Value::Float(l / r).into(),
942
943            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
944                "div", left, right,
945            )),
946        }
947    }
948}
949
950impl ops::Mul<Value> for Value {
951    type Output = ResolveResult;
952
953    #[inline(always)]
954    fn mul(self, rhs: Value) -> Self::Output {
955        match (self, rhs) {
956            (Value::Int(l), Value::Int(r)) => l
957                .checked_mul(r)
958                .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
959                .map(Value::Int),
960
961            (Value::UInt(l), Value::UInt(r)) => l
962                .checked_mul(r)
963                .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
964                .map(Value::UInt),
965
966            (Value::Float(l), Value::Float(r)) => Value::Float(l * r).into(),
967
968            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
969                "mul", left, right,
970            )),
971        }
972    }
973}
974
975impl ops::Rem<Value> for Value {
976    type Output = ResolveResult;
977
978    #[inline(always)]
979    fn rem(self, rhs: Value) -> Self::Output {
980        match (self, rhs) {
981            (Value::Int(l), Value::Int(r)) => {
982                if r == 0 {
983                    Err(ExecutionError::RemainderByZero(l.into()))
984                } else {
985                    l.checked_rem(r)
986                        .ok_or(ExecutionError::Overflow("rem", l.into(), r.into()))
987                        .map(Value::Int)
988                }
989            }
990
991            (Value::UInt(l), Value::UInt(r)) => l
992                .checked_rem(r)
993                .ok_or(ExecutionError::RemainderByZero(l.into()))
994                .map(Value::UInt),
995
996            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
997                "rem", left, right,
998            )),
999        }
1000    }
1001}
1002
1003/// Op represents a binary arithmetic operation supported on a timestamp
1004///
1005#[cfg(feature = "chrono")]
1006enum TsOp {
1007    Add,
1008    Sub,
1009}
1010
1011#[cfg(feature = "chrono")]
1012impl TsOp {
1013    fn str(&self) -> &'static str {
1014        match self {
1015            TsOp::Add => "add",
1016            TsOp::Sub => "sub",
1017        }
1018    }
1019}
1020
1021/// Performs a checked arithmetic operation [`TsOp`] on a timestamp and a duration and ensures that
1022/// the resulting timestamp does not overflow the data type internal limits, as well as the timestamp
1023/// limits defined in the cel-spec. See [`MAX_TIMESTAMP`] and [`MIN_TIMESTAMP`] for more details.
1024#[cfg(feature = "chrono")]
1025fn checked_op(
1026    op: TsOp,
1027    lhs: &chrono::DateTime<chrono::FixedOffset>,
1028    rhs: &chrono::Duration,
1029) -> ResolveResult {
1030    // Add lhs and rhs together, checking for data type overflow
1031    let result = match op {
1032        TsOp::Add => lhs.checked_add_signed(*rhs),
1033        TsOp::Sub => lhs.checked_sub_signed(*rhs),
1034    }
1035    .ok_or(ExecutionError::Overflow(
1036        op.str(),
1037        (*lhs).into(),
1038        (*rhs).into(),
1039    ))?;
1040
1041    // Check for cel-spec limits
1042    if result > *MAX_TIMESTAMP || result < *MIN_TIMESTAMP {
1043        Err(ExecutionError::Overflow(
1044            op.str(),
1045            (*lhs).into(),
1046            (*rhs).into(),
1047        ))
1048    } else {
1049        Value::Timestamp(result).into()
1050    }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use crate::{objects::Key, Context, ExecutionError, Program, Value};
1056    use std::collections::HashMap;
1057    use std::sync::Arc;
1058
1059    #[test]
1060    fn test_indexed_map_access() {
1061        let mut context = Context::default();
1062        let mut headers = HashMap::new();
1063        headers.insert("Content-Type", "application/json".to_string());
1064        context.add_variable_from_value("headers", headers);
1065
1066        let program = Program::compile("headers[\"Content-Type\"]").unwrap();
1067        let value = program.execute(&context).unwrap();
1068        assert_eq!(value, "application/json".into());
1069    }
1070
1071    #[test]
1072    fn test_numeric_map_access() {
1073        let mut context = Context::default();
1074        let mut numbers = HashMap::new();
1075        numbers.insert(Key::Uint(1), "one".to_string());
1076        context.add_variable_from_value("numbers", numbers);
1077
1078        let program = Program::compile("numbers[1]").unwrap();
1079        let value = program.execute(&context).unwrap();
1080        assert_eq!(value, "one".into());
1081    }
1082
1083    #[test]
1084    fn test_heterogeneous_compare() {
1085        let context = Context::default();
1086
1087        let program = Program::compile("1 < uint(2)").unwrap();
1088        let value = program.execute(&context).unwrap();
1089        assert_eq!(value, true.into());
1090
1091        let program = Program::compile("1 < 1.1").unwrap();
1092        let value = program.execute(&context).unwrap();
1093        assert_eq!(value, true.into());
1094
1095        let program = Program::compile("uint(0) > -10").unwrap();
1096        let value = program.execute(&context).unwrap();
1097        assert_eq!(
1098            value,
1099            true.into(),
1100            "negative signed ints should be less than uints"
1101        );
1102    }
1103
1104    #[test]
1105    fn test_float_compare() {
1106        let context = Context::default();
1107
1108        let program = Program::compile("1.0 > 0.0").unwrap();
1109        let value = program.execute(&context).unwrap();
1110        assert_eq!(value, true.into());
1111
1112        let program = Program::compile("double('NaN') == double('NaN')").unwrap();
1113        let value = program.execute(&context).unwrap();
1114        assert_eq!(value, false.into(), "NaN should not equal itself");
1115
1116        let program = Program::compile("1.0 > double('NaN')").unwrap();
1117        let result = program.execute(&context);
1118        assert!(
1119            result.is_err(),
1120            "NaN should not be comparable with inequality operators"
1121        );
1122    }
1123
1124    #[test]
1125    fn test_invalid_compare() {
1126        let context = Context::default();
1127
1128        let program = Program::compile("{} == []").unwrap();
1129        let value = program.execute(&context).unwrap();
1130        assert_eq!(value, false.into());
1131    }
1132
1133    #[test]
1134    fn test_size_fn_var() {
1135        let program = Program::compile("size(requests) + size == 5").unwrap();
1136        let mut context = Context::default();
1137        let requests = vec![Value::Int(42), Value::Int(42)];
1138        context
1139            .add_variable("requests", Value::List(Arc::new(requests)))
1140            .unwrap();
1141        context.add_variable("size", Value::Int(3)).unwrap();
1142        assert_eq!(program.execute(&context).unwrap(), Value::Bool(true));
1143    }
1144
1145    fn test_execution_error(program: &str, expected: ExecutionError) {
1146        let program = Program::compile(program).unwrap();
1147        let result = program.execute(&Context::default());
1148        assert_eq!(result.unwrap_err(), expected);
1149    }
1150
1151    #[test]
1152    fn test_invalid_sub() {
1153        test_execution_error(
1154            "'foo' - 10",
1155            ExecutionError::UnsupportedBinaryOperator("sub", "foo".into(), Value::Int(10)),
1156        );
1157    }
1158
1159    #[test]
1160    fn test_invalid_add() {
1161        test_execution_error(
1162            "'foo' + 10",
1163            ExecutionError::UnsupportedBinaryOperator("add", "foo".into(), Value::Int(10)),
1164        );
1165    }
1166
1167    #[test]
1168    fn test_invalid_div() {
1169        test_execution_error(
1170            "'foo' / 10",
1171            ExecutionError::UnsupportedBinaryOperator("div", "foo".into(), Value::Int(10)),
1172        );
1173    }
1174
1175    #[test]
1176    fn test_invalid_rem() {
1177        test_execution_error(
1178            "'foo' % 10",
1179            ExecutionError::UnsupportedBinaryOperator("rem", "foo".into(), Value::Int(10)),
1180        );
1181    }
1182
1183    #[test]
1184    fn out_of_bound_list_access() {
1185        let program = Program::compile("list[10]").unwrap();
1186        let mut context = Context::default();
1187        context
1188            .add_variable("list", Value::List(Arc::new(vec![])))
1189            .unwrap();
1190        let result = program.execute(&context);
1191        assert_eq!(result.unwrap(), Value::Null);
1192    }
1193
1194    #[test]
1195    fn reference_to_value() {
1196        let test = "example".to_string();
1197        let direct: Value = test.as_str().into();
1198        assert_eq!(direct, Value::String(Arc::new(String::from("example"))));
1199
1200        let vec = vec![test.as_str()];
1201        let indirect: Value = vec.into();
1202        assert_eq!(
1203            indirect,
1204            Value::List(Arc::new(vec![Value::String(Arc::new(String::from(
1205                "example"
1206            )))]))
1207        );
1208    }
1209
1210    #[test]
1211    fn test_short_circuit_and() {
1212        let mut context = Context::default();
1213        let data: HashMap<String, String> = HashMap::new();
1214        context.add_variable_from_value("data", data);
1215
1216        let program = Program::compile("has(data.x) && data.x.startsWith(\"foo\")").unwrap();
1217        let value = program.execute(&context);
1218        println!("{value:?}");
1219        assert!(
1220            value.is_ok(),
1221            "The AND expression should support short-circuit evaluation."
1222        );
1223    }
1224
1225    #[test]
1226    fn invalid_int_math() {
1227        use ExecutionError::*;
1228
1229        let cases = [
1230            ("1 / 0", DivisionByZero(1.into())),
1231            ("1 % 0", RemainderByZero(1.into())),
1232            (
1233                &format!("{} + 1", i64::MAX),
1234                Overflow("add", i64::MAX.into(), 1.into()),
1235            ),
1236            (
1237                &format!("{} - 1", i64::MIN),
1238                Overflow("sub", i64::MIN.into(), 1.into()),
1239            ),
1240            (
1241                &format!("{} * 2", i64::MAX),
1242                Overflow("mul", i64::MAX.into(), 2.into()),
1243            ),
1244            (
1245                &format!("{} / -1", i64::MIN),
1246                Overflow("div", i64::MIN.into(), (-1).into()),
1247            ),
1248            (
1249                &format!("{} % -1", i64::MIN),
1250                Overflow("rem", i64::MIN.into(), (-1).into()),
1251            ),
1252        ];
1253
1254        for (expr, err) in cases {
1255            test_execution_error(expr, err);
1256        }
1257    }
1258
1259    #[test]
1260    fn invalid_uint_math() {
1261        use ExecutionError::*;
1262
1263        let cases = [
1264            ("1u / 0u", DivisionByZero(1u64.into())),
1265            ("1u % 0u", RemainderByZero(1u64.into())),
1266            (
1267                &format!("{}u + 1u", u64::MAX),
1268                Overflow("add", u64::MAX.into(), 1u64.into()),
1269            ),
1270            ("0u - 1u", Overflow("sub", 0u64.into(), 1u64.into())),
1271            (
1272                &format!("{}u * 2u", u64::MAX),
1273                Overflow("mul", u64::MAX.into(), 2u64.into()),
1274            ),
1275        ];
1276
1277        for (expr, err) in cases {
1278            test_execution_error(expr, err);
1279        }
1280    }
1281
1282    #[test]
1283    fn test_function_identifier() {
1284        fn with(
1285            ftx: &crate::FunctionContext,
1286            crate::extractors::This(this): crate::extractors::This<Value>,
1287            ident: crate::extractors::Identifier,
1288            expr: crate::parser::Expression,
1289        ) -> crate::ResolveResult {
1290            let mut ptx = ftx.ptx.new_inner_scope();
1291            ptx.add_variable_from_value(&ident, this);
1292            ptx.resolve(&expr)
1293        }
1294        let mut context = Context::default();
1295        context.add_function("with", with);
1296
1297        let program = Program::compile("[1,2].with(a, a + a)").unwrap();
1298        let value = program.execute(&context);
1299        assert_eq!(
1300            value,
1301            Ok(Value::List(Arc::new(vec![
1302                Value::Int(1),
1303                Value::Int(2),
1304                Value::Int(1),
1305                Value::Int(2)
1306            ])))
1307        );
1308    }
1309}