Skip to main content

filter_expr/
expr.rs

1use std::cmp::Ordering;
2use std::hash::{Hash, Hasher};
3
4use crate::{Error, Transform, TransformContext, TransformResult};
5
6/// The expression.
7///
8/// It is an AST of the filter expression.
9#[derive(Debug, Clone)]
10pub enum Expr {
11    Field(String),
12    FieldAccess(Box<Expr>, String),
13
14    Str(String),
15    I64(i64),
16    F64(f64),
17    Bool(bool),
18    Null,
19
20    Array(Vec<Expr>),
21
22    FuncCall(String, Vec<Expr>),
23    MethodCall(String, Box<Expr>, Vec<Expr>),
24
25    Gt(Box<Expr>, Box<Expr>),
26    Lt(Box<Expr>, Box<Expr>),
27    Ge(Box<Expr>, Box<Expr>),
28    Le(Box<Expr>, Box<Expr>),
29    Eq(Box<Expr>, Box<Expr>),
30    Ne(Box<Expr>, Box<Expr>),
31    In(Box<Expr>, Box<Expr>),
32
33    And(Vec<Expr>),
34    Or(Vec<Expr>),
35    Not(Box<Expr>),
36}
37
38impl PartialOrd for Expr {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl Ord for Expr {
45    fn cmp(&self, other: &Self) -> Ordering {
46        use Expr::*;
47        use ordered_float::OrderedFloat;
48
49        // Helper function to get variant order.
50        fn variant_order(expr: &Expr) -> u16 {
51            use Expr::*;
52            match expr {
53                Field(..) => 100,
54                FieldAccess(..) => 101,
55
56                Str(..) => 200,
57                I64(..) => 201,
58                F64(..) => 202,
59                Bool(..) => 203,
60                Null => 204,
61
62                Array(..) => 300,
63
64                FuncCall(..) => 400,
65                MethodCall(..) => 401,
66
67                Gt(..) => 500,
68                Lt(..) => 501,
69                Ge(..) => 502,
70                Le(..) => 503,
71                Eq(..) => 504,
72                Ne(..) => 505,
73                In(..) => 506,
74
75                And(..) => 600,
76                Or(..) => 601,
77                Not(..) => 602,
78            }
79        }
80
81        // First compare by variant order.
82        match variant_order(self).cmp(&variant_order(other)) {
83            Ordering::Equal => {
84                // Same variant, compare contents.
85                match (self, other) {
86                    (Field(a), Field(b)) => a.cmp(b),
87                    (FieldAccess(o1, f1), FieldAccess(o2, f2)) => match o1.cmp(o2) {
88                        Ordering::Equal => f1.cmp(f2),
89                        other => other,
90                    },
91
92                    (Str(a), Str(b)) => a.cmp(b),
93                    (I64(a), I64(b)) => a.cmp(b),
94                    (F64(a), F64(b)) => OrderedFloat(*a).cmp(&OrderedFloat(*b)),
95                    (Bool(a), Bool(b)) => a.cmp(b),
96                    (Null, Null) => Ordering::Equal,
97
98                    (Array(a), Array(b)) => a.cmp(b),
99
100                    (FuncCall(f1, a1), FuncCall(f2, a2)) => match f1.cmp(f2) {
101                        Ordering::Equal => a1.cmp(a2),
102                        other => other,
103                    },
104                    (MethodCall(m1, o1, a1), MethodCall(m2, o2, a2)) => match m1.cmp(m2) {
105                        Ordering::Equal => match o1.cmp(o2) {
106                            Ordering::Equal => a1.cmp(a2),
107                            other => other,
108                        },
109                        other => other,
110                    },
111
112                    (Gt(l1, r1), Gt(l2, r2)) => match l1.cmp(l2) {
113                        Ordering::Equal => r1.cmp(r2),
114                        other => other,
115                    },
116                    (Lt(l1, r1), Lt(l2, r2)) => match l1.cmp(l2) {
117                        Ordering::Equal => r1.cmp(r2),
118                        other => other,
119                    },
120                    (Ge(l1, r1), Ge(l2, r2)) => match l1.cmp(l2) {
121                        Ordering::Equal => r1.cmp(r2),
122                        other => other,
123                    },
124                    (Le(l1, r1), Le(l2, r2)) => match l1.cmp(l2) {
125                        Ordering::Equal => r1.cmp(r2),
126                        other => other,
127                    },
128                    (Eq(l1, r1), Eq(l2, r2)) => match l1.cmp(l2) {
129                        Ordering::Equal => r1.cmp(r2),
130                        other => other,
131                    },
132                    (Ne(l1, r1), Ne(l2, r2)) => match l1.cmp(l2) {
133                        Ordering::Equal => r1.cmp(r2),
134                        other => other,
135                    },
136                    (In(l1, r1), In(l2, r2)) => match l1.cmp(l2) {
137                        Ordering::Equal => r1.cmp(r2),
138                        other => other,
139                    },
140
141                    (And(a), And(b)) => a.cmp(b),
142                    (Or(a), Or(b)) => a.cmp(b),
143                    (Not(a), Not(b)) => a.cmp(b),
144
145                    _ => unreachable!(),
146                }
147            }
148            other => other,
149        }
150    }
151}
152
153impl PartialEq for Expr {
154    fn eq(&self, other: &Self) -> bool {
155        self.cmp(other) == Ordering::Equal
156    }
157}
158
159impl Eq for Expr {}
160
161impl Hash for Expr {
162    fn hash<H: Hasher>(&self, state: &mut H) {
163        use Expr::*;
164        use ordered_float::OrderedFloat;
165
166        // Hash the discriminant first.
167        std::mem::discriminant(self).hash(state);
168
169        // Then hash the contents.
170        match self {
171            Field(s) => s.hash(state),
172            FieldAccess(obj, field) => {
173                obj.hash(state);
174                field.hash(state);
175            }
176
177            Str(s) => s.hash(state),
178            I64(i) => i.hash(state),
179            F64(f) => OrderedFloat(*f).hash(state),
180            Bool(b) => b.hash(state),
181            Null => {},
182            Array(v) => v.hash(state),
183            FuncCall(name, args) => {
184                name.hash(state);
185                args.hash(state);
186            }
187            MethodCall(method, obj, args) => {
188                method.hash(state);
189                obj.hash(state);
190                args.hash(state);
191            }
192            Gt(l, r) => {
193                l.hash(state);
194                r.hash(state);
195            }
196            Lt(l, r) => {
197                l.hash(state);
198                r.hash(state);
199            }
200            Ge(l, r) => {
201                l.hash(state);
202                r.hash(state);
203            }
204            Le(l, r) => {
205                l.hash(state);
206                r.hash(state);
207            }
208            Eq(l, r) => {
209                l.hash(state);
210                r.hash(state);
211            }
212            Ne(l, r) => {
213                l.hash(state);
214                r.hash(state);
215            }
216            In(l, r) => {
217                l.hash(state);
218                r.hash(state);
219            }
220            And(v) => v.hash(state),
221            Or(v) => v.hash(state),
222            Not(e) => e.hash(state),
223        }
224    }
225}
226
227impl Expr {
228    pub fn field_<T: Into<String>>(field: T) -> Self {
229        Self::Field(field.into())
230    }
231
232    pub fn field_access_(obj: Expr, field: impl Into<String>) -> Self {
233        Self::FieldAccess(Box::new(obj), field.into())
234    }
235
236    pub fn str_<T: Into<String>>(value: T) -> Self {
237        Self::Str(value.into())
238    }
239
240    pub fn i64_<T: Into<i64>>(value: T) -> Self {
241        Self::I64(value.into())
242    }
243
244    pub fn f64_<T: Into<f64>>(value: T) -> Self {
245        Self::F64(value.into())
246    }
247
248    pub fn bool_<T: Into<bool>>(value: T) -> Self {
249        Self::Bool(value.into())
250    }
251
252    pub fn null_() -> Self {
253        Self::Null
254    }
255
256    pub fn array_<T: Into<Vec<Expr>>>(value: T) -> Self {
257        Self::Array(value.into())
258    }
259
260    pub fn func_call_(func: impl Into<String>, args: Vec<Expr>) -> Self {
261        Self::FuncCall(func.into(), args)
262    }
263
264    pub fn method_call_(obj: Expr, method: impl Into<String>, args: Vec<Expr>) -> Self {
265        Self::MethodCall(method.into(), Box::new(obj), args)
266    }
267
268    pub fn gt_(left: Expr, right: Expr) -> Self {
269        Self::Gt(Box::new(left), Box::new(right))
270    }
271
272    pub fn lt_(left: Expr, right: Expr) -> Self {
273        Self::Lt(Box::new(left), Box::new(right))
274    }
275
276    pub fn ge_(left: Expr, right: Expr) -> Self {
277        Self::Ge(Box::new(left), Box::new(right))
278    }
279
280    pub fn le_(left: Expr, right: Expr) -> Self {
281        Self::Le(Box::new(left), Box::new(right))
282    }
283
284    pub fn eq_(left: Expr, right: Expr) -> Self {
285        Self::Eq(Box::new(left), Box::new(right))
286    }
287
288    pub fn ne_(left: Expr, right: Expr) -> Self {
289        Self::Ne(Box::new(left), Box::new(right))
290    }
291
292    pub fn in_(left: Expr, right: Expr) -> Self {
293        Self::In(Box::new(left), Box::new(right))
294    }
295
296    pub fn and_<T: Into<Vec<Expr>>>(value: T) -> Self {
297        Self::And(value.into())
298    }
299
300    pub fn or_<T: Into<Vec<Expr>>>(value: T) -> Self {
301        Self::Or(value.into())
302    }
303
304    pub fn not_(self) -> Self {
305        Self::Not(Box::new(self))
306    }
307}
308
309impl Expr {
310    /// Recursively transform an expression using the provided transformer.
311    ///
312    /// ```rust
313    /// use filter_expr::{Expr, Transform};
314    /// use async_trait::async_trait;
315    ///
316    /// struct MyTransformer;
317    ///
318    /// #[async_trait]
319    /// impl Transform for MyTransformer {
320    ///     async fn transform(&mut self, expr: Expr) -> Result<Expr, filter_expr::Error> {
321    ///         // Transform the expression before recursing
322    ///         Ok(match expr {
323    ///             Expr::Field(name) if name == "old_name" => {
324    ///                 Expr::Field("new_name".to_string())
325    ///             }
326    ///             other => other,
327    ///         })
328    ///     }
329    /// }
330    ///
331    /// # #[tokio::main]
332    /// # async fn main() {
333    /// let expr = Expr::Field("old_name".to_string());
334    /// let mut transformer = MyTransformer;
335    /// let result = expr.transform(&mut transformer).await.unwrap();
336    /// assert_eq!(result, Expr::Field("new_name".to_string()));
337    /// # }
338    /// ```
339    pub async fn transform<F: Transform>(self, transformer: &mut F) -> Result<Expr, Error> {
340        let ctx = TransformContext { depth: 0 };
341
342        return Self::transform_expr(transformer, self, ctx).await;
343    }
344
345    async fn transform_expr<F: Transform>(transformer: &mut F, expr: Expr, ctx: TransformContext) -> Result<Expr, Error> {
346        let this = transformer.transform(expr, ctx.clone()).await;
347
348        match this {
349            TransformResult::Continue(expr) => {
350                return Box::pin(Self::transform_children(transformer, expr, ctx)).await;
351            }
352            TransformResult::Stop(expr) => {
353                return Ok(expr);
354            }
355            TransformResult::Err(err) => {
356                return Err(Error::Transform(err));
357            }
358        }
359    }
360
361    async fn transform_children<F: Transform>(transformer: &mut F, expr: Expr, mut ctx: TransformContext) -> Result<Expr, Error> {
362        ctx.depth += 1;
363
364        Ok(match expr {
365            // Do nothing if the expression have no children.
366            Expr::Field(name) => Expr::Field(name),
367            Expr::FieldAccess(obj, field) => {
368                let obj = Box::new(Self::transform_expr(transformer, *obj, ctx.clone()).await?);
369                Expr::FieldAccess(obj, field)
370            }
371
372            Expr::Str(value) => Expr::Str(value),
373            Expr::I64(value) => Expr::I64(value),
374            Expr::F64(value) => Expr::F64(value),
375            Expr::Bool(value) => Expr::Bool(value),
376            Expr::Null => Expr::Null,
377            Expr::Array(value) => Expr::Array(value),
378
379            Expr::FuncCall(func, args) => {
380                let mut transformed_args = Vec::new();
381                for arg in args {
382                    transformed_args.push(Self::transform_expr(transformer, arg, ctx.clone()).await?);
383                }
384                Expr::FuncCall(func, transformed_args)
385            }
386            Expr::MethodCall(method, obj, args) => {
387                let obj = Box::new(Self::transform_expr(transformer, *obj, ctx.clone()).await?);
388                let mut transformed_args = Vec::new();
389                for arg in args {
390                    transformed_args.push(Self::transform_expr(transformer, arg, ctx.clone()).await?);
391                }
392                Expr::MethodCall(method, obj, transformed_args)
393            }
394
395            Expr::Gt(left, right) => {
396                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
397                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
398                Expr::Gt(left, right)
399            }
400            Expr::Lt(left, right) => {
401                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
402                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
403                Expr::Lt(left, right)
404            }
405            Expr::Ge(left, right) => {
406                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
407                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
408                Expr::Ge(left, right)
409            }
410            Expr::Le(left, right) => {
411                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
412                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
413                Expr::Le(left, right)
414            }
415            Expr::Eq(left, right) => {
416                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
417                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
418                Expr::Eq(left, right)
419            }
420            Expr::Ne(left, right) => {
421                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
422                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
423                Expr::Ne(left, right)
424            }
425            Expr::In(left, right) => {
426                let left = Box::new(Self::transform_expr(transformer, *left, ctx.clone()).await?);
427                let right = Box::new(Self::transform_expr(transformer, *right, ctx).await?);
428                Expr::In(left, right)
429            }
430            Expr::And(exprs) => {
431                let mut transformed_exprs = Vec::new();
432                for e in exprs {
433                    transformed_exprs.push(Self::transform_expr(transformer, e, ctx.clone()).await?);
434                }
435                Expr::And(transformed_exprs)
436            }
437            Expr::Or(exprs) => {
438                let mut transformed_exprs = Vec::new();
439                for e in exprs {
440                    transformed_exprs.push(Self::transform_expr(transformer, e, ctx.clone()).await?);
441                }
442                Expr::Or(transformed_exprs)
443            }
444            Expr::Not(expr) => {
445                let expr = Box::new(Self::transform_expr(transformer, *expr, ctx).await?);
446                Expr::Not(expr)
447            }
448        })
449    }
450}