filter_expr_evaler/
lib.rs

1//! Evaluator for filter expressions.
2
3mod asm;
4mod asm_codegen;
5mod ast_runner;
6mod bc;
7mod bc_codegen;
8mod bc_runner;
9mod ctx;
10mod error;
11mod value;
12
13use std::collections::BTreeMap;
14use std::sync::{Arc, Mutex};
15
16use filter_expr::{Expr, FilterExpr};
17use regex::Regex;
18
19use crate::ast_runner::AstRunner;
20use crate::bc::Bytecode;
21
22pub use crate::ctx::{Context, ExprFn, ExprFnContext, SimpleContext};
23pub use crate::error::Error;
24pub use crate::value::{Value, ValueType};
25
26struct FilterExprEvalerCache {
27    /// The cached compiled bytecode.
28    ///
29    /// Used to avoid compiling the same expression multiple times.
30    pub(crate) cached_bytecode: Option<Arc<Bytecode>>,
31
32    /// The cached compiled regex.
33    ///
34    /// It is useful to avoid compiling the same regex multiple times.
35    pub(crate) cached_regex: BTreeMap<String, Arc<Regex>>,
36}
37
38pub struct FilterExprEvaler {
39    filter_expr: FilterExpr,
40
41    cache: Arc<Mutex<FilterExprEvalerCache>>,
42}
43
44impl FilterExprEvaler {
45    /// Create a new filter expression evaluator.
46    pub fn new(filter_expr: FilterExpr) -> Self {
47        Self {
48            filter_expr,
49            cache: Arc::new(Mutex::new(FilterExprEvalerCache {
50                cached_bytecode: None,
51                cached_regex: BTreeMap::new(),
52            })),
53        }
54    }
55
56    /// Evaluate the filter expression using the default runner.
57    pub async fn eval(&self, ctx: &dyn Context) -> Result<bool, Error> {
58        self.eval_by_bytecode_runner(ctx).await
59    }
60
61    /// Evaluate the filter expression in the given context.
62    pub async fn eval_by_bytecode_runner(&self, ctx: &dyn Context) -> Result<bool, Error> {
63        if let Some(expr) = self.filter_expr.expr() {
64            let value = FilterExprEvalerInner::new(self, expr)
65                .eval_by_bytecode_runner(ctx)
66                .await?;
67            match value {
68                Value::Bool(b) => Ok(b),
69                _ => Err(Error::InvalidValue(format!("{value:?}"))),
70            }
71        } else {
72            Ok(true)
73        }
74    }
75
76    /// Evaluate the filter expression using AST runner.
77    pub async fn eval_by_ast_runner(&self, ctx: &dyn Context) -> Result<bool, Error> {
78        if let Some(expr) = self.filter_expr.expr() {
79            let value = FilterExprEvalerInner::new(self, expr)
80                .eval_by_ast_runner(ctx)
81                .await?;
82            match value {
83                Value::Bool(b) => Ok(b),
84                _ => Err(Error::InvalidValue(format!("{value:?}"))),
85            }
86        } else {
87            Ok(true)
88        }
89    }
90
91    /// Evaluate the filter expression using bytecode execution.
92    pub async fn eval_by_bytecode(&self, ctx: &dyn Context) -> Result<bool, Error> {
93        if let Some(expr) = self.filter_expr.expr() {
94            let value = FilterExprEvalerInner::new(self, expr)
95                .eval_by_bytecode_runner(ctx)
96                .await?;
97            match value {
98                Value::Bool(b) => Ok(b),
99                _ => Err(Error::InvalidValue(format!("{value:?}"))),
100            }
101        } else {
102            Ok(true)
103        }
104    }
105}
106
107struct FilterExprEvalerInner<'a> {
108    evaler: &'a FilterExprEvaler,
109    expr: &'a Expr,
110}
111
112impl<'a> FilterExprEvalerInner<'a> {
113    pub fn new(evaler: &'a FilterExprEvaler, expr: &'a Expr) -> Self {
114        Self { evaler, expr }
115    }
116
117    /// Evaluate the expression using AST runner.
118    pub async fn eval_by_ast_runner(&self, ctx: &dyn Context) -> Result<Value, Error> {
119        AstRunner::new(self.expr, self.evaler.cache.clone())
120            .run(ctx)
121            .await
122    }
123
124    /// Evaluate the expression using bytecode runner.
125    pub async fn eval_by_bytecode_runner(&self, ctx: &dyn Context) -> Result<Value, Error> {
126        let bytecode = {
127            // Check if bytecode is cached, generate if not.
128            let mut cache = self
129                .evaler
130                .cache
131                .lock()
132                .map_err(|e| Error::Internal(format!("failed to lock cache: {e}")))?;
133            if cache.cached_bytecode.is_none() {
134                // Generate bytecode and cache it.
135                let asm = asm_codegen::AsmCodegen::new().codegen(self.expr);
136                let bytecode = bc_codegen::BytecodeCodegen::new().codegen(asm);
137                cache.cached_bytecode = Some(Arc::new(bytecode));
138            }
139
140            // Get the cached bytecode.
141            match cache.cached_bytecode.as_ref() {
142                Some(bytecode) => bytecode.clone(),
143                None => unreachable!("bytecode should be already cached"),
144            }
145        };
146
147        let runner = bc_runner::BytecodeRunner::new(&bytecode, self.evaler.cache.clone());
148
149        // Safety: The bytecode is generated by a internal AsmCodegen.
150        unsafe { runner.run(ctx).await }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use crate::ctx::ExprFn;
157
158    use super::*;
159
160    #[tokio::test]
161    async fn test_parse_and_then_eval() {
162        // Parse the filter-expr:
163        //
164        //     name = 'John' AND age > 18 AND 1 > 0
165        // =====================================================================
166        parse_and_do_test_cases(
167            "name = 'John' AND age > 18 AND 1 > 0",
168            vec![
169                (simple_context! { "name": "John", "age": 19 }, true),
170                (simple_context! { "name": "John", "age": 18 }, false),
171            ],
172        )
173        .await;
174
175        // Parse the filter-expr:
176        //
177        //     name = "John" AND age IN [18, 19, 20, 22] AND 1 > 0
178        // =====================================================================
179        parse_and_do_test_cases(
180            r#"name = "John" AND age IN [18, 19, 20, 22] AND 1 > 0"#,
181            vec![
182                (simple_context! { "name": "John", "age": 19 }, true),
183                (simple_context! { "name": "John", "age": 23 }, false),
184            ],
185        )
186        .await;
187
188        // Parse the filter-expr:
189        //
190        //     matches(name, "^J.*n$")
191        // =====================================================================
192        parse_and_do_test_cases(
193            r#"matches(name, "^J.*n$")"#,
194            vec![
195                (simple_context! { "name": "John" }, true),
196                (simple_context! { "name": "Jane" }, false),
197            ],
198        )
199        .await;
200
201        // Parse the filter-expr:
202        //
203        //     custom_add(1, 2) = 3
204        // =====================================================================
205        fn with_custom_add_fn(mut ctx: SimpleContext) -> SimpleContext {
206            ctx.add_fn("custom_add".to_string(), Box::new(CustomAddFn));
207            ctx
208        }
209        parse_and_do_test_cases(
210            r#"custom_add(a, b) = 3"#,
211            vec![
212                (with_custom_add_fn(simple_context! { "a": 1, "b": 2 }), true),
213                (
214                    with_custom_add_fn(simple_context! { "a": 1, "b": 3 }),
215                    false,
216                ),
217            ],
218        )
219        .await;
220
221        // Parse the filter-expr:
222        //
223        //     name != null
224        // =====================================================================
225        parse_and_do_test_cases(
226            r#"name != null"#,
227            vec![
228                (simple_context! { "name": Value::Null }, false),
229                (simple_context! { "name": "John" }, true),
230            ],
231        )
232        .await;
233
234        // Parse the filter-expr:
235        //
236        //     open > 1.5 AND age > 17.5 AND age < 18.5 AND is_peter = true
237        // =====================================================================
238        parse_and_do_test_cases(
239            r#"open > 1.5 AND age > 17.5 AND age < 18.5 AND is_peter = true"#,
240            vec![(
241                simple_context! { "open": 1.6, "age": 18, "is_peter": true },
242                true,
243            )],
244        )
245        .await;
246
247        // Parse the filter-expr:
248        //
249        //     name.to_uppercase() = 'JOHN'
250        // =====================================================================
251        parse_and_do_test_cases(
252            r#"name.to_uppercase() = 'JOHN'"#,
253            vec![
254                (simple_context! { "name": "john" }, true),
255                (simple_context! { "name": "Jane" }, false),
256                (simple_context! { "name": "John" }, true),
257            ],
258        )
259        .await;
260
261        // Parse the filter-expr:
262        //
263        //     name.contains('John')
264        // =====================================================================
265        parse_and_do_test_cases(
266            r#"name.contains('John')"#,
267            vec![
268                (simple_context! { "name": "John" }, true),
269                (simple_context! { "name": "Jane" }, false),
270                (simple_context! { "name": "The John is a good boy." }, true),
271            ],
272        )
273        .await;
274
275        // Parse the filter-expr:
276        //
277        //     type(name) = 'str'
278        //     type(name) = 'null'
279        //     type(foo.contains('bar')) = 'bool'
280        //     type(age) = 'i64'
281        //     type(open) = 'f64'
282        //     type(maybe_i64_or_f64) IN ['i64', 'f64']
283        // =====================================================================
284        parse_and_do_test_cases(
285            r#"type(name) = 'str'"#,
286            vec![
287                (simple_context! { "name": "John" }, true),
288                (simple_context! { "name": 18 }, false),
289                (simple_context! { "name": Value::Null }, false),
290            ],
291        )
292        .await;
293
294        parse_and_do_test_cases(
295            r#"type(name) = 'null'"#,
296            vec![
297                (simple_context! { "name": "John" }, false),
298                (simple_context! { "name": Value::Null }, true),
299                (simple_context! { "name": 18 }, false),
300            ],
301        )
302        .await;
303
304        parse_and_do_test_cases(
305            r#"type(foo.contains('bar')) = 'bool'"#,
306            vec![
307                (simple_context! { "foo": "foobar" }, true),
308                (simple_context! { "foo": "bar and foo" }, true),
309            ],
310        )
311        .await;
312
313        parse_and_do_test_cases(
314            r#"type(age) = 'i64'"#,
315            vec![
316                (simple_context! { "age": 18 }, true),
317                (simple_context! { "age": 18.5 }, false),
318            ],
319        )
320        .await;
321
322        parse_and_do_test_cases(
323            r#"type(open) = 'f64'"#,
324            vec![
325                (simple_context! { "open": 18 }, false),
326                (simple_context! { "open": 18.5 }, true),
327                (simple_context! { "open": "18" }, false),
328            ],
329        )
330        .await;
331
332        parse_and_do_test_cases(
333            r#"type(maybe_i64_or_f64) IN ['i64', 'f64']"#,
334            vec![
335                (simple_context! { "maybe_i64_or_f64": 18 }, true),
336                (simple_context! { "maybe_i64_or_f64": 18.5 }, true),
337                (simple_context! { "maybe_i64_or_f64": "18" }, false),
338            ],
339        )
340        .await;
341
342        // Parse the filter-expr:
343        //
344        //     name.starts_with('J')
345        //     name.ends_with('n')
346        // =====================================================================
347        parse_and_do_test_cases(
348            r#"name.starts_with('J')"#,
349            vec![
350                (simple_context! { "name": "John" }, true),
351                (simple_context! { "name": "Peterlits" }, false),
352            ],
353        )
354        .await;
355
356        parse_and_do_test_cases(
357            r#"name.ends_with('n')"#,
358            vec![
359                (simple_context! { "name": "John" }, true),
360                (simple_context! { "name": "Jane" }, false),
361            ],
362        )
363        .await;
364    }
365
366    async fn parse_and_do_test_cases(input: &str, test_cases: Vec<(SimpleContext, bool)>) {
367        let filter_expr =
368            FilterExpr::parse(input).unwrap_or_else(|_| panic!("failed to parse: {input}"));
369        let evaler = FilterExprEvaler::new(filter_expr);
370
371        for (ctx, expected) in test_cases {
372            let result = evaler
373                .eval_by_bytecode_runner(&ctx)
374                .await
375                .unwrap_or_else(|_| panic!("failed to eval by bytecode runner: {input}"));
376            assert_eq!(
377                result, expected,
378                "{input} failed with context with bytecode runner {ctx:?}"
379            );
380            let result = evaler
381                .eval_by_ast_runner(&ctx)
382                .await
383                .unwrap_or_else(|_| panic!("failed to eval by ast runner: {input}"));
384            assert_eq!(
385                result, expected,
386                "{input} failed with context with ast runner {ctx:?}"
387            );
388        }
389    }
390
391    struct CustomAddFn;
392
393    #[async_trait::async_trait]
394    impl ExprFn for CustomAddFn {
395        async fn call(&self, ctx: ExprFnContext) -> Result<Value, Error> {
396            if ctx.args.len() != 2 {
397                return Err(Error::InvalidArgumentCountForFunction {
398                    function: "custom_add".to_string(),
399                    expected: 2,
400                    got: ctx.args.len(),
401                });
402            }
403            let a = match ctx.args[0] {
404                Value::I64(a) => a,
405                _ => {
406                    return Err(Error::InvalidArgumentTypeForFunction {
407                        function: "custom_add".to_string(),
408                        index: 0,
409                        expected: ValueType::I64,
410                        got: ctx.args[0].typ(),
411                    });
412                }
413            };
414            let b = match ctx.args[1] {
415                Value::I64(b) => b,
416                _ => {
417                    return Err(Error::InvalidArgumentTypeForFunction {
418                        function: "custom_add".to_string(),
419                        index: 1,
420                        expected: ValueType::I64,
421                        got: ctx.args[1].typ(),
422                    });
423                }
424            };
425            Ok(Value::i64(a + b))
426        }
427    }
428}