Skip to main content

filter_expr/
expr.rs

1use std::cmp::Ordering;
2use std::hash::{Hash, Hasher};
3
4use crate::{Error, Transform, TransformContext, TransformResult};
5
6/// The expression.
7///
8/// It is an AST of the filter expression.
9#[derive(Debug, Clone)]
10pub enum Expr {
11    Field(String),
12    FieldAccess(Box<Expr>, String),
13
14    Str(String),
15    I64(i64),
16    F64(f64),
17    Bool(bool),
18    Null,
19
20    Array(Vec<Expr>),
21
22    FuncCall(String, Vec<Expr>),
23    MethodCall(String, Box<Expr>, Vec<Expr>),
24
25    Gt(Box<Expr>, Box<Expr>),
26    Lt(Box<Expr>, Box<Expr>),
27    Ge(Box<Expr>, Box<Expr>),
28    Le(Box<Expr>, Box<Expr>),
29    Eq(Box<Expr>, Box<Expr>),
30    Ne(Box<Expr>, Box<Expr>),
31    In(Box<Expr>, Box<Expr>),
32
33    And(Vec<Expr>),
34    Or(Vec<Expr>),
35    Not(Box<Expr>),
36}
37
38impl PartialOrd for Expr {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl Ord for Expr {
45    fn cmp(&self, other: &Self) -> Ordering {
46        use Expr::*;
47        use ordered_float::OrderedFloat;
48
49        // Helper function to get variant order.
50        fn variant_order(expr: &Expr) -> u16 {
51            use Expr::*;
52            match expr {
53                Field(..) => 100,
54                FieldAccess(..) => 101,
55
56                Str(..) => 200,
57                I64(..) => 201,
58                F64(..) => 202,
59                Bool(..) => 203,
60                Null => 204,
61
62                Array(..) => 300,
63
64                FuncCall(..) => 400,
65                MethodCall(..) => 401,
66
67                Gt(..) => 500,
68                Lt(..) => 501,
69                Ge(..) => 502,
70                Le(..) => 503,
71                Eq(..) => 504,
72                Ne(..) => 505,
73                In(..) => 506,
74
75                And(..) => 600,
76                Or(..) => 601,
77                Not(..) => 602,
78            }
79        }
80
81        // First compare by variant order.
82        match variant_order(self).cmp(&variant_order(other)) {
83            Ordering::Equal => {
84                // Same variant, compare contents.
85                match (self, other) {
86                    (Field(a), Field(b)) => a.cmp(b),
87                    (FieldAccess(o1, f1), FieldAccess(o2, f2)) => match o1.cmp(o2) {
88                        Ordering::Equal => f1.cmp(f2),
89                        other => other,
90                    },
91
92                    (Str(a), Str(b)) => a.cmp(b),
93                    (I64(a), I64(b)) => a.cmp(b),
94                    (F64(a), F64(b)) => OrderedFloat(*a).cmp(&OrderedFloat(*b)),
95                    (Bool(a), Bool(b)) => a.cmp(b),
96                    (Null, Null) => Ordering::Equal,
97
98                    (Array(a), Array(b)) => a.cmp(b),
99
100                    (FuncCall(f1, a1), FuncCall(f2, a2)) => match f1.cmp(f2) {
101                        Ordering::Equal => a1.cmp(a2),
102                        other => other,
103                    },
104                    (MethodCall(m1, o1, a1), MethodCall(m2, o2, a2)) => match m1.cmp(m2) {
105                        Ordering::Equal => match o1.cmp(o2) {
106                            Ordering::Equal => a1.cmp(a2),
107                            other => other,
108                        },
109                        other => other,
110                    },
111
112                    (Gt(l1, r1), Gt(l2, r2)) => match l1.cmp(l2) {
113                        Ordering::Equal => r1.cmp(r2),
114                        other => other,
115                    },
116                    (Lt(l1, r1), Lt(l2, r2)) => match l1.cmp(l2) {
117                        Ordering::Equal => r1.cmp(r2),
118                        other => other,
119                    },
120                    (Ge(l1, r1), Ge(l2, r2)) => match l1.cmp(l2) {
121                        Ordering::Equal => r1.cmp(r2),
122                        other => other,
123                    },
124                    (Le(l1, r1), Le(l2, r2)) => match l1.cmp(l2) {
125                        Ordering::Equal => r1.cmp(r2),
126                        other => other,
127                    },
128                    (Eq(l1, r1), Eq(l2, r2)) => match l1.cmp(l2) {
129                        Ordering::Equal => r1.cmp(r2),
130                        other => other,
131                    },
132                    (Ne(l1, r1), Ne(l2, r2)) => match l1.cmp(l2) {
133                        Ordering::Equal => r1.cmp(r2),
134                        other => other,
135                    },
136                    (In(l1, r1), In(l2, r2)) => match l1.cmp(l2) {
137                        Ordering::Equal => r1.cmp(r2),
138                        other => other,
139                    },
140
141                    (And(a), And(b)) => a.cmp(b),
142                    (Or(a), Or(b)) => a.cmp(b),
143                    (Not(a), Not(b)) => a.cmp(b),
144
145                    _ => unreachable!(),
146                }
147            }
148            other => other,
149        }
150    }
151}
152
153impl PartialEq for Expr {
154    fn eq(&self, other: &Self) -> bool {
155        self.cmp(other) == Ordering::Equal
156    }
157}
158
159impl Eq for Expr {}
160
161impl Hash for Expr {
162    fn hash<H: Hasher>(&self, state: &mut H) {
163        use Expr::*;
164        use ordered_float::OrderedFloat;
165
166        // Hash the discriminant first.
167        std::mem::discriminant(self).hash(state);
168
169        // Then hash the contents.
170        match self {
171            Field(s) => s.hash(state),
172            FieldAccess(obj, field) => {
173                obj.hash(state);
174                field.hash(state);
175            }
176
177            Str(s) => s.hash(state),
178            I64(i) => i.hash(state),
179            F64(f) => OrderedFloat(*f).hash(state),
180            Bool(b) => b.hash(state),
181            Null => {},
182            Array(v) => v.hash(state),
183            FuncCall(name, args) => {
184                name.hash(state);
185                args.hash(state);
186            }
187            MethodCall(method, obj, args) => {
188                method.hash(state);
189                obj.hash(state);
190                args.hash(state);
191            }
192            Gt(l, r) => {
193                l.hash(state);
194                r.hash(state);
195            }
196            Lt(l, r) => {
197                l.hash(state);
198                r.hash(state);
199            }
200            Ge(l, r) => {
201                l.hash(state);
202                r.hash(state);
203            }
204            Le(l, r) => {
205                l.hash(state);
206                r.hash(state);
207            }
208            Eq(l, r) => {
209                l.hash(state);
210                r.hash(state);
211            }
212            Ne(l, r) => {
213                l.hash(state);
214                r.hash(state);
215            }
216            In(l, r) => {
217                l.hash(state);
218                r.hash(state);
219            }
220            And(v) => v.hash(state),
221            Or(v) => v.hash(state),
222            Not(e) => e.hash(state),
223        }
224    }
225}
226
227impl Expr {
228    pub fn field_<T: Into<String>>(field: T) -> Self {
229        Self::Field(field.into())
230    }
231
232    pub fn field_access_(obj: Expr, field: impl Into<String>) -> Self {
233        Self::FieldAccess(Box::new(obj), field.into())
234    }
235
236    pub fn str_<T: Into<String>>(value: T) -> Self {
237        Self::Str(value.into())
238    }
239
240    pub fn i64_<T: Into<i64>>(value: T) -> Self {
241        Self::I64(value.into())
242    }
243
244    pub fn f64_<T: Into<f64>>(value: T) -> Self {
245        Self::F64(value.into())
246    }
247
248    pub fn bool_<T: Into<bool>>(value: T) -> Self {
249        Self::Bool(value.into())
250    }
251
252    pub fn null_() -> Self {
253        Self::Null
254    }
255
256    pub fn array_<T: Into<Vec<Expr>>>(value: T) -> Self {
257        Self::Array(value.into())
258    }
259
260    pub fn func_call_(func: impl Into<String>, args: Vec<Expr>) -> Self {
261        Self::FuncCall(func.into(), args)
262    }
263
264    pub fn method_call_(obj: Expr, method: impl Into<String>, args: Vec<Expr>) -> Self {
265        Self::MethodCall(method.into(), Box::new(obj), args)
266    }
267
268    pub fn gt_(left: Expr, right: Expr) -> Self {
269        Self::Gt(Box::new(left), Box::new(right))
270    }
271
272    pub fn lt_(left: Expr, right: Expr) -> Self {
273        Self::Lt(Box::new(left), Box::new(right))
274    }
275
276    pub fn ge_(left: Expr, right: Expr) -> Self {
277        Self::Ge(Box::new(left), Box::new(right))
278    }
279
280    pub fn le_(left: Expr, right: Expr) -> Self {
281        Self::Le(Box::new(left), Box::new(right))
282    }
283
284    pub fn eq_(left: Expr, right: Expr) -> Self {
285        Self::Eq(Box::new(left), Box::new(right))
286    }
287
288    pub fn ne_(left: Expr, right: Expr) -> Self {
289        Self::Ne(Box::new(left), Box::new(right))
290    }
291
292    pub fn in_(left: Expr, right: Expr) -> Self {
293        Self::In(Box::new(left), Box::new(right))
294    }
295
296    pub fn and_<T: Into<Vec<Expr>>>(value: T) -> Self {
297        Self::And(value.into())
298    }
299
300    pub fn or_<T: Into<Vec<Expr>>>(value: T) -> Self {
301        Self::Or(value.into())
302    }
303
304    pub fn not_(self) -> Self {
305        Self::Not(Box::new(self))
306    }
307}
308
309impl Expr {
310    /// Recursively transform an expression using the provided transformer.
311    ///
312    /// ```rust
313    /// use filter_expr::{Expr, Transform};
314    /// use async_trait::async_trait;
315    ///
316    /// struct MyTransformer;
317    ///
318    /// #[async_trait]
319    /// impl Transform for MyTransformer {
320    ///     async fn transform(&mut self, expr: Expr) -> Result<Expr, filter_expr::Error> {
321    ///         // Transform the expression before recursing
322    ///         Ok(match expr {
323    ///             Expr::Field(name) if name == "old_name" => {
324    ///                 Expr::Field("new_name".to_string())
325    ///             }
326    ///             other => other,
327    ///         })
328    ///     }
329    /// }
330    ///
331    /// # #[tokio::main]
332    /// # async fn main() {
333    /// let expr = Expr::Field("old_name".to_string());
334    /// let mut transformer = MyTransformer;
335    /// let result = expr.transform(&mut transformer).await.unwrap();
336    /// assert_eq!(result, Expr::Field("new_name".to_string()));
337    /// # }
338    /// ```
339    pub async fn transform<F: Transform>(self, transformer: &mut F) -> Result<Expr, Error> {
340        let ctx = TransformContext { depth: 0 };
341
342        return Self::transform_expr(transformer, self, ctx).await;
343    }
344
345    async fn transform_expr<F: Transform>(transformer: &mut F, expr: Expr, ctx: TransformContext) -> Result<Expr, Error> {
346        let this = transformer.transform(expr, ctx.clone()).await;
347
348        match this {
349            TransformResult::Continue(expr) => {
350                return Box::pin(Self::transform_children(transformer, expr, ctx)).await;
351            }
352            TransformResult::Stop(expr) => {
353                return Ok(expr);
354            }
355            TransformResult::Err(err) => {
356                return Err(Error::Transform(err));
357            }
358        }
359    }
360
361    async fn transform_children<F: Transform>(transformer: &mut F, expr: Expr, mut ctx: TransformContext) -> Result<Expr, Error> {
362        ctx.depth += 1;
363
364        Ok(match expr {
365            // Do nothing if the expression have no children.
366            Expr::Field(name) => Expr::Field(name),
367            Expr::FieldAccess(obj, field) => {
368                let obj = Box::new(Self::transform_expr(transformer, *obj, ctx.clone()).await?);
369                Expr::FieldAccess(obj, field)
370            }
371
372            Expr::Str(value) => Expr::Str(value),
373            Expr::I64(value) => Expr::I64(value),
374            Expr::F64(value) => Expr::F64(value),
375            Expr::Bool(value) => Expr::Bool(value),
376            Expr::Null => Expr::Null,
377            Expr::Array(value) => Expr::Array(value),
378
379            Expr::FuncCall(func, args) => {
380                let mut transformed_args = Vec::new();
381                for arg in args {
382                    transformed_args.push(Self::transform_expr(transformer, arg, ctx.clone()).await?);
383                }
384                Expr::FuncCall(func, transformed_args)
385            }
386            Expr::MethodCall(method, obj, args) => {
387                let obj = Box::new(Self::transform_expr(transformer, *obj, ctx.clone()).await?);
388                let mut transformed_args = Vec::new();
389                for arg in args {
390                    transformed_args.push(Self::transform_expr(transformer, arg, ctx.clone()).await?);
391                }
392                Expr::MethodCall(method, obj, transformed_args)
393            }
394
395            Expr::Gt(left, right) => {
396                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
397                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
398                Expr::Gt(left, right)
399            }
400            Expr::Lt(left, right) => {
401                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
402                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
403                Expr::Lt(left, right)
404            }
405            Expr::Ge(left, right) => {
406                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
407                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
408                Expr::Ge(left, right)
409            }
410            Expr::Le(left, right) => {
411                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
412                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
413                Expr::Le(left, right)
414            }
415            Expr::Eq(left, right) => {
416                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
417                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
418                Expr::Eq(left, right)
419            }
420            Expr::Ne(left, right) => {
421                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
422                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
423                Expr::Ne(left, right)
424            }
425            Expr::In(left, right) => {
426                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
427                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
428                Expr::In(left, right)
429            }
430            Expr::And(exprs) => {
431                let mut transformed_exprs = Vec::new();
432                for e in exprs {
433                    transformed_exprs.push(Self::transform_expr(transformer, e, ctx.clone()).await?);
434                }
435                Expr::And(transformed_exprs)
436            }
437            Expr::Or(exprs) => {
438                let mut transformed_exprs = Vec::new();
439                for e in exprs {
440                    transformed_exprs.push(Self::transform_expr(transformer, e, ctx.clone()).await?);
441                }
442                Expr::Or(transformed_exprs)
443            }
444            Expr::Not(expr) => {
445                let expr = Box::new(Self::transform_expr(transformer, *expr, ctx).await?);
446                Expr::Not(expr)
447            }
448        })
449    }
450}
451
452impl Expr {
453    /// Optimize the expression by applying constant folding and simplification
454    /// rules.
455    ///
456    /// Examples:
457    /// 
458    /// - `true AND true` → `true`
459    /// - `true AND false` → `false`
460    /// - `NOT NOT expr` → `expr`
461    /// - `1 > 2` → `false`
462    pub fn optimize(self) -> Self {
463        use Expr::*;
464
465        match self {
466            // Leaf nodes - no optimization needed.
467            Field(_) | Str(_) | I64(_) | F64(_) | Bool(_) | Null => self,
468
469            // Field access - optimize the object.
470            FieldAccess(obj, field) => {
471                FieldAccess(Box::new(obj.optimize()), field)
472            }
473
474            // Array - optimize all elements.
475            Array(elements) => {
476                Array(elements.into_iter().map(|e| e.optimize()).collect())
477            }
478
479            // Function call - optimize all arguments.
480            FuncCall(func, args) => {
481                FuncCall(func, args.into_iter().map(|a| a.optimize()).collect())
482            }
483
484            // Method call - optimize object and arguments.
485            MethodCall(method, obj, args) => {
486                MethodCall(method, Box::new(obj.optimize()), args.into_iter().map(|a| a.optimize()).collect())
487            }
488
489            // Comparison operators - optimize and fold constants.
490            Gt(left, right) => {
491                let left = left.optimize();
492                let right = right.optimize();
493                match (&left, &right) {
494                    (I64(a), I64(b)) => Bool(*a > *b),
495                    (F64(a), F64(b)) => Bool(*a > *b),
496                    (Str(a), Str(b)) => Bool(a > b),
497                    _ => Gt(Box::new(left), Box::new(right)),
498                }
499            }
500            Lt(left, right) => {
501                let left = left.optimize();
502                let right = right.optimize();
503                match (&left, &right) {
504                    (I64(a), I64(b)) => Bool(*a < *b),
505                    (F64(a), F64(b)) => Bool(*a < *b),
506                    (Str(a), Str(b)) => Bool(a < b),
507                    _ => Lt(Box::new(left), Box::new(right)),
508                }
509            }
510            Ge(left, right) => {
511                let left = left.optimize();
512                let right = right.optimize();
513                match (&left, &right) {
514                    (I64(a), I64(b)) => Bool(*a >= *b),
515                    (F64(a), F64(b)) => Bool(*a >= *b),
516                    (Str(a), Str(b)) => Bool(a >= b),
517                    _ => Ge(Box::new(left), Box::new(right)),
518                }
519            }
520            Le(left, right) => {
521                let left = left.optimize();
522                let right = right.optimize();
523                match (&left, &right) {
524                    (I64(a), I64(b)) => Bool(*a <= *b),
525                    (F64(a), F64(b)) => Bool(*a <= *b),
526                    (Str(a), Str(b)) => Bool(a <= b),
527                    _ => Le(Box::new(left), Box::new(right)),
528                }
529            }
530            Eq(left, right) => {
531                let left = left.optimize();
532                let right = right.optimize();
533                match (&left, &right) {
534                    (I64(a), I64(b)) => Bool(*a == *b),
535                    (F64(a), F64(b)) => Bool(*a == *b),
536                    (Str(a), Str(b)) => Bool(a == b),
537                    (Bool(a), Bool(b)) => Bool(*a == *b),
538                    (Null, Null) => Bool(true),
539                    _ => Eq(Box::new(left), Box::new(right)),
540                }
541            }
542            Ne(left, right) => {
543                let left = left.optimize();
544                let right = right.optimize();
545                match (&left, &right) {
546                    (I64(a), I64(b)) => Bool(*a != *b),
547                    (F64(a), F64(b)) => Bool(*a != *b),
548                    (Str(a), Str(b)) => Bool(a != b),
549                    (Bool(a), Bool(b)) => Bool(*a != *b),
550                    (Null, Null) => Bool(false),
551                    _ => Ne(Box::new(left), Box::new(right)),
552                }
553            }
554            In(left, right) => {
555                let left = left.optimize();
556                let right = right.optimize();
557                In(Box::new(left), Box::new(right))
558            }
559
560            // AND optimization.
561            And(exprs) => {
562                let mut optimized: Vec<Expr> = Vec::new();
563                let mut has_false = false;
564
565                for expr in exprs {
566                    let opt_expr = expr.optimize();
567                    match &opt_expr {
568                        Bool(true) => {
569                            // Skip true values in AND.
570                            continue;
571                        }
572                        Bool(false) => {
573                            // If any operand is false, the whole AND is false.
574                            has_false = true;
575                            break;
576                        }
577                        _ => {
578                            optimized.push(opt_expr);
579                        }
580                    }
581                }
582
583                if has_false {
584                    Bool(false)
585                } else if optimized.is_empty() {
586                    // Empty AND is true (identity element)
587                    Bool(true)
588                } else if optimized.len() == 1 {
589                    // Single element AND is just that element
590                    optimized.into_iter().next().unwrap()
591                } else {
592                    And(optimized)
593                }
594            }
595
596            // OR optimization.
597            Or(exprs) => {
598                let mut optimized: Vec<Expr> = Vec::new();
599                let mut has_true = false;
600
601                for expr in exprs {
602                    let opt_expr = expr.optimize();
603                    match &opt_expr {
604                        Bool(false) => {
605                            // Skip false values in OR
606                            continue;
607                        }
608                        Bool(true) => {
609                            // If any operand is true, the whole OR is true
610                            has_true = true;
611                            break;
612                        }
613                        _ => {
614                            optimized.push(opt_expr);
615                        }
616                    }
617                }
618
619                if has_true {
620                    Bool(true)
621                } else if optimized.is_empty() {
622                    // Empty OR is false (identity element)
623                    Bool(false)
624                } else if optimized.len() == 1 {
625                    // Single element OR is just that element
626                    optimized.into_iter().next().unwrap()
627                } else {
628                    Or(optimized)
629                }
630            }
631
632            // NOT optimization.
633            Not(expr) => {
634                let opt_expr = expr.optimize();
635                match opt_expr {
636                    Bool(true) => Bool(false),
637                    Bool(false) => Bool(true),
638                    Not(inner) => {
639                        // Double negation: NOT NOT expr -> expr
640                        *inner
641                    }
642                    other => Not(Box::new(other)),
643                }
644            }
645        }
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn test_optimize() {
655        let expr = Expr::And(vec![Expr::Bool(true), Expr::Bool(false)]);
656        let optimized = expr.optimize();
657        assert_eq!(optimized, Expr::Bool(false));
658
659        let expr = Expr::Or(vec![Expr::Bool(false), Expr::Bool(true)]);
660        let optimized = expr.optimize();
661        assert_eq!(optimized, Expr::Bool(true));
662
663        let expr = Expr::Not(Box::new(Expr::Bool(true)));
664        let optimized = expr.optimize();
665        assert_eq!(optimized, Expr::Bool(false));
666    }
667}