filter_expr/
expr.rs

1use std::cmp::Ordering;
2use std::hash::{Hash, Hasher};
3
4use crate::{Error, Transform, TransformContext, TransformResult};
5
6/// The expression.
7///
8/// It is an AST of the filter expression.
9#[derive(Debug, Clone)]
10pub enum Expr {
11    Field(String),
12    Str(String),
13    I64(i64),
14    F64(f64),
15    Bool(bool),
16    Null,
17    Array(Vec<Expr>),
18
19    FuncCall(String, Vec<Expr>),
20    MethodCall(String, Box<Expr>, Vec<Expr>),
21
22    Gt(Box<Expr>, Box<Expr>),
23    Lt(Box<Expr>, Box<Expr>),
24    Ge(Box<Expr>, Box<Expr>),
25    Le(Box<Expr>, Box<Expr>),
26    Eq(Box<Expr>, Box<Expr>),
27    Ne(Box<Expr>, Box<Expr>),
28    In(Box<Expr>, Box<Expr>),
29
30    And(Vec<Expr>),
31    Or(Vec<Expr>),
32    Not(Box<Expr>),
33}
34
35impl PartialOrd for Expr {
36    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
37        Some(self.cmp(other))
38    }
39}
40
41impl Ord for Expr {
42    fn cmp(&self, other: &Self) -> Ordering {
43        use Expr::*;
44        use ordered_float::OrderedFloat;
45
46        // Helper function to get variant order.
47        fn variant_order(expr: &Expr) -> u8 {
48            use Expr::*;
49            match expr {
50                Field(..) => 0,
51                Str(..) => 1,
52                I64(..) => 2,
53                F64(..) => 3,
54                Bool(..) => 4,
55                Null => 5,
56                Array(..) => 6,
57
58                FuncCall(..) => 7,
59                MethodCall(..) => 8,
60
61                Gt(..) => 9,
62                Lt(..) => 10,
63                Ge(..) => 11,
64                Le(..) => 12,
65                Eq(..) => 13,
66                Ne(..) => 14,
67                In(..) => 15,
68
69                And(..) => 16,
70                Or(..) => 17,
71                Not(..) => 18,
72            }
73        }
74
75        // First compare by variant order.
76        match variant_order(self).cmp(&variant_order(other)) {
77            Ordering::Equal => {
78                // Same variant, compare contents.
79                match (self, other) {
80                    (Field(a), Field(b)) => a.cmp(b),
81                    (Str(a), Str(b)) => a.cmp(b),
82                    (I64(a), I64(b)) => a.cmp(b),
83                    (F64(a), F64(b)) => OrderedFloat(*a).cmp(&OrderedFloat(*b)),
84                    (Bool(a), Bool(b)) => a.cmp(b),
85                    (Null, Null) => Ordering::Equal,
86                    (Array(a), Array(b)) => a.cmp(b),
87                    (FuncCall(f1, a1), FuncCall(f2, a2)) => match f1.cmp(f2) {
88                        Ordering::Equal => a1.cmp(a2),
89                        other => other,
90                    },
91                    (MethodCall(m1, o1, a1), MethodCall(m2, o2, a2)) => match m1.cmp(m2) {
92                        Ordering::Equal => match o1.cmp(o2) {
93                            Ordering::Equal => a1.cmp(a2),
94                            other => other,
95                        },
96                        other => other,
97                    },
98                    (Gt(l1, r1), Gt(l2, r2)) => match l1.cmp(l2) {
99                        Ordering::Equal => r1.cmp(r2),
100                        other => other,
101                    },
102                    (Lt(l1, r1), Lt(l2, r2)) => match l1.cmp(l2) {
103                        Ordering::Equal => r1.cmp(r2),
104                        other => other,
105                    },
106                    (Ge(l1, r1), Ge(l2, r2)) => match l1.cmp(l2) {
107                        Ordering::Equal => r1.cmp(r2),
108                        other => other,
109                    },
110                    (Le(l1, r1), Le(l2, r2)) => match l1.cmp(l2) {
111                        Ordering::Equal => r1.cmp(r2),
112                        other => other,
113                    },
114                    (Eq(l1, r1), Eq(l2, r2)) => match l1.cmp(l2) {
115                        Ordering::Equal => r1.cmp(r2),
116                        other => other,
117                    },
118                    (Ne(l1, r1), Ne(l2, r2)) => match l1.cmp(l2) {
119                        Ordering::Equal => r1.cmp(r2),
120                        other => other,
121                    },
122                    (In(l1, r1), In(l2, r2)) => match l1.cmp(l2) {
123                        Ordering::Equal => r1.cmp(r2),
124                        other => other,
125                    },
126                    (And(a), And(b)) => a.cmp(b),
127                    (Or(a), Or(b)) => a.cmp(b),
128                    (Not(a), Not(b)) => a.cmp(b),
129                    _ => unreachable!(),
130                }
131            }
132            other => other,
133        }
134    }
135}
136
137impl PartialEq for Expr {
138    fn eq(&self, other: &Self) -> bool {
139        self.cmp(other) == Ordering::Equal
140    }
141}
142
143impl Eq for Expr {}
144
145impl Hash for Expr {
146    fn hash<H: Hasher>(&self, state: &mut H) {
147        use Expr::*;
148        use ordered_float::OrderedFloat;
149
150        // Hash the discriminant first.
151        std::mem::discriminant(self).hash(state);
152
153        // Then hash the contents.
154        match self {
155            Field(s) => s.hash(state),
156            Str(s) => s.hash(state),
157            I64(i) => i.hash(state),
158            F64(f) => OrderedFloat(*f).hash(state),
159            Bool(b) => b.hash(state),
160            Null => {},
161            Array(v) => v.hash(state),
162            FuncCall(name, args) => {
163                name.hash(state);
164                args.hash(state);
165            }
166            MethodCall(method, obj, args) => {
167                method.hash(state);
168                obj.hash(state);
169                args.hash(state);
170            }
171            Gt(l, r) => {
172                l.hash(state);
173                r.hash(state);
174            }
175            Lt(l, r) => {
176                l.hash(state);
177                r.hash(state);
178            }
179            Ge(l, r) => {
180                l.hash(state);
181                r.hash(state);
182            }
183            Le(l, r) => {
184                l.hash(state);
185                r.hash(state);
186            }
187            Eq(l, r) => {
188                l.hash(state);
189                r.hash(state);
190            }
191            Ne(l, r) => {
192                l.hash(state);
193                r.hash(state);
194            }
195            In(l, r) => {
196                l.hash(state);
197                r.hash(state);
198            }
199            And(v) => v.hash(state),
200            Or(v) => v.hash(state),
201            Not(e) => e.hash(state),
202        }
203    }
204}
205
206impl Expr {
207    pub fn field_<T: Into<String>>(field: T) -> Self {
208        Self::Field(field.into())
209    }
210
211    pub fn str_<T: Into<String>>(value: T) -> Self {
212        Self::Str(value.into())
213    }
214
215    pub fn i64_<T: Into<i64>>(value: T) -> Self {
216        Self::I64(value.into())
217    }
218
219    pub fn f64_<T: Into<f64>>(value: T) -> Self {
220        Self::F64(value.into())
221    }
222
223    pub fn bool_<T: Into<bool>>(value: T) -> Self {
224        Self::Bool(value.into())
225    }
226
227    pub fn null_() -> Self {
228        Self::Null
229    }
230
231    pub fn array_<T: Into<Vec<Expr>>>(value: T) -> Self {
232        Self::Array(value.into())
233    }
234
235    pub fn func_call_(func: impl Into<String>, args: Vec<Expr>) -> Self {
236        Self::FuncCall(func.into(), args)
237    }
238
239    pub fn method_call_(obj: Expr, method: impl Into<String>, args: Vec<Expr>) -> Self {
240        Self::MethodCall(method.into(), Box::new(obj), args)
241    }
242
243    pub fn gt_(left: Expr, right: Expr) -> Self {
244        Self::Gt(Box::new(left), Box::new(right))
245    }
246
247    pub fn lt_(left: Expr, right: Expr) -> Self {
248        Self::Lt(Box::new(left), Box::new(right))
249    }
250
251    pub fn ge_(left: Expr, right: Expr) -> Self {
252        Self::Ge(Box::new(left), Box::new(right))
253    }
254
255    pub fn le_(left: Expr, right: Expr) -> Self {
256        Self::Le(Box::new(left), Box::new(right))
257    }
258
259    pub fn eq_(left: Expr, right: Expr) -> Self {
260        Self::Eq(Box::new(left), Box::new(right))
261    }
262
263    pub fn ne_(left: Expr, right: Expr) -> Self {
264        Self::Ne(Box::new(left), Box::new(right))
265    }
266
267    pub fn in_(left: Expr, right: Expr) -> Self {
268        Self::In(Box::new(left), Box::new(right))
269    }
270
271    pub fn and_<T: Into<Vec<Expr>>>(value: T) -> Self {
272        Self::And(value.into())
273    }
274
275    pub fn or_<T: Into<Vec<Expr>>>(value: T) -> Self {
276        Self::Or(value.into())
277    }
278
279    pub fn not_(self) -> Self {
280        Self::Not(Box::new(self))
281    }
282}
283
284impl Expr {
285    /// Recursively transform an expression using the provided transformer.
286    ///
287    /// ```rust
288    /// use filter_expr::{Expr, Transform};
289    /// use async_trait::async_trait;
290    ///
291    /// struct MyTransformer;
292    ///
293    /// #[async_trait]
294    /// impl Transform for MyTransformer {
295    ///     async fn transform(&mut self, expr: Expr) -> Result<Expr, filter_expr::Error> {
296    ///         // Transform the expression before recursing
297    ///         Ok(match expr {
298    ///             Expr::Field(name) if name == "old_name" => {
299    ///                 Expr::Field("new_name".to_string())
300    ///             }
301    ///             other => other,
302    ///         })
303    ///     }
304    /// }
305    ///
306    /// # #[tokio::main]
307    /// # async fn main() {
308    /// let expr = Expr::Field("old_name".to_string());
309    /// let mut transformer = MyTransformer;
310    /// let result = expr.transform(&mut transformer).await.unwrap();
311    /// assert_eq!(result, Expr::Field("new_name".to_string()));
312    /// # }
313    /// ```
314    pub async fn transform<F: Transform>(self, transformer: &mut F) -> Result<Expr, Error> {
315        let ctx = TransformContext { depth: 0 };
316
317        return Self::transform_expr(transformer, self, ctx).await;
318    }
319
320    async fn transform_expr<F: Transform>(transformer: &mut F, expr: Expr, ctx: TransformContext) -> Result<Expr, Error> {
321        let this = transformer.transform(expr, ctx.clone()).await;
322
323        match this {
324            TransformResult::Continue(expr) => {
325                return Box::pin(Self::transform_children(transformer, expr, ctx)).await;
326            }
327            TransformResult::Stop(expr) => {
328                return Ok(expr);
329            }
330            TransformResult::Err(err) => {
331                return Err(Error::Transform(err));
332            }
333        }
334    }
335
336    async fn transform_children<F: Transform>(transformer: &mut F, expr: Expr, mut ctx: TransformContext) -> Result<Expr, Error> {
337        ctx.depth += 1;
338
339        Ok(match expr {
340            // Do nothing if the expression have no children.
341            Expr::Field(name) => Expr::Field(name),
342            Expr::Str(value) => Expr::Str(value),
343            Expr::I64(value) => Expr::I64(value),
344            Expr::F64(value) => Expr::F64(value),
345            Expr::Bool(value) => Expr::Bool(value),
346            Expr::Null => Expr::Null,
347            Expr::Array(value) => Expr::Array(value),
348
349            Expr::FuncCall(func, args) => {
350                let mut transformed_args = Vec::new();
351                for arg in args {
352                    transformed_args.push(Self::transform_expr(transformer, arg, ctx.clone()).await?);
353                }
354                Expr::FuncCall(func, transformed_args)
355            }
356            Expr::MethodCall(method, obj, args) => {
357                let obj = Box::new(Self::transform_expr(transformer, *obj, ctx.clone()).await?);
358                let mut transformed_args = Vec::new();
359                for arg in args {
360                    transformed_args.push(Self::transform_expr(transformer, arg, ctx.clone()).await?);
361                }
362                Expr::MethodCall(method, obj, transformed_args)
363            }
364
365            Expr::Gt(left, right) => {
366                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
367                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
368                Expr::Gt(left, right)
369            }
370            Expr::Lt(left, right) => {
371                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
372                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
373                Expr::Lt(left, right)
374            }
375            Expr::Ge(left, right) => {
376                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
377                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
378                Expr::Ge(left, right)
379            }
380            Expr::Le(left, right) => {
381                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
382                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
383                Expr::Le(left, right)
384            }
385            Expr::Eq(left, right) => {
386                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
387                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
388                Expr::Eq(left, right)
389            }
390            Expr::Ne(left, right) => {
391                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
392                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
393                Expr::Ne(left, right)
394            }
395            Expr::In(left, right) => {
396                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
397                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
398                Expr::In(left, right)
399            }
400            Expr::And(exprs) => {
401                let mut transformed_exprs = Vec::new();
402                for e in exprs {
403                    transformed_exprs.push(Self::transform_expr(transformer, e, ctx.clone()).await?);
404                }
405                Expr::And(transformed_exprs)
406            }
407            Expr::Or(exprs) => {
408                let mut transformed_exprs = Vec::new();
409                for e in exprs {
410                    transformed_exprs.push(Self::transform_expr(transformer, e, ctx.clone()).await?);
411                }
412                Expr::Or(transformed_exprs)
413            }
414            Expr::Not(expr) => {
415                let expr = Box::new(Self::transform_expr(transformer, *expr, ctx).await?);
416                Expr::Not(expr)
417            }
418        })
419    }
420}