1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
use crate::magic::{Function, FunctionRegistry, Handler};
use crate::objects::{TryIntoValue, Value};
use crate::{functions, ExecutionError};
use cel_parser::Expression;
use std::collections::HashMap;

/// Context is a collection of variables and functions that can be used
/// by the interpreter to resolve expressions. The context can be either
/// a parent context, or a child context. A parent context is created by
/// default and contains all of the built-in functions. A child context
/// can be created by calling `.clone()`. The child context has it's own
/// variables (which can be added to), but it will also reference the
/// parent context. This allows for variables to be overridden within the
/// child context while still being able to resolve variables in the child's
/// parents. You can have theoretically have an infinite number of child
/// contexts that reference each-other.
///
/// So why is this important? Well some CEL-macros such as the `.map` macro
/// declare intermediate user-specified identifiers that should only be
/// available within the macro, and should not override variables in the
/// parent context. The `.map` macro can clone the parent context, add the
/// intermediate identifier to the child context, and then evaluate the
/// map expression.
///
/// Intermediate variable stored in child context
///               ↓
/// [1, 2, 3].map(x, x * 2) == [2, 4, 6]
///                  ↑
/// Only in scope for the duration of the map expression
///
pub enum Context<'a> {
    Root {
        functions: FunctionRegistry,
        variables: HashMap<String, Value>,
    },
    Child {
        parent: &'a Context<'a>,
        variables: HashMap<String, Value>,
    },
}

impl<'a> Context<'a> {
    pub fn add_variable<S, V>(
        &mut self,
        name: S,
        value: V,
    ) -> Result<(), Box<dyn std::error::Error>>
    where
        S: Into<String>,
        V: TryIntoValue,
    {
        match self {
            Context::Root { variables, .. } => {
                variables.insert(name.into(), value.try_into_value()?);
            }
            Context::Child { variables, .. } => {
                variables.insert(name.into(), value.try_into_value()?);
            }
        }
        Ok(())
    }

    pub fn add_variable_from_value<S, V>(&mut self, name: S, value: V)
    where
        S: Into<String>,
        V: Into<Value>,
    {
        match self {
            Context::Root { variables, .. } => {
                variables.insert(name.into(), value.into());
            }
            Context::Child { variables, .. } => {
                variables.insert(name.into(), value.into());
            }
        }
    }

    pub fn get_variable<S>(&self, name: S) -> Result<Value, ExecutionError>
    where
        S: Into<String>,
    {
        let name = name.into();
        match self {
            Context::Child { variables, parent } => variables
                .get(&name)
                .cloned()
                .or_else(|| parent.get_variable(&name).ok())
                .ok_or(ExecutionError::UndeclaredReference(name.into())),
            Context::Root { variables, .. } => variables
                .get(&name)
                .cloned()
                .ok_or(ExecutionError::UndeclaredReference(name.into())),
        }
    }

    pub(crate) fn has_function<S>(&self, name: S) -> bool
    where
        S: Into<String>,
    {
        let name = name.into();
        match self {
            Context::Root { functions, .. } => functions.has(&name),
            Context::Child { parent, .. } => parent.has_function(name),
        }
    }

    pub(crate) fn get_function<S>(&self, name: S) -> Option<Box<dyn Function>>
    where
        S: Into<String>,
    {
        let name = name.into();
        match self {
            Context::Root { functions, .. } => functions.get(&name),
            Context::Child { parent, .. } => parent.get_function(name),
        }
    }

    pub fn add_function<T: 'static, F: 'static>(&mut self, name: &str, value: F)
    where
        F: Handler<T> + 'static,
    {
        if let Context::Root { functions, .. } = self {
            functions.add(name, value);
        };
    }

    pub fn resolve(&self, expr: &Expression) -> Result<Value, ExecutionError> {
        Value::resolve(expr, self)
    }

    pub fn resolve_all(&self, exprs: &[Expression]) -> Result<Value, ExecutionError> {
        Value::resolve_all(exprs, self)
    }

    pub fn clone(&self) -> Context {
        Context::Child {
            parent: self,
            variables: Default::default(),
        }
    }
}

impl<'a> Default for Context<'a> {
    fn default() -> Self {
        let mut ctx = Context::Root {
            variables: Default::default(),
            functions: Default::default(),
        };
        ctx.add_function("contains", functions::contains);
        ctx.add_function("size", functions::size);
        ctx.add_function("has", functions::has);
        ctx.add_function("map", functions::map);
        ctx.add_function("filter", functions::filter);
        ctx.add_function("all", functions::all);
        ctx.add_function("max", functions::max);
        ctx.add_function("startsWith", functions::starts_with);
        ctx.add_function("duration", functions::duration);
        ctx.add_function("timestamp", functions::timestamp);
        ctx.add_function("string", functions::string);
        ctx.add_function("double", functions::double);
        ctx.add_function("exists", functions::exists);
        ctx.add_function("exists_one", functions::exists_one);
        ctx.add_function("int", functions::int);
        ctx.add_function("uint", functions::uint);
        ctx
    }
}