1mod types;
2use crate::{types::CelByteCode, CelValueDyn};
3use std::{collections::HashMap, fmt};
4pub use types::{ByteCode, JmpWhen};
5
6use crate::{
7 context::construct_type, utils::ScopedCounter, BindContext, CelContext, CelError, CelResult,
8 CelValue, RsCelFunction, RsCelMacro,
9};
10
11use types::CelStackValue;
12
13use self::types::RsCallable;
14
15struct InterpStack<'a, 'b> {
16 stack: Vec<CelStackValue<'b>>,
17
18 ctx: &'a Interpreter<'b>,
19}
20
21impl<'a, 'b> InterpStack<'a, 'b> {
22 fn new(ctx: &'b Interpreter) -> InterpStack<'a, 'b> {
23 InterpStack {
24 stack: Vec::new(),
25 ctx,
26 }
27 }
28
29 fn push(&mut self, val: CelStackValue<'b>) {
30 self.stack.push(val);
31 }
32
33 fn push_val(&mut self, val: CelValue) {
34 self.stack.push(CelStackValue::Value(val));
35 }
36
37 fn pop(&mut self) -> CelResult<CelStackValue> {
38 match self.stack.pop() {
39 Some(stack_val) => {
40 if let CelStackValue::Value(val) = stack_val {
41 if let CelValue::Ident(name) = val {
42 if let Some(val) = self.ctx.get_type_by_name(&name) {
43 return Ok(CelStackValue::Value(val.clone()));
44 }
45
46 if let Some(val) = self.ctx.get_param_by_name(&name) {
47 return Ok(CelStackValue::Value(val.clone()));
48 }
49
50 if let Some(ctx) = self.ctx.cel {
51 if let Some(prog) = ctx.get_program(&name) {
53 return self.ctx.run_raw(prog.bytecode(), true).map(|x| x.into());
54 }
55 }
56
57 Ok(CelValue::from_err(CelError::binding(&name)).into())
58 } else {
59 Ok(val.into())
60 }
61 } else {
62 Ok(stack_val)
63 }
64 }
65 None => Err(CelError::runtime("No value on stack!")),
66 }
67 }
68
69 fn pop_val(&mut self) -> CelResult<CelValue> {
70 self.pop()?.into_value()
71 }
72
73 fn pop_noresolve(&mut self) -> CelResult<CelStackValue<'b>> {
74 match self.stack.pop() {
75 Some(val) => Ok(val),
76 None => Err(CelError::runtime("No value on stack!")),
77 }
78 }
79
80 fn pop_tryresolve(&mut self) -> CelResult<CelStackValue<'b>> {
81 match self.stack.pop() {
82 Some(val) => match val.try_into()? {
83 CelValue::Ident(name) => {
84 if let Some(val) = self.ctx.get_param_by_name(&name) {
85 Ok(val.clone().into())
86 } else {
87 Ok(CelStackValue::Value(CelValue::from_ident(&name)))
88 }
89 }
90 other => Ok(CelStackValue::Value(other.into())),
91 },
92 None => Err(CelError::runtime("No value on stack!")),
93 }
94 }
95}
96
97impl<'a, 'b> fmt::Debug for InterpStack<'a, 'b> {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(f, "{:?}", self.stack)
100 }
101}
102
103pub struct Interpreter<'a> {
104 cel: Option<&'a CelContext>,
105 bindings: Option<&'a BindContext<'a>>,
106 depth: ScopedCounter,
107}
108
109impl<'a> Interpreter<'a> {
110 pub fn new(cel: &'a CelContext, bindings: &'a BindContext) -> Interpreter<'a> {
111 Interpreter {
112 cel: Some(cel),
113 bindings: Some(bindings),
114 depth: ScopedCounter::new(),
115 }
116 }
117
118 pub fn empty() -> Interpreter<'a> {
119 Interpreter {
120 cel: None,
121 bindings: None,
122 depth: ScopedCounter::new(),
123 }
124 }
125
126 pub fn add_bindings(&mut self, bindings: &'a BindContext) {
127 self.bindings = Some(bindings);
128 }
129
130 pub fn cel_copy(&self) -> Option<CelContext> {
131 self.cel.cloned()
132 }
133
134 pub fn bindings_copy(&self) -> Option<BindContext> {
135 self.bindings.cloned()
136 }
137
138 pub fn run_program(&self, name: &str) -> CelResult<CelValue> {
139 match self.cel {
140 Some(cel) => match cel.get_program(name) {
141 Some(prog) => self.run_raw(prog.bytecode(), true),
142 None => Err(CelError::binding(&name)),
143 },
144 None => Err(CelError::internal("No CEL context bound to interpreter")),
145 }
146 }
147
148 pub fn run_raw(&self, prog: &CelByteCode, resolve: bool) -> CelResult<CelValue> {
149 let mut pc: usize = 0;
150 let mut stack = InterpStack::new(self);
151
152 let count = self.depth.inc();
153
154 if count.count() > 32 {
155 return Err(CelError::runtime("Max call depth excceded"));
156 }
157
158 while pc < prog.len() {
159 let oldpc = pc;
160 pc += 1;
161 match &prog[oldpc] {
162 ByteCode::Push(val) => stack.push(val.clone().into()),
163 ByteCode::Or => {
164 let v2 = stack.pop_val()?;
165 let v1 = stack.pop_val()?;
166
167 stack.push_val(v1.or(&v2))
168 }
169 ByteCode::And => {
170 let v2 = stack.pop_val()?;
171 let v1 = stack.pop_val()?;
172
173 stack.push_val(v1.and(v2))
174 }
175 ByteCode::Not => {
176 let v1 = stack.pop_val()?;
177
178 stack.push_val(!v1);
179 }
180 ByteCode::Neg => {
181 let v1 = stack.pop_val()?;
182
183 stack.push_val(-v1);
184 }
185 ByteCode::Add => {
186 let v2 = stack.pop_val()?;
187 let v1 = stack.pop_val()?;
188
189 stack.push_val(v1 + v2);
190 }
191 ByteCode::Sub => {
192 let v2 = stack.pop_val()?;
193 let v1 = stack.pop_val()?;
194
195 stack.push_val(v1 - v2);
196 }
197 ByteCode::Mul => {
198 let v2 = stack.pop_val()?;
199 let v1 = stack.pop_val()?;
200
201 stack.push_val(v1 * v2);
202 }
203 ByteCode::Div => {
204 let v2 = stack.pop_val()?;
205 let v1 = stack.pop_val()?;
206
207 stack.push_val(v1 / v2);
208 }
209 ByteCode::Mod => {
210 let v2 = stack.pop_val()?;
211 let v1 = stack.pop_val()?;
212
213 stack.push_val(v1 % v2);
214 }
215 ByteCode::Lt => {
216 let v2 = stack.pop_val()?;
217 let v1 = stack.pop_val()?;
218
219 stack.push_val(v1.lt(v2));
220 }
221 ByteCode::Le => {
222 let v2 = stack.pop_val()?;
223 let v1 = stack.pop_val()?;
224
225 stack.push_val(v1.le(v2));
226 }
227 ByteCode::Eq => {
228 let v2 = stack.pop_val()?;
229 let v1 = stack.pop_val()?;
230
231 stack.push_val(CelValueDyn::eq(&v1, &v2));
232 }
233 ByteCode::Ne => {
234 let v2 = stack.pop_val()?;
235 let v1 = stack.pop_val()?;
236
237 stack.push_val(v1.neq(v2));
238 }
239 ByteCode::Ge => {
240 let v2 = stack.pop_val()?;
241 let v1 = stack.pop_val()?;
242
243 stack.push_val(v1.ge(v2));
244 }
245 ByteCode::Gt => {
246 let v2 = stack.pop_val()?;
247 let v1 = stack.pop_val()?;
248
249 stack.push_val(v1.gt(v2));
250 }
251 ByteCode::In => {
252 let rhs = stack.pop_val()?;
253 let lhs = stack.pop_val()?;
254
255 stack.push_val(lhs.in_(rhs));
256 }
257 ByteCode::Jmp(dist) => pc = pc + *dist as usize,
258 ByteCode::JmpCond {
259 when,
260 dist,
261 leave_val,
262 } => {
263 let mut v1 = stack.pop_val()?;
264 match when {
265 JmpWhen::True => {
266 if cfg!(feature = "type_prop") {
267 if v1.is_truthy() {
268 v1 = CelValue::true_();
269 pc += *dist as usize
270 }
271 } else if let CelValue::Err(ref _e) = v1 {
272 } else if let CelValue::Bool(v) = v1 {
274 if v {
275 pc += *dist as usize
276 }
277 } else {
278 return Err(CelError::invalid_op(&format!(
279 "JMP TRUE invalid on type {:?}",
280 v1.as_type()
281 )));
282 }
283 }
284 JmpWhen::False => {
285 if cfg!(feature = "type_prop") {
286 if !v1.is_truthy() {
287 v1 = CelValue::false_();
288 pc += *dist as usize
289 }
290 } else if let CelValue::Bool(v) = v1 {
291 if !v {
292 pc += *dist as usize
293 }
294 } else {
295 return Err(CelError::invalid_op(&format!(
296 "JMP FALSE invalid on type {:?}",
297 v1.as_type()
298 )));
299 }
300 }
301 };
302 if *leave_val {
303 stack.push_val(v1);
304 }
305 }
306 ByteCode::MkList(size) => {
307 let mut v = Vec::new();
308
309 for _ in 0..*size {
310 v.push(stack.pop_val()?)
311 }
312
313 v.reverse();
314 stack.push_val(v.into());
315 }
316 ByteCode::MkDict(size) => {
317 let mut map = HashMap::new();
318
319 for _ in 0..*size {
320 let key = if let CelValue::String(key) = stack.pop_val()? {
321 key
322 } else {
323 return Err(CelError::value("Only strings can be used as Object keys"));
324 };
325
326 map.insert(key, stack.pop_val()?);
327 }
328
329 stack.push_val(map.into());
330 }
331 ByteCode::Index => {
332 let index = stack.pop_val()?;
333 let obj = stack.pop_val()?;
334
335 stack.push_val(obj.index(index));
336 }
337 ByteCode::Access => {
338 let index = stack.pop_noresolve()?;
339 if let CelValue::Ident(ident) = index.as_value()? {
340 let obj = stack.pop()?.into_value()?;
341 match obj {
342 CelValue::Map(ref map) => match map.get(ident.as_str()) {
343 Some(val) => stack.push_val(val.clone()),
344 None => match self.callable_by_name(ident.as_str()) {
345 Ok(callable) => stack.push(CelStackValue::BoundCall {
346 callable,
347 value: obj,
348 }),
349 Err(_) => {
350 stack.push(
351 CelValue::from_err(CelError::attribute(
352 "obj",
353 ident.as_str(),
354 ))
355 .into(),
356 );
357 }
358 },
359 },
360 #[cfg(feature = "protobuf")]
361 CelValue::Message(msg) => {
362 let desc = msg.descriptor_dyn();
363
364 if let Some(field) = desc.field_by_name(ident.as_str()) {
365 stack.push_val(
366 field.get_singular_field_or_default(msg.as_ref()).into(),
367 )
368 } else {
369 return Err(CelError::attribute("msg", ident.as_str()));
370 }
371 }
372 CelValue::Dyn(d) => {
373 stack.push_val(d.access(ident.as_str()));
374 }
375 _ => {
376 if let Some(bindings) = self.bindings {
377 if bindings.get_func(ident.as_str()).is_some()
378 || bindings.get_macro(ident.as_str()).is_some()
379 {
380 stack.push(CelStackValue::BoundCall {
381 callable: self.callable_by_name(ident.as_str())?,
382 value: obj,
383 });
384 } else {
385 stack.push(
386 CelValue::from_err(CelError::attribute(
387 "obj",
388 ident.as_str(),
389 ))
390 .into(),
391 );
392 }
393 } else {
394 return Err(CelError::Runtime(
395 "Invalid state: no bindings".to_string(),
396 ));
397 }
398 }
399 }
400 } else {
401 let obj_type = stack.pop()?.into_value()?.as_type();
402 stack.push(
403 CelValue::from_err(CelError::value(&format!(
404 "Index operator invalid between {:?} and {:?}",
405 index.into_value()?.as_type(),
406 obj_type
407 )))
408 .into(),
409 );
410 }
411 }
412 ByteCode::Call(n_args) => {
413 let mut args = Vec::new();
414
415 for _ in 0..*n_args {
416 args.push(stack.pop()?.into_value()?)
417 }
418
419 match stack.pop_noresolve()? {
420 CelStackValue::BoundCall { callable, value } => match callable {
421 RsCallable::Function(func) => {
422 let arg_values = self.resolve_args(args)?;
423 stack.push_val(func(value, arg_values));
424 }
425 RsCallable::Macro(macro_) => {
426 stack.push_val(self.call_macro(&value, &args, macro_)?);
427 }
428 },
429 CelStackValue::Value(value) => match value {
430 CelValue::Ident(func_name) => {
431 if let Some(func) = self.get_func_by_name(&func_name) {
432 let arg_values = self.resolve_args(args)?;
433 stack.push_val(func(CelValue::from_null(), arg_values));
434 } else if let Some(macro_) = self.get_macro_by_name(&func_name) {
435 stack.push_val(self.call_macro(
436 &CelValue::from_null(),
437 &args,
438 macro_,
439 )?);
440 } else if let Some(CelValue::Type(type_name)) =
441 self.get_type_by_name(&func_name)
442 {
443 let arg_values = self.resolve_args(args)?;
444 stack.push_val(construct_type(type_name, arg_values));
445 } else {
446 stack.push_val(CelValue::from_err(CelError::runtime(
447 &format!("{} is not callable", func_name),
448 )));
449 }
450 }
451 CelValue::Type(type_name) => {
452 let arg_values = self.resolve_args(args)?;
453 stack.push_val(construct_type(&type_name, arg_values));
454 }
455 other => stack.push_val(
456 CelValue::from_err(CelError::runtime(&format!(
457 "{:?} cannot be called",
458 other
459 )))
460 .into(),
461 ),
462 },
463 };
464 }
465 ByteCode::FmtString(nsegments) => {
466 let mut segments = Vec::new();
467 for _ in 0..*nsegments {
468 segments.push(stack.pop_val()?);
469 }
470
471 let mut working = String::new();
472 for seg in segments.into_iter().rev() {
473 if let CelValue::String(s) = seg {
474 working.push_str(&s)
475 } else {
476 return Err(CelError::Runtime(
477 "Expected string from format string specifier".to_string(),
478 ));
479 }
480 }
481
482 stack.push_val(CelValue::String(working));
483 }
484 };
485 }
486
487 if resolve {
488 match stack.pop() {
489 Ok(val) => {
490 let cel: CelValue = val.try_into()?;
491 cel.into_result()
492 }
493 Err(err) => Err(err),
494 }
495 } else {
496 match stack.pop_tryresolve() {
497 Ok(val) => {
498 let cel: CelValue = val.try_into()?;
499 cel.into_result()
500 }
501 Err(err) => Err(err),
502 }
503 }
504 }
505
506 fn call_macro(
507 &self,
508 this: &CelValue,
509 args: &Vec<CelValue>,
510 macro_: &RsCelMacro,
511 ) -> Result<CelValue, CelError> {
512 let mut v = Vec::new();
513 for arg in args.iter() {
514 if let CelValue::ByteCode(bc) = arg {
515 v.push(bc);
516 } else {
517 return Err(CelError::internal("macro args must be bytecode"));
518 }
519 }
520 let res = macro_(self, this.clone(), &v);
521 Ok(res)
522 }
523
524 fn resolve_args(&self, args: Vec<CelValue>) -> Result<Vec<CelValue>, CelError> {
525 let mut arg_values = Vec::new();
526 for arg in args.into_iter() {
527 if let CelValue::ByteCode(bc) = arg {
528 arg_values.push(self.run_raw(&bc, true)?);
529 } else {
530 arg_values.push(arg)
531 }
532 }
533 Ok(arg_values)
534 }
535
536 fn get_param_by_name(&self, name: &str) -> Option<&'a CelValue> {
537 self.bindings?.get_param(name)
538 }
539
540 fn get_func_by_name(&self, name: &str) -> Option<&'a RsCelFunction> {
541 self.bindings?.get_func(name)
542 }
543
544 fn get_macro_by_name(&self, name: &str) -> Option<&'a RsCelMacro> {
545 self.bindings?.get_macro(name)
546 }
547
548 fn get_type_by_name(&self, name: &str) -> Option<&'a CelValue> {
549 self.bindings?.get_type(name)
550 }
551
552 fn callable_by_name(&self, name: &str) -> CelResult<RsCallable> {
553 if let Some(func) = self.get_func_by_name(name) {
554 Ok(RsCallable::Function(func))
555 } else if let Some(macro_) = self.get_macro_by_name(name) {
556 Ok(RsCallable::Macro(macro_))
557 } else {
558 Err(CelError::value(&format!("{} is not callable", name)))
559 }
560 }
561}
562
563#[cfg(test)]
564mod test {
565 use crate::{types::CelByteCode, CelValue};
566
567 use super::{types::ByteCode, Interpreter};
568 use test_case::test_case;
569
570 #[test_case(ByteCode::Add, 7.into())]
571 #[test_case(ByteCode::Sub, 1.into())]
572 #[test_case(ByteCode::Mul, 12.into())]
573 #[test_case(ByteCode::Div, 1.into())]
574 #[test_case(ByteCode::Mod, 1.into())]
575 #[test_case(ByteCode::Lt, false.into())]
576 #[test_case(ByteCode::Le, false.into())]
577 #[test_case(ByteCode::Eq, false.into())]
578 #[test_case(ByteCode::Ne, true.into())]
579 #[test_case(ByteCode::Ge, true.into())]
580 #[test_case(ByteCode::Gt, true.into())]
581 fn test_interp_ops(op: ByteCode, expected: CelValue) {
582 let mut prog =
583 CelByteCode::from_vec(vec![ByteCode::Push(4.into()), ByteCode::Push(3.into())]);
584 prog.push(op);
585 let interp = Interpreter::empty();
586
587 assert!(interp.run_raw(&prog, true).unwrap() == expected);
588 }
589}