rscel/interp/
mod.rs

1mod types;
2use crate::{types::CelByteCode, CelValueDyn};
3use std::{collections::HashMap, fmt};
4pub use types::{ByteCode, JmpWhen};
5
6use crate::{
7    context::construct_type, utils::ScopedCounter, BindContext, CelContext, CelError, CelResult,
8    CelValue, RsCelFunction, RsCelMacro,
9};
10
11use types::CelStackValue;
12
13use self::types::RsCallable;
14
15struct InterpStack<'a, 'b> {
16    stack: Vec<CelStackValue<'b>>,
17
18    ctx: &'a Interpreter<'b>,
19}
20
21impl<'a, 'b> InterpStack<'a, 'b> {
22    fn new(ctx: &'b Interpreter) -> InterpStack<'a, 'b> {
23        InterpStack {
24            stack: Vec::new(),
25            ctx,
26        }
27    }
28
29    fn push(&mut self, val: CelStackValue<'b>) {
30        self.stack.push(val);
31    }
32
33    fn push_val(&mut self, val: CelValue) {
34        self.stack.push(CelStackValue::Value(val));
35    }
36
37    fn pop(&mut self) -> CelResult<CelStackValue> {
38        match self.stack.pop() {
39            Some(stack_val) => {
40                if let CelStackValue::Value(val) = stack_val {
41                    if let CelValue::Ident(name) = val {
42                        if let Some(val) = self.ctx.get_type_by_name(&name) {
43                            return Ok(CelStackValue::Value(val.clone()));
44                        }
45
46                        if let Some(val) = self.ctx.get_param_by_name(&name) {
47                            return Ok(CelStackValue::Value(val.clone()));
48                        }
49
50                        if let Some(ctx) = self.ctx.cel {
51                            // Allow for loaded programs to run as values
52                            if let Some(prog) = ctx.get_program(&name) {
53                                return self.ctx.run_raw(prog.bytecode(), true).map(|x| x.into());
54                            }
55                        }
56
57                        Ok(CelValue::from_err(CelError::binding(&name)).into())
58                    } else {
59                        Ok(val.into())
60                    }
61                } else {
62                    Ok(stack_val)
63                }
64            }
65            None => Err(CelError::runtime("No value on stack!")),
66        }
67    }
68
69    fn pop_val(&mut self) -> CelResult<CelValue> {
70        self.pop()?.into_value()
71    }
72
73    fn pop_noresolve(&mut self) -> CelResult<CelStackValue<'b>> {
74        match self.stack.pop() {
75            Some(val) => Ok(val),
76            None => Err(CelError::runtime("No value on stack!")),
77        }
78    }
79
80    fn pop_tryresolve(&mut self) -> CelResult<CelStackValue<'b>> {
81        match self.stack.pop() {
82            Some(val) => match val.try_into()? {
83                CelValue::Ident(name) => {
84                    if let Some(val) = self.ctx.get_param_by_name(&name) {
85                        Ok(val.clone().into())
86                    } else {
87                        Ok(CelStackValue::Value(CelValue::from_ident(&name)))
88                    }
89                }
90                other => Ok(CelStackValue::Value(other.into())),
91            },
92            None => Err(CelError::runtime("No value on stack!")),
93        }
94    }
95}
96
97impl<'a, 'b> fmt::Debug for InterpStack<'a, 'b> {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{:?}", self.stack)
100    }
101}
102
103pub struct Interpreter<'a> {
104    cel: Option<&'a CelContext>,
105    bindings: Option<&'a BindContext<'a>>,
106    depth: ScopedCounter,
107}
108
109impl<'a> Interpreter<'a> {
110    pub fn new(cel: &'a CelContext, bindings: &'a BindContext) -> Interpreter<'a> {
111        Interpreter {
112            cel: Some(cel),
113            bindings: Some(bindings),
114            depth: ScopedCounter::new(),
115        }
116    }
117
118    pub fn empty() -> Interpreter<'a> {
119        Interpreter {
120            cel: None,
121            bindings: None,
122            depth: ScopedCounter::new(),
123        }
124    }
125
126    pub fn add_bindings(&mut self, bindings: &'a BindContext) {
127        self.bindings = Some(bindings);
128    }
129
130    pub fn cel_copy(&self) -> Option<CelContext> {
131        self.cel.cloned()
132    }
133
134    pub fn bindings_copy(&self) -> Option<BindContext> {
135        self.bindings.cloned()
136    }
137
138    pub fn run_program(&self, name: &str) -> CelResult<CelValue> {
139        match self.cel {
140            Some(cel) => match cel.get_program(name) {
141                Some(prog) => self.run_raw(prog.bytecode(), true),
142                None => Err(CelError::binding(&name)),
143            },
144            None => Err(CelError::internal("No CEL context bound to interpreter")),
145        }
146    }
147
148    pub fn run_raw(&self, prog: &CelByteCode, resolve: bool) -> CelResult<CelValue> {
149        let mut pc: usize = 0;
150        let mut stack = InterpStack::new(self);
151
152        let count = self.depth.inc();
153
154        if count.count() > 32 {
155            return Err(CelError::runtime("Max call depth excceded"));
156        }
157
158        while pc < prog.len() {
159            let oldpc = pc;
160            pc += 1;
161            match &prog[oldpc] {
162                ByteCode::Push(val) => stack.push(val.clone().into()),
163                ByteCode::Or => {
164                    let v2 = stack.pop_val()?;
165                    let v1 = stack.pop_val()?;
166
167                    stack.push_val(v1.or(&v2))
168                }
169                ByteCode::And => {
170                    let v2 = stack.pop_val()?;
171                    let v1 = stack.pop_val()?;
172
173                    stack.push_val(v1.and(v2))
174                }
175                ByteCode::Not => {
176                    let v1 = stack.pop_val()?;
177
178                    stack.push_val(!v1);
179                }
180                ByteCode::Neg => {
181                    let v1 = stack.pop_val()?;
182
183                    stack.push_val(-v1);
184                }
185                ByteCode::Add => {
186                    let v2 = stack.pop_val()?;
187                    let v1 = stack.pop_val()?;
188
189                    stack.push_val(v1 + v2);
190                }
191                ByteCode::Sub => {
192                    let v2 = stack.pop_val()?;
193                    let v1 = stack.pop_val()?;
194
195                    stack.push_val(v1 - v2);
196                }
197                ByteCode::Mul => {
198                    let v2 = stack.pop_val()?;
199                    let v1 = stack.pop_val()?;
200
201                    stack.push_val(v1 * v2);
202                }
203                ByteCode::Div => {
204                    let v2 = stack.pop_val()?;
205                    let v1 = stack.pop_val()?;
206
207                    stack.push_val(v1 / v2);
208                }
209                ByteCode::Mod => {
210                    let v2 = stack.pop_val()?;
211                    let v1 = stack.pop_val()?;
212
213                    stack.push_val(v1 % v2);
214                }
215                ByteCode::Lt => {
216                    let v2 = stack.pop_val()?;
217                    let v1 = stack.pop_val()?;
218
219                    stack.push_val(v1.lt(v2));
220                }
221                ByteCode::Le => {
222                    let v2 = stack.pop_val()?;
223                    let v1 = stack.pop_val()?;
224
225                    stack.push_val(v1.le(v2));
226                }
227                ByteCode::Eq => {
228                    let v2 = stack.pop_val()?;
229                    let v1 = stack.pop_val()?;
230
231                    stack.push_val(CelValueDyn::eq(&v1, &v2));
232                }
233                ByteCode::Ne => {
234                    let v2 = stack.pop_val()?;
235                    let v1 = stack.pop_val()?;
236
237                    stack.push_val(v1.neq(v2));
238                }
239                ByteCode::Ge => {
240                    let v2 = stack.pop_val()?;
241                    let v1 = stack.pop_val()?;
242
243                    stack.push_val(v1.ge(v2));
244                }
245                ByteCode::Gt => {
246                    let v2 = stack.pop_val()?;
247                    let v1 = stack.pop_val()?;
248
249                    stack.push_val(v1.gt(v2));
250                }
251                ByteCode::In => {
252                    let rhs = stack.pop_val()?;
253                    let lhs = stack.pop_val()?;
254
255                    stack.push_val(lhs.in_(rhs));
256                }
257                ByteCode::Jmp(dist) => pc = pc + *dist as usize,
258                ByteCode::JmpCond {
259                    when,
260                    dist,
261                    leave_val,
262                } => {
263                    let mut v1 = stack.pop_val()?;
264                    match when {
265                        JmpWhen::True => {
266                            if cfg!(feature = "type_prop") {
267                                if v1.is_truthy() {
268                                    v1 = CelValue::true_();
269                                    pc += *dist as usize
270                                }
271                            } else if let CelValue::Err(ref _e) = v1 {
272                                // do nothing
273                            } else if let CelValue::Bool(v) = v1 {
274                                if v {
275                                    pc += *dist as usize
276                                }
277                            } else {
278                                return Err(CelError::invalid_op(&format!(
279                                    "JMP TRUE invalid on type {:?}",
280                                    v1.as_type()
281                                )));
282                            }
283                        }
284                        JmpWhen::False => {
285                            if cfg!(feature = "type_prop") {
286                                if !v1.is_truthy() {
287                                    v1 = CelValue::false_();
288                                    pc += *dist as usize
289                                }
290                            } else if let CelValue::Bool(v) = v1 {
291                                if !v {
292                                    pc += *dist as usize
293                                }
294                            } else {
295                                return Err(CelError::invalid_op(&format!(
296                                    "JMP FALSE invalid on type {:?}",
297                                    v1.as_type()
298                                )));
299                            }
300                        }
301                    };
302                    if *leave_val {
303                        stack.push_val(v1);
304                    }
305                }
306                ByteCode::MkList(size) => {
307                    let mut v = Vec::new();
308
309                    for _ in 0..*size {
310                        v.push(stack.pop_val()?)
311                    }
312
313                    v.reverse();
314                    stack.push_val(v.into());
315                }
316                ByteCode::MkDict(size) => {
317                    let mut map = HashMap::new();
318
319                    for _ in 0..*size {
320                        let key = if let CelValue::String(key) = stack.pop_val()? {
321                            key
322                        } else {
323                            return Err(CelError::value("Only strings can be used as Object keys"));
324                        };
325
326                        map.insert(key, stack.pop_val()?);
327                    }
328
329                    stack.push_val(map.into());
330                }
331                ByteCode::Index => {
332                    let index = stack.pop_val()?;
333                    let obj = stack.pop_val()?;
334
335                    stack.push_val(obj.index(index));
336                }
337                ByteCode::Access => {
338                    let index = stack.pop_noresolve()?;
339                    if let CelValue::Ident(ident) = index.as_value()? {
340                        let obj = stack.pop()?.into_value()?;
341                        match obj {
342                            CelValue::Map(ref map) => match map.get(ident.as_str()) {
343                                Some(val) => stack.push_val(val.clone()),
344                                None => match self.callable_by_name(ident.as_str()) {
345                                    Ok(callable) => stack.push(CelStackValue::BoundCall {
346                                        callable,
347                                        value: obj,
348                                    }),
349                                    Err(_) => {
350                                        stack.push(
351                                            CelValue::from_err(CelError::attribute(
352                                                "obj",
353                                                ident.as_str(),
354                                            ))
355                                            .into(),
356                                        );
357                                    }
358                                },
359                            },
360                            #[cfg(feature = "protobuf")]
361                            CelValue::Message(msg) => {
362                                let desc = msg.descriptor_dyn();
363
364                                if let Some(field) = desc.field_by_name(ident.as_str()) {
365                                    stack.push_val(
366                                        field.get_singular_field_or_default(msg.as_ref()).into(),
367                                    )
368                                } else {
369                                    return Err(CelError::attribute("msg", ident.as_str()));
370                                }
371                            }
372                            CelValue::Dyn(d) => {
373                                stack.push_val(d.access(ident.as_str()));
374                            }
375                            _ => {
376                                if let Some(bindings) = self.bindings {
377                                    if bindings.get_func(ident.as_str()).is_some()
378                                        || bindings.get_macro(ident.as_str()).is_some()
379                                    {
380                                        stack.push(CelStackValue::BoundCall {
381                                            callable: self.callable_by_name(ident.as_str())?,
382                                            value: obj,
383                                        });
384                                    } else {
385                                        stack.push(
386                                            CelValue::from_err(CelError::attribute(
387                                                "obj",
388                                                ident.as_str(),
389                                            ))
390                                            .into(),
391                                        );
392                                    }
393                                } else {
394                                    return Err(CelError::Runtime(
395                                        "Invalid state: no bindings".to_string(),
396                                    ));
397                                }
398                            }
399                        }
400                    } else {
401                        let obj_type = stack.pop()?.into_value()?.as_type();
402                        stack.push(
403                            CelValue::from_err(CelError::value(&format!(
404                                "Index operator invalid between {:?} and {:?}",
405                                index.into_value()?.as_type(),
406                                obj_type
407                            )))
408                            .into(),
409                        );
410                    }
411                }
412                ByteCode::Call(n_args) => {
413                    let mut args = Vec::new();
414
415                    for _ in 0..*n_args {
416                        args.push(stack.pop()?.into_value()?)
417                    }
418
419                    match stack.pop_noresolve()? {
420                        CelStackValue::BoundCall { callable, value } => match callable {
421                            RsCallable::Function(func) => {
422                                let arg_values = self.resolve_args(args)?;
423                                stack.push_val(func(value, arg_values));
424                            }
425                            RsCallable::Macro(macro_) => {
426                                stack.push_val(self.call_macro(&value, &args, macro_)?);
427                            }
428                        },
429                        CelStackValue::Value(value) => match value {
430                            CelValue::Ident(func_name) => {
431                                if let Some(func) = self.get_func_by_name(&func_name) {
432                                    let arg_values = self.resolve_args(args)?;
433                                    stack.push_val(func(CelValue::from_null(), arg_values));
434                                } else if let Some(macro_) = self.get_macro_by_name(&func_name) {
435                                    stack.push_val(self.call_macro(
436                                        &CelValue::from_null(),
437                                        &args,
438                                        macro_,
439                                    )?);
440                                } else if let Some(CelValue::Type(type_name)) =
441                                    self.get_type_by_name(&func_name)
442                                {
443                                    let arg_values = self.resolve_args(args)?;
444                                    stack.push_val(construct_type(type_name, arg_values));
445                                } else {
446                                    stack.push_val(CelValue::from_err(CelError::runtime(
447                                        &format!("{} is not callable", func_name),
448                                    )));
449                                }
450                            }
451                            CelValue::Type(type_name) => {
452                                let arg_values = self.resolve_args(args)?;
453                                stack.push_val(construct_type(&type_name, arg_values));
454                            }
455                            other => stack.push_val(
456                                CelValue::from_err(CelError::runtime(&format!(
457                                    "{:?} cannot be called",
458                                    other
459                                )))
460                                .into(),
461                            ),
462                        },
463                    };
464                }
465                ByteCode::FmtString(nsegments) => {
466                    let mut segments = Vec::new();
467                    for _ in 0..*nsegments {
468                        segments.push(stack.pop_val()?);
469                    }
470
471                    let mut working = String::new();
472                    for seg in segments.into_iter().rev() {
473                        if let CelValue::String(s) = seg {
474                            working.push_str(&s)
475                        } else {
476                            return Err(CelError::Runtime(
477                                "Expected string from format string specifier".to_string(),
478                            ));
479                        }
480                    }
481
482                    stack.push_val(CelValue::String(working));
483                }
484            };
485        }
486
487        if resolve {
488            match stack.pop() {
489                Ok(val) => {
490                    let cel: CelValue = val.try_into()?;
491                    cel.into_result()
492                }
493                Err(err) => Err(err),
494            }
495        } else {
496            match stack.pop_tryresolve() {
497                Ok(val) => {
498                    let cel: CelValue = val.try_into()?;
499                    cel.into_result()
500                }
501                Err(err) => Err(err),
502            }
503        }
504    }
505
506    fn call_macro(
507        &self,
508        this: &CelValue,
509        args: &Vec<CelValue>,
510        macro_: &RsCelMacro,
511    ) -> Result<CelValue, CelError> {
512        let mut v = Vec::new();
513        for arg in args.iter() {
514            if let CelValue::ByteCode(bc) = arg {
515                v.push(bc);
516            } else {
517                return Err(CelError::internal("macro args must be bytecode"));
518            }
519        }
520        let res = macro_(self, this.clone(), &v);
521        Ok(res)
522    }
523
524    fn resolve_args(&self, args: Vec<CelValue>) -> Result<Vec<CelValue>, CelError> {
525        let mut arg_values = Vec::new();
526        for arg in args.into_iter() {
527            if let CelValue::ByteCode(bc) = arg {
528                arg_values.push(self.run_raw(&bc, true)?);
529            } else {
530                arg_values.push(arg)
531            }
532        }
533        Ok(arg_values)
534    }
535
536    fn get_param_by_name(&self, name: &str) -> Option<&'a CelValue> {
537        self.bindings?.get_param(name)
538    }
539
540    fn get_func_by_name(&self, name: &str) -> Option<&'a RsCelFunction> {
541        self.bindings?.get_func(name)
542    }
543
544    fn get_macro_by_name(&self, name: &str) -> Option<&'a RsCelMacro> {
545        self.bindings?.get_macro(name)
546    }
547
548    fn get_type_by_name(&self, name: &str) -> Option<&'a CelValue> {
549        self.bindings?.get_type(name)
550    }
551
552    fn callable_by_name(&self, name: &str) -> CelResult<RsCallable> {
553        if let Some(func) = self.get_func_by_name(name) {
554            Ok(RsCallable::Function(func))
555        } else if let Some(macro_) = self.get_macro_by_name(name) {
556            Ok(RsCallable::Macro(macro_))
557        } else {
558            Err(CelError::value(&format!("{} is not callable", name)))
559        }
560    }
561}
562
563#[cfg(test)]
564mod test {
565    use crate::{types::CelByteCode, CelValue};
566
567    use super::{types::ByteCode, Interpreter};
568    use test_case::test_case;
569
570    #[test_case(ByteCode::Add, 7.into())]
571    #[test_case(ByteCode::Sub, 1.into())]
572    #[test_case(ByteCode::Mul, 12.into())]
573    #[test_case(ByteCode::Div, 1.into())]
574    #[test_case(ByteCode::Mod, 1.into())]
575    #[test_case(ByteCode::Lt, false.into())]
576    #[test_case(ByteCode::Le, false.into())]
577    #[test_case(ByteCode::Eq, false.into())]
578    #[test_case(ByteCode::Ne, true.into())]
579    #[test_case(ByteCode::Ge, true.into())]
580    #[test_case(ByteCode::Gt, true.into())]
581    fn test_interp_ops(op: ByteCode, expected: CelValue) {
582        let mut prog =
583            CelByteCode::from_vec(vec![ByteCode::Push(4.into()), ByteCode::Push(3.into())]);
584        prog.push(op);
585        let interp = Interpreter::empty();
586
587        assert!(interp.run_raw(&prog, true).unwrap() == expected);
588    }
589}