filter_expr_evaler/
lib.rs

1//! Evaluator for filter expressions.
2
3mod asm;
4mod asm_codegen;
5mod ast_runner;
6mod bc;
7mod bc_codegen;
8mod bc_runner;
9mod builtin;
10mod callable;
11mod ctx;
12mod error;
13mod value;
14
15use std::collections::BTreeMap;
16use std::sync::{Arc, RwLock};
17
18use filter_expr::{Expr, FilterExpr};
19use moka::sync::Cache;
20use regex::Regex;
21
22use crate::ast_runner::AstRunner;
23use crate::bc::Bytecode;
24
25pub use crate::callable::{ArcFunction, Function, FunctionContext};
26pub use crate::callable::{ArcMethod, Method, MethodContext};
27pub use crate::ctx::{Context, SimpleContext};
28pub use crate::error::Error;
29pub use crate::value::{Value, ValueType};
30
31/// The environment for the filter expression evaluator.
32#[derive(Clone)]
33pub struct FilterExprEvalerEnv {
34    inner: Arc<RwLock<FilterExprEvalerEnvInner>>,
35}
36
37struct FilterExprEvalerEnvInner {
38    /// The cached compiled regex.
39    ///
40    /// It is useful to avoid compiling the same regex multiple times.
41    cached_regex: Cache<String, Arc<Regex>>,
42
43    /// The functions.
44    functions: BTreeMap<String, ArcFunction>,
45
46    /// The methods.
47    methods: BTreeMap<(String, ValueType), ArcMethod>,
48}
49
50impl FilterExprEvalerEnv {
51    /// Create a new environment.
52    pub(crate) fn new() -> Self {
53        // Initialize the builtin functions.
54        let mut functions: BTreeMap<String, ArcFunction> = BTreeMap::new();
55        functions.insert("matches".to_string(), Arc::new(builtin::FunctionMatches));
56        functions.insert("type".to_string(), Arc::new(builtin::FunctionType));
57
58        // Initialize the builtin methods.
59        let mut methods: BTreeMap<(String, ValueType), ArcMethod> = BTreeMap::new();
60        methods.insert(
61            ("to_uppercase".to_string(), ValueType::Str),
62            Arc::new(builtin::MethodStrToUppercase),
63        );
64        methods.insert(
65            ("to_lowercase".to_string(), ValueType::Str),
66            Arc::new(builtin::MethodStrToLowercase),
67        );
68        methods.insert(
69            ("contains".to_string(), ValueType::Str),
70            Arc::new(builtin::MethodStrContains),
71        );
72        methods.insert(
73            ("starts_with".to_string(), ValueType::Str),
74            Arc::new(builtin::MethodStrStartsWith),
75        );
76        methods.insert(
77            ("ends_with".to_string(), ValueType::Str),
78            Arc::new(builtin::MethodStrEndsWith),
79        );
80
81        let inner = Arc::new(RwLock::new(FilterExprEvalerEnvInner {
82            cached_regex: Cache::new(128),
83            functions,
84            methods,
85        }));
86
87        Self { inner }
88    }
89
90    /// Add a function to the environment.
91    pub(crate) fn add_function(&self, name: String, function: ArcFunction) -> Result<(), Error> {
92        self.inner
93            .write()
94            .map_err(|e| Error::Internal(format!("failed to lock env: {e}")))?
95            .functions
96            .insert(name, function);
97        Ok(())
98    }
99
100    /// Add a method to the environment.
101    pub(crate) fn add_method(
102        &self,
103        name: String,
104        obj_type: ValueType,
105        method: ArcMethod,
106    ) -> Result<(), Error> {
107        self.inner
108            .write()
109            .map_err(|e| Error::Internal(format!("failed to lock env: {e}")))?
110            .methods
111            .insert((name, obj_type), method);
112        Ok(())
113    }
114
115    /// Get a function from the environment.
116    pub(crate) fn get_function(&self, name: &str) -> Result<ArcFunction, Error> {
117        let inner = self
118            .inner
119            .read()
120            .map_err(|e| Error::Internal(format!("failed to lock env: {e}")))?;
121        let function = inner
122            .functions
123            .get(name)
124            .ok_or_else(|| Error::NoSuchFunction {
125                function: name.to_string(),
126            })?;
127        Ok(Arc::clone(function))
128    }
129
130    /// Get a method from the environment.
131    pub(crate) fn get_method(&self, name: &str, obj_type: ValueType) -> Result<ArcMethod, Error> {
132        let inner = self
133            .inner
134            .read()
135            .map_err(|e| Error::Internal(format!("failed to lock env: {e}")))?;
136        let method = inner
137            .methods
138            .get(&(name.to_string(), obj_type))
139            .ok_or_else(|| Error::NoSuchMethod {
140                method: name.to_string(),
141                obj_type,
142            })?;
143        Ok(Arc::clone(method))
144    }
145}
146
147impl FilterExprEvalerEnv {
148    /// Get a regex (if cached, return the cached one; otherwise, compile and
149    /// cache it).
150    pub(crate) fn get_regex(&self, pattern: &str) -> Result<Arc<Regex>, Error> {
151        let inner = self
152            .inner
153            .read()
154            .map_err(|e| Error::Internal(format!("failed to lock env: {e}")))?;
155        let cached_regex = inner.cached_regex.get(pattern);
156        if let Some(cached) = cached_regex {
157            Ok(cached)
158        } else {
159            let regex = Regex::new(pattern)
160                .map_err(|e| Error::Internal(format!("failed to compile regex: {e}")))?;
161            let regex_arc = Arc::new(regex);
162            inner
163                .cached_regex
164                .insert(pattern.to_string(), regex_arc.clone());
165            Ok(regex_arc)
166        }
167    }
168}
169
170pub struct FilterExprEvaler {
171    /// The cached compiled bytecode.
172    ///
173    /// Used to avoid compiling the same expression multiple times.
174    cached_bytecode: Cache<Expr, Arc<Bytecode>>,
175
176    /// The global environment shared by runners.
177    env: FilterExprEvalerEnv,
178}
179
180impl Default for FilterExprEvaler {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186impl FilterExprEvaler {
187    /// Create a new filter expression evaluator.
188    pub fn new() -> Self {
189        Self {
190            cached_bytecode: Cache::new(128),
191            env: FilterExprEvalerEnv::new(),
192        }
193    }
194
195    /// Evaluate the filter expression using the default runner.
196    pub async fn eval(&self, filter_expr: &FilterExpr, ctx: &dyn Context) -> Result<bool, Error> {
197        self.eval_by_bytecode_runner(filter_expr, ctx).await
198    }
199
200    /// Evaluate the filter expression using AST runner.
201    pub async fn eval_by_ast_runner(
202        &self,
203        filter_expr: &FilterExpr,
204        ctx: &dyn Context,
205    ) -> Result<bool, Error> {
206        if let Some(expr) = filter_expr.expr() {
207            let ast_runner = AstRunner::new(expr, self.env.clone());
208            let value = ast_runner.run(ctx).await?;
209
210            match value {
211                Value::Bool(b) => Ok(b),
212                _ => Err(Error::InvalidValue(format!("{value:?} is not a bool"))),
213            }
214        } else {
215            Ok(true)
216        }
217    }
218
219    /// Evaluate the filter expression using bytecode execution.
220    pub async fn eval_by_bytecode_runner(
221        &self,
222        filter_expr: &FilterExpr,
223        ctx: &dyn Context,
224    ) -> Result<bool, Error> {
225        if let Some(expr) = filter_expr.expr() {
226            let bytecode = {
227                let cached_bytecode = self.cached_bytecode.get(expr);
228
229                // Check if bytecode is cached, generate if not.
230                if let Some(cached) = cached_bytecode {
231                    cached
232                } else {
233                    // Generate bytecode and cache it.
234                    let asm = asm_codegen::AsmCodegen::new().codegen(expr);
235                    let bytecode = bc_codegen::BytecodeCodegen::new().codegen(asm);
236                    let bytecode_arc = Arc::new(bytecode);
237
238                    self.cached_bytecode
239                        .insert(expr.clone(), bytecode_arc.clone());
240                    bytecode_arc
241                }
242            };
243
244            let runner = bc_runner::BytecodeRunner::new(&bytecode, self.env.clone());
245            let value = unsafe { runner.run(ctx).await }?;
246
247            match value {
248                Value::Bool(b) => Ok(b),
249                _ => Err(Error::InvalidValue(format!("{value:?} is not a bool"))),
250            }
251        } else {
252            Ok(true)
253        }
254    }
255
256    pub fn add_function(&self, name: String, function: ArcFunction) -> Result<(), Error> {
257        self.env.add_function(name, function)
258    }
259
260    pub fn add_method(
261        &self,
262        name: String,
263        obj_type: ValueType,
264        method: ArcMethod,
265    ) -> Result<(), Error> {
266        self.env.add_method(name, obj_type, method)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use crate::callable::Function;
273
274    use super::*;
275
276    #[tokio::test]
277    async fn test_parse_and_then_eval() {
278        let evaler = FilterExprEvaler::new();
279        evaler
280            .add_function("custom_add".to_string(), Arc::new(CustomAddFn))
281            .unwrap();
282
283        macro_rules! parse_and_do_test_cases {
284            ($input:expr, $test_cases:expr $(,)?) => {
285                parse_and_do_test_cases(&evaler, $input, $test_cases).await
286            };
287        }
288
289        // Parse the filter-expr:
290        //
291        //     name = 'John' AND age > 18 AND 1 > 0
292        // =====================================================================
293        parse_and_do_test_cases!(
294            "name = 'John' AND age > 18 AND 1 > 0",
295            &[
296                (simple_context! { "name": "John", "age": 19 }, true),
297                (simple_context! { "name": "John", "age": 18 }, false),
298            ],
299        );
300
301        // Parse the filter-expr:
302        //
303        //     name = "John" AND age IN [18, 19, 20, 22] AND 1 > 0
304        // =====================================================================
305        parse_and_do_test_cases!(
306            r#"name = "John" AND age IN [18, 19, 20, 22] AND 1 > 0"#,
307            &[
308                (simple_context! { "name": "John", "age": 19 }, true),
309                (simple_context! { "name": "John", "age": 23 }, false),
310            ],
311        );
312
313        // Parse the filter-expr:
314        //
315        //     matches(name, "^J.*n$")
316        // =====================================================================
317        parse_and_do_test_cases!(
318            r#"matches(name, "^J.*n$")"#,
319            &[
320                (simple_context! { "name": "John" }, true),
321                (simple_context! { "name": "Jane" }, false),
322            ],
323        );
324
325        // Parse the filter-expr:
326        //
327        //     custom_add(1, 2) = 3
328        // =====================================================================
329        parse_and_do_test_cases!(
330            r#"custom_add(a, b) = 3"#,
331            &[
332                (simple_context! { "a": 1, "b": 2 }, true),
333                (simple_context! { "a": 1, "b": 3 }, false),
334            ],
335        );
336
337        // Parse the filter-expr:
338        //
339        //     name != null
340        // =====================================================================
341        parse_and_do_test_cases!(
342            r#"name != null"#,
343            &[
344                (simple_context! { "name": Value::Null }, false),
345                (simple_context! { "name": "John" }, true),
346            ],
347        );
348
349        // Parse the filter-expr:
350        //
351        //     open > 1.5 AND age > 17.5 AND age < 18.5 AND is_peter = true
352        // =====================================================================
353        parse_and_do_test_cases!(
354            r#"open > 1.5 AND age > 17.5 AND age < 18.5 AND is_peter = true"#,
355            &[(
356                simple_context! { "open": 1.6, "age": 18, "is_peter": true },
357                true,
358            )],
359        );
360
361        // Parse the filter-expr:
362        //
363        //     name.to_uppercase() = 'JOHN'
364        // =====================================================================
365        parse_and_do_test_cases!(
366            r#"name.to_uppercase() = 'JOHN'"#,
367            &[
368                (simple_context! { "name": "john" }, true),
369                (simple_context! { "name": "Jane" }, false),
370                (simple_context! { "name": "John" }, true),
371            ],
372        );
373
374        // Parse the filter-expr:
375        //
376        //     name.contains('John')
377        // =====================================================================
378        parse_and_do_test_cases!(
379            r#"name.contains('John')"#,
380            &[
381                (simple_context! { "name": "John" }, true),
382                (simple_context! { "name": "Jane" }, false),
383                (simple_context! { "name": "The John is a good boy." }, true),
384            ],
385        );
386
387        // Parse the filter-expr:
388        //
389        //     type(name) = 'str'
390        //     type(name) = 'null'
391        //     type(foo.contains('bar')) = 'bool'
392        //     type(age) = 'i64'
393        //     type(open) = 'f64'
394        //     type(maybe_i64_or_f64) IN ['i64', 'f64']
395        // =====================================================================
396        parse_and_do_test_cases!(
397            r#"type(name) = 'str'"#,
398            &[
399                (simple_context! { "name": "John" }, true),
400                (simple_context! { "name": 18 }, false),
401                (simple_context! { "name": Value::Null }, false),
402            ],
403        );
404
405        parse_and_do_test_cases!(
406            r#"type(name) = 'null'"#,
407            &[
408                (simple_context! { "name": "John" }, false),
409                (simple_context! { "name": Value::Null }, true),
410                (simple_context! { "name": 18 }, false),
411            ],
412        );
413
414        parse_and_do_test_cases!(
415            r#"type(foo.contains('bar')) = 'bool'"#,
416            &[
417                (simple_context! { "foo": "foobar" }, true),
418                (simple_context! { "foo": "bar and foo" }, true),
419            ],
420        );
421
422        parse_and_do_test_cases!(
423            r#"type(age) = 'i64'"#,
424            &[
425                (simple_context! { "age": 18 }, true),
426                (simple_context! { "age": 18.5 }, false),
427            ],
428        );
429
430        parse_and_do_test_cases!(
431            r#"type(open) = 'f64'"#,
432            &[
433                (simple_context! { "open": 18 }, false),
434                (simple_context! { "open": 18.5 }, true),
435                (simple_context! { "open": "18" }, false),
436            ],
437        );
438
439        parse_and_do_test_cases!(
440            r#"type(maybe_i64_or_f64) IN ['i64', 'f64']"#,
441            &[
442                (simple_context! { "maybe_i64_or_f64": 18 }, true),
443                (simple_context! { "maybe_i64_or_f64": 18.5 }, true),
444                (simple_context! { "maybe_i64_or_f64": "18" }, false),
445            ],
446        );
447
448        // Parse the filter-expr:
449        //
450        //     name.starts_with('J')
451        //     name.ends_with('n')
452        // =====================================================================
453        parse_and_do_test_cases!(
454            r#"name.starts_with('J')"#,
455            &[
456                (simple_context! { "name": "John" }, true),
457                (simple_context! { "name": "Peterlits" }, false),
458            ],
459        );
460
461        parse_and_do_test_cases!(
462            r#"name.ends_with('n')"#,
463            &[
464                (simple_context! { "name": "John" }, true),
465                (simple_context! { "name": "Jane" }, false),
466            ],
467        );
468    }
469
470    async fn parse_and_do_test_cases(
471        evaler: &FilterExprEvaler,
472        input: &str,
473        test_cases: &[(SimpleContext, bool)],
474    ) {
475        let filter_expr =
476            FilterExpr::parse(input).unwrap_or_else(|_| panic!("failed to parse: {input}"));
477
478        for (ctx, expected) in test_cases {
479            let result = evaler
480                .eval_by_bytecode_runner(&filter_expr, ctx)
481                .await
482                .unwrap_or_else(|e| {
483                    panic!("failed to eval by bytecode runner (input={input}): {e}")
484                });
485            assert_eq!(
486                result, *expected,
487                "{input} failed with context with bytecode runner {ctx:?}"
488            );
489            let result = evaler
490                .eval_by_ast_runner(&filter_expr, ctx)
491                .await
492                .unwrap_or_else(|e| panic!("failed to eval by ast runner (input={input}): {e}"));
493            assert_eq!(
494                result, *expected,
495                "{input} failed with context with ast runner {ctx:?}"
496            );
497        }
498    }
499
500    struct CustomAddFn;
501
502    #[async_trait::async_trait]
503    impl Function for CustomAddFn {
504        async fn call(&self, ctx: FunctionContext<'_, '_>) -> Result<Value, Error> {
505            if ctx.args.len() != 2 {
506                return Err(Error::InvalidArgumentCountForFunction {
507                    function: "custom_add".to_string(),
508                    expected: 2,
509                    got: ctx.args.len(),
510                });
511            }
512            let a = match ctx.args[0] {
513                Value::I64(a) => a,
514                _ => {
515                    return Err(Error::InvalidArgumentTypeForFunction {
516                        function: "custom_add".to_string(),
517                        index: 0,
518                        expected: ValueType::I64,
519                        got: ctx.args[0].typ(),
520                    });
521                }
522            };
523            let b = match ctx.args[1] {
524                Value::I64(b) => b,
525                _ => {
526                    return Err(Error::InvalidArgumentTypeForFunction {
527                        function: "custom_add".to_string(),
528                        index: 1,
529                        expected: ValueType::I64,
530                        got: ctx.args[1].typ(),
531                    });
532                }
533            };
534            Ok(Value::i64(a + b))
535        }
536    }
537}