filter_expr/
expr.rs

1use crate::{Error, ctx::Context};
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum Expr {
5    Field(String),
6    Str(String),
7    I64(i64),
8    F64(f64),
9    Bool(bool),
10    Null,
11    Array(Vec<Expr>),
12
13    FuncCall(String, Vec<Expr>),
14
15    Gt(Box<Expr>, Box<Expr>),
16    Lt(Box<Expr>, Box<Expr>),
17    Ge(Box<Expr>, Box<Expr>),
18    Le(Box<Expr>, Box<Expr>),
19    Eq(Box<Expr>, Box<Expr>),
20    Ne(Box<Expr>, Box<Expr>),
21    In(Box<Expr>, Box<Expr>),
22
23    And(Vec<Expr>),
24    Or(Vec<Expr>),
25    Not(Box<Expr>),
26}
27
28#[allow(unused)]
29impl Expr {
30    pub(crate) fn field_<T: Into<String>>(field: T) -> Self {
31        Self::Field(field.into())
32    }
33
34    pub(crate) fn str_<T: Into<String>>(value: T) -> Self {
35        Self::Str(value.into())
36    }
37
38    pub(crate) fn i64_<T: Into<i64>>(value: T) -> Self {
39        Self::I64(value.into())
40    }
41
42    pub(crate) fn f64_<T: Into<f64>>(value: T) -> Self {
43        Self::F64(value.into())
44    }
45
46    pub(crate) fn bool_<T: Into<bool>>(value: T) -> Self {
47        Self::Bool(value.into())
48    }
49
50    pub(crate) fn null_() -> Self {
51        Self::Null
52    }
53
54    pub(crate) fn array_<T: Into<Vec<Expr>>>(value: T) -> Self {
55        Self::Array(value.into())
56    }
57
58    pub(crate) fn gt_(left: Expr, right: Expr) -> Self {
59        Self::Gt(Box::new(left), Box::new(right))
60    }
61
62    pub(crate) fn lt_(left: Expr, right: Expr) -> Self {
63        Self::Lt(Box::new(left), Box::new(right))
64    }
65
66    pub(crate) fn ge_(left: Expr, right: Expr) -> Self {
67        Self::Ge(Box::new(left), Box::new(right))
68    }
69
70    pub(crate) fn le_(left: Expr, right: Expr) -> Self {
71        Self::Le(Box::new(left), Box::new(right))
72    }
73
74    pub(crate) fn eq_(left: Expr, right: Expr) -> Self {
75        Self::Eq(Box::new(left), Box::new(right))
76    }
77
78    pub(crate) fn ne_(left: Expr, right: Expr) -> Self {
79        Self::Ne(Box::new(left), Box::new(right))
80    }
81
82    pub(crate) fn in_(left: Expr, right: Expr) -> Self {
83        Self::In(Box::new(left), Box::new(right))
84    }
85
86    pub(crate) fn and_<T: Into<Vec<Expr>>>(value: T) -> Self {
87        Self::And(value.into())
88    }
89
90    pub(crate) fn or_<T: Into<Vec<Expr>>>(value: T) -> Self {
91        Self::Or(value.into())
92    }
93
94    pub(crate) fn not_(self) -> Self {
95        Self::Not(Box::new(self))
96    }
97}
98
99impl Expr {
100    pub(crate) async fn eval(&self, ctx: &dyn Context) -> Result<ExprValue, Error> {
101        match self {
102            Self::Field(field) => {
103                let value = ctx.get_var(field).await?;
104                Ok(value)
105            }
106            Self::Str(value) => Ok(ExprValue::Str(value.clone())),
107            Self::I64(value) => Ok(ExprValue::I64(value.clone())),
108            Self::F64(value) => Ok(ExprValue::F64(value.clone())),
109            Self::Bool(value) => Ok(ExprValue::Bool(value.clone())),
110            Self::Null => Ok(ExprValue::Null),
111            Self::Array(value) => self.eval_array(value, ctx).await,
112
113            Self::FuncCall(func, args) => self.eval_func_call(func, args, ctx).await,
114
115            Self::Gt(left, right) => {
116                let left_value = Box::pin(left.eval(ctx)).await?;
117                let right_value = Box::pin(right.eval(ctx)).await?;
118                match left_value.partial_cmp(&right_value) {
119                    Some(ordering) => Ok(ExprValue::Bool(ordering == std::cmp::Ordering::Greater)),
120                    None => Err(Error::TypeMismatch(
121                        format!("{:?}", left_value),
122                        format!("{:?}", right_value),
123                    )),
124                }
125            }
126            Self::Lt(left, right) => {
127                let left_value = Box::pin(left.eval(ctx)).await?;
128                let right_value = Box::pin(right.eval(ctx)).await?;
129                match left_value.partial_cmp(&right_value) {
130                    Some(ordering) => Ok(ExprValue::Bool(ordering == std::cmp::Ordering::Less)),
131                    None => Err(Error::TypeMismatch(
132                        format!("{:?}", left_value),
133                        format!("{:?}", right_value),
134                    )),
135                }
136            }
137            Self::Ge(left, right) => {
138                let left_value = Box::pin(left.eval(ctx)).await?;
139                let right_value = Box::pin(right.eval(ctx)).await?;
140                match left_value.partial_cmp(&right_value) {
141                    Some(ordering) => Ok(ExprValue::Bool(
142                        ordering == std::cmp::Ordering::Greater
143                            || ordering == std::cmp::Ordering::Equal,
144                    )),
145                    None => Err(Error::TypeMismatch(
146                        format!("{:?}", left_value),
147                        format!("{:?}", right_value),
148                    )),
149                }
150            }
151            Self::Le(left, right) => {
152                let left_value = Box::pin(left.eval(ctx)).await?;
153                let right_value = Box::pin(right.eval(ctx)).await?;
154                match left_value.partial_cmp(&right_value) {
155                    Some(ordering) => Ok(ExprValue::Bool(
156                        ordering == std::cmp::Ordering::Less
157                            || ordering == std::cmp::Ordering::Equal,
158                    )),
159                    None => Err(Error::TypeMismatch(
160                        format!("{:?}", left_value),
161                        format!("{:?}", right_value),
162                    )),
163                }
164            }
165            Self::Eq(left, right) => {
166                let left_value = Box::pin(left.eval(ctx)).await?;
167                let right_value = Box::pin(right.eval(ctx)).await?;
168                Ok(ExprValue::Bool(left_value == right_value))
169            }
170            Self::Ne(left, right) => {
171                let left_value = Box::pin(left.eval(ctx)).await?;
172                let right_value = Box::pin(right.eval(ctx)).await?;
173                Ok(ExprValue::Bool(left_value != right_value))
174            }
175            Self::In(left, right) => {
176                let left_value = Box::pin(left.eval(ctx)).await?;
177                let right_value = Box::pin(right.eval(ctx)).await?;
178                match right_value {
179                    ExprValue::Array(array) => Ok(ExprValue::Bool(array.contains(&left_value))),
180                    _ => Err(Error::TypeMismatch(
181                        format!("{:?}", right_value),
182                        format!("{:?}", left_value),
183                    )),
184                }
185            }
186
187            Self::And(exprs) => {
188                let mut result = true;
189                for expr in exprs {
190                    let value = Box::pin(expr.eval(ctx)).await?;
191                    match value {
192                        ExprValue::Bool(b) => result = result && b,
193                        _ => {
194                            return Err(Error::InvalidValue(format!(
195                                "expected bool, got {:?}",
196                                value
197                            )));
198                        }
199                    }
200                }
201                Ok(ExprValue::Bool(result))
202            }
203            Self::Or(exprs) => {
204                let mut result = false;
205                for expr in exprs {
206                    let value = Box::pin(expr.eval(ctx)).await?;
207                    match value {
208                        ExprValue::Bool(b) => result = result || b,
209                        _ => {
210                            return Err(Error::InvalidValue(format!(
211                                "expected bool, got {:?}",
212                                value
213                            )));
214                        }
215                    }
216                }
217                Ok(ExprValue::Bool(result))
218            }
219            Self::Not(expr) => {
220                let value = Box::pin(expr.eval(ctx)).await?;
221                match value {
222                    ExprValue::Bool(b) => Ok(ExprValue::Bool(!b)),
223                    _ => Err(Error::TypeMismatch(format!("{:?}", value), format!("bool"))),
224                }
225            }
226        }
227    }
228
229    pub(crate) async fn eval_array(&self, array: &[Expr], ctx: &dyn Context) -> Result<ExprValue, Error> {
230        let mut values = vec![];
231        for expr in array {
232            let value = Box::pin(expr.eval(ctx)).await?;
233            values.push(value);
234        }
235        Ok(ExprValue::Array(values))
236    }
237
238    pub(crate) async fn eval_func_call(
239        &self,
240        func: &str,
241        args: &[Expr],
242        ctx: &dyn Context,
243    ) -> Result<ExprValue, Error> {
244        // Evaluate the arguments.
245        let mut args_values = vec![];
246        for arg in args {
247            let value = Box::pin(arg.eval(ctx)).await?;
248            args_values.push(value);
249        }
250
251        // Get the function to call.
252        let func_name = func;
253        let func = ctx.get_fn(func).await;
254
255        // Call the function or call the builtin function.
256        if let Some(func) = func {
257            return func.call(ExprFnContext { args: args_values }).await;
258        } else {
259            match func_name {
260                "matches" => self.eval_builtin_func_call_matches(&args_values).await,
261                _ => Err(Error::NoSuchFunction(func_name.to_string())),
262            }
263        }
264    }
265
266    pub(crate) async fn eval_builtin_func_call_matches(
267        &self,
268        args: &[ExprValue],
269    ) -> Result<ExprValue, Error> {
270        if args.len() != 2 {
271            return Err(Error::InvalidArgumentCount {
272                expected: 2,
273                got: args.len(),
274            });
275        }
276        let text = match &args[0] {
277            ExprValue::Str(s) => s,
278            _ => {
279                return Err(Error::InvalidArgumentType {
280                    expected: "string".to_string(),
281                    got: format!("{:?}", args[0]),
282                });
283            }
284        };
285        let pattern = match &args[1] {
286            ExprValue::Str(s) => s,
287            _ => {
288                return Err(Error::InvalidArgumentType {
289                    expected: "string".to_string(),
290                    got: format!("{:?}", args[1]),
291                });
292            }
293        };
294        let pattern = regex::Regex::new(&pattern);
295        let pattern = match pattern {
296            Ok(pattern) => pattern,
297            Err(e) => {
298                return Err(Error::Internal(format!("failed to compile regex: {}", e)));
299            }
300        };
301
302        let matches = pattern.is_match(&text);
303        Ok(ExprValue::Bool(matches))
304    }
305}
306
307#[derive(Debug, Clone, PartialEq)]
308pub enum ExprValue {
309    Str(String),
310    I64(i64),
311    F64(f64),
312    Bool(bool),
313    Null,
314
315    Array(Vec<ExprValue>),
316}
317
318impl PartialOrd for ExprValue {
319    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
320        match (self, other) {
321            (ExprValue::Str(a), ExprValue::Str(b)) => a.partial_cmp(b),
322            (ExprValue::I64(a), ExprValue::I64(b)) => a.partial_cmp(b),
323            (ExprValue::F64(a), ExprValue::F64(b)) => a.partial_cmp(b),
324            (ExprValue::Bool(a), ExprValue::Bool(b)) => a.partial_cmp(b),
325            (ExprValue::Null, ExprValue::Null) => Some(std::cmp::Ordering::Equal),
326
327            (ExprValue::F64(a), ExprValue::I64(b)) => a.partial_cmp(&(*b as f64)),
328            (ExprValue::I64(a), ExprValue::F64(b)) => (*a as f64).partial_cmp(b),
329
330            (ExprValue::Array(a), ExprValue::Array(b)) => a.partial_cmp(b),
331
332            (ExprValue::Null, _) => Some(std::cmp::Ordering::Greater),
333            (_, ExprValue::Null) => Some(std::cmp::Ordering::Less),
334
335            _ => None, // Different types cannot be compared...
336        }
337    }
338}
339
340impl Into<ExprValue> for String {
341    fn into(self) -> ExprValue {
342        ExprValue::Str(self)
343    }
344}
345
346impl Into<ExprValue> for &str {
347    fn into(self) -> ExprValue {
348        ExprValue::Str(self.to_string())
349    }
350}
351
352impl Into<ExprValue> for i64 {
353    fn into(self) -> ExprValue {
354        ExprValue::I64(self)
355    }
356}
357
358impl Into<ExprValue> for f64 {
359    fn into(self) -> ExprValue {
360        ExprValue::F64(self)
361    }
362}
363
364impl Into<ExprValue> for bool {
365    fn into(self) -> ExprValue {
366        ExprValue::Bool(self)
367    }
368}
369
370impl<T: Into<ExprValue>> Into<ExprValue> for Vec<T> {
371    fn into(self) -> ExprValue {
372        ExprValue::Array(self.into_iter().map(|item| item.into()).collect())
373    }
374}
375
376pub struct ExprFnContext {
377    pub args: Vec<ExprValue>,
378}
379
380#[async_trait::async_trait]
381pub trait ExprFn: Send + Sync {
382    async fn call(&self, ctx: ExprFnContext) -> Result<ExprValue, Error>;
383}
384
385pub type BoxedExprFn = Box<dyn ExprFn>;
386
387/// A trait for transforming AST expressions.
388///
389/// This trait allows you to recursively transform expressions by visiting
390/// all sub-expressions. The `transform` method is called recursively on all
391/// sub-expressions, allowing you to transform the AST in a composable way.
392///
393/// # Example
394///
395/// ```rust
396/// use filter_expr::{Expr, Transform};
397///
398/// struct MyTransformer;
399///
400/// impl Transform for MyTransformer {
401///     fn transform(&mut self, expr: Expr) -> Expr {
402///         // Transform the expression before recursing
403///         match expr {
404///             Expr::Field(name) if name == "old_name" => {
405///                 Expr::Field("new_name".to_string())
406///             }
407///             other => other,
408///         }
409///     }
410/// }
411/// ```
412pub trait Transform {
413    /// Transform an expression by recursively transforming all sub-expressions.
414    fn transform(&mut self, expr: Expr) -> Expr
415    where
416        Self: Sized;
417}
418
419impl Expr {
420    /// Recursively transform an expression using the provided transformer.
421    /// 
422    /// ```rust
423    /// use filter_expr::{Expr, Transform};
424    ///
425    /// struct MyTransformer;
426    ///
427    /// impl Transform for MyTransformer {
428    ///     fn transform(&mut self, expr: Expr) -> Expr {
429    ///         // Transform the expression before recursing
430    ///         match expr {
431    ///             Expr::Field(name) if name == "old_name" => {
432    ///                 Expr::Field("new_name".to_string())
433    ///             }
434    ///             other => other,
435    ///         }
436    ///     }
437    /// }
438    ///
439    /// let expr = Expr::Field("old_name".to_string());
440    /// let mut transformer = MyTransformer;
441    /// let result = expr.transform(&mut transformer);
442    /// assert_eq!(result, Expr::Field("new_name".to_string()));
443    /// ```
444    pub fn transform<F: Transform>(self, transformer: &mut F) -> Expr {
445        let this = transformer.transform(self);
446        
447        match this {
448            Expr::Field(name) => Expr::Field(name),
449            Expr::Str(value) => Expr::Str(value),
450            Expr::I64(value) => Expr::I64(value),
451            Expr::F64(value) => Expr::F64(value),
452            Expr::Bool(value) => Expr::Bool(value),
453            Expr::Null => Expr::Null,
454            Expr::Array(value) => Expr::Array(value),
455            Expr::FuncCall(func, args) => {
456                let args = args
457                    .into_iter()
458                    .map(|arg| transformer.transform(arg))
459                    .collect();
460                Expr::FuncCall(func, args)
461            }
462            Expr::Gt(left, right) => {
463                let left = Box::new(transformer.transform(*left));
464                let right = Box::new(transformer.transform(*right));
465                Expr::Gt(left, right)
466            }
467            Expr::Lt(left, right) => {
468                let left = Box::new(transformer.transform(*left));
469                let right = Box::new(transformer.transform(*right));
470                Expr::Lt(left, right)
471            }
472            Expr::Ge(left, right) => {
473                let left = Box::new(transformer.transform(*left));
474                let right = Box::new(transformer.transform(*right));
475                Expr::Ge(left, right)
476            }
477            Expr::Le(left, right) => {
478                let left = Box::new(transformer.transform(*left));
479                let right = Box::new(transformer.transform(*right));
480                Expr::Le(left, right)
481            }
482            Expr::Eq(left, right) => {
483                let left = Box::new(transformer.transform(*left));
484                let right = Box::new(transformer.transform(*right));
485                Expr::Eq(left, right)
486            }
487            Expr::Ne(left, right) => {
488                let left = Box::new(transformer.transform(*left));
489                let right = Box::new(transformer.transform(*right));
490                Expr::Ne(left, right)
491            }
492            Expr::In(left, right) => {
493                let left = Box::new(transformer.transform(*left));
494                let right = Box::new(transformer.transform(*right));
495                Expr::In(left, right)
496            }
497            Expr::And(exprs) => {
498                let exprs = exprs
499                    .into_iter()
500                    .map(|e| transformer.transform(e))
501                    .collect();
502                Expr::And(exprs)
503            }
504            Expr::Or(exprs) => {
505                let exprs = exprs
506                    .into_iter()
507                    .map(|e| transformer.transform(e))
508                    .collect();
509                Expr::Or(exprs)
510            }
511            Expr::Not(expr) => {
512                let expr = Box::new(transformer.transform(*expr));
513                Expr::Not(expr)
514            }
515        }
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_expr_value_ordering() {
525        // Test string ordering.
526        assert!(ExprValue::Str("a".to_string()) < ExprValue::Str("b".to_string()));
527        assert!(ExprValue::Str("a".to_string()) <= ExprValue::Str("a".to_string()));
528        assert!(ExprValue::Str("b".to_string()) > ExprValue::Str("a".to_string()));
529
530        // Test integer ordering.
531        assert!(ExprValue::I64(1) < ExprValue::I64(2));
532        assert!(ExprValue::I64(1) <= ExprValue::I64(1));
533        assert!(ExprValue::I64(2) > ExprValue::I64(1));
534
535        // Test float ordering.
536        assert!(ExprValue::F64(1.0) < ExprValue::F64(2.0));
537        assert!(ExprValue::F64(1.0) <= ExprValue::F64(1.0));
538        assert!(ExprValue::F64(2.0) > ExprValue::F64(1.0));
539
540        // Test boolean ordering.
541        assert!(ExprValue::Bool(false) < ExprValue::Bool(true));
542        assert!(ExprValue::Bool(false) <= ExprValue::Bool(false));
543        assert!(ExprValue::Bool(true) > ExprValue::Bool(false));
544
545        // Test Int and Float comparison.
546        assert!(ExprValue::I64(1) < ExprValue::F64(2.0));
547        assert!(ExprValue::I64(2) > ExprValue::F64(1.0));
548        assert!(ExprValue::F64(1.0) < ExprValue::I64(2));
549        assert!(ExprValue::F64(2.0) > ExprValue::I64(1));
550
551        // Test Null ordering.
552        assert!(ExprValue::Null == ExprValue::Null);
553        assert!(ExprValue::Null > ExprValue::Str("a".to_string()));
554        assert!(ExprValue::Str("a".to_string()) < ExprValue::Null);
555        assert!(ExprValue::Null > ExprValue::I64(1));
556        assert!(ExprValue::I64(1) < ExprValue::Null);
557
558        // Test array ordering.
559        let arr1 = ExprValue::Array(vec![ExprValue::I64(1), ExprValue::I64(2)]);
560        let arr2 = ExprValue::Array(vec![ExprValue::I64(1), ExprValue::I64(3)]);
561        assert!(arr1 < arr2);
562        assert!(arr1 <= arr1);
563        assert!(arr2 > arr1);
564
565        // Test incompatible types (should return None).
566        assert!(
567            ExprValue::Str("a".to_string())
568                .partial_cmp(&ExprValue::I64(1))
569                .is_none()
570        );
571        assert!(
572            ExprValue::I64(1)
573                .partial_cmp(&ExprValue::Bool(true))
574                .is_none()
575        );
576        assert!(
577            ExprValue::Str("a".to_string())
578                .partial_cmp(&ExprValue::Bool(false))
579                .is_none()
580        );
581        assert!(
582            ExprValue::Array(vec![])
583                .partial_cmp(&ExprValue::I64(1))
584                .is_none()
585        );
586    }
587
588    #[test]
589    fn test_transform_expr() {
590        // Example: Rename field "old_name" to "new_name"
591        struct RenameField {
592            old_name: String,
593            new_name: String,
594        }
595
596        impl Transform for RenameField {
597            fn transform(&mut self, expr: Expr) -> Expr {
598                match expr {
599                    Expr::Field(name) if name == self.old_name => {
600                        Expr::Field(self.new_name.clone())
601                    }
602                    _ => expr,
603                }
604            }
605        }
606
607        let expr = Expr::Eq(
608            Box::new(Expr::field_("old_name")),
609            Box::new(Expr::str_("value")),
610        );
611
612        let mut transformer = RenameField {
613            old_name: "old_name".to_string(),
614            new_name: "new_name".to_string(),
615        };
616
617        let result = expr.transform(&mut transformer);
618        assert_eq!(
619            result,
620            Expr::Eq(
621                Box::new(Expr::field_("new_name")),
622                Box::new(Expr::str_("value"))
623            )
624        );
625    }
626}