filter_expr/
transform.rs

1use async_trait::async_trait;
2
3use crate::Expr;
4
5type BoxedError = Box<dyn std::error::Error + Send + Sync>;
6
7/// A trait for transforming AST expressions.
8///
9/// This trait allows you to recursively transform expressions by visiting
10/// all sub-expressions. The `transform` method is called recursively on all
11/// sub-expressions, allowing you to transform the AST in a composable way.
12///
13/// # Example
14///
15/// ```rust
16/// use filter_expr::{Expr, Transform};
17/// use async_trait::async_trait;
18///
19/// struct MyTransformer;
20///
21/// #[async_trait]
22/// impl Transform for MyTransformer {
23///     async fn transform(&mut self, expr: Expr) -> Result<Expr, filter_expr::Error> {
24///         // Transform the expression before recursing
25///         Ok(match expr {
26///             Expr::Field(name) if name == "old_name" => {
27///                 Expr::Field("new_name".to_string())
28///             }
29///             other => other,
30///         })
31///     }
32/// }
33/// ```
34#[async_trait]
35pub trait Transform {
36    /// Transform an expression.
37    async fn transform(&mut self, expr: Expr, ctx: TransformContext) -> TransformResult 
38    where
39        Self: Sized;
40}
41
42/// The result of transforming an expression.
43pub enum TransformResult {
44    /// Continue transforming the expression.  The children of the expression
45    /// will be transformed.
46    Continue(Expr),
47    /// Stop transforming the expression.  It means we will not transform the
48    /// children of the expression.
49    Stop(Expr),
50    /// An error occurred while transforming the expression.
51    Err(BoxedError),
52}
53
54/// The context of transforming an expression.
55#[derive(Debug, Clone)]
56pub struct TransformContext {
57    /// The depth of the expression.  The root expression has a depth of 0.
58    pub depth: usize,
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[tokio::test]
66    async fn test_transform_expr_to_rename_field() {
67        mod rename_field {
68            use super::*;
69
70            pub struct RenameFieldTransformer {
71                pub old_name: String,
72                pub new_name: String,
73            }
74
75            #[async_trait::async_trait]
76            impl Transform for RenameFieldTransformer {
77                async fn transform(&mut self, expr: Expr, _ctx: TransformContext) -> TransformResult {
78                    TransformResult::Continue(match expr {
79                        Expr::Field(name) if name == self.old_name => {
80                            Expr::Field(self.new_name.clone())
81                        }
82                        _ => expr,
83                    })
84                }
85            }
86        }
87
88        let expr = Expr::eq_(
89            Expr::field_("old_name"),
90            Expr::str_("value"),
91        );
92
93        let mut transformer = rename_field::RenameFieldTransformer {
94            old_name: "old_name".to_string(),
95            new_name: "new_name".to_string(),
96        };
97
98        let result = expr.transform(&mut transformer).await.unwrap();
99        assert_eq!(
100            result,
101            Expr::Eq(
102                Box::new(Expr::field_("new_name")),
103                Box::new(Expr::str_("value"))
104            )
105        );
106    }
107
108    #[tokio::test]
109    async fn test_transform_expr_to_load_datas_from_external_datasource() {
110        mod foo {
111            use std::time::Duration;
112
113            use super::*;
114
115            pub struct FooTransformer;
116
117            #[derive(Debug, thiserror::Error)]
118            pub enum Error {
119                #[error("unexpected arguments length, expected {expected}, actual {actual}")]
120                UnexpectedArgumentsLength {
121                    expected: usize,
122                    actual: usize,
123                },
124            }
125
126            #[async_trait::async_trait]
127            impl Transform for FooTransformer {
128                async fn transform(&mut self, expr: Expr, _ctx: TransformContext) -> TransformResult {
129                    TransformResult::Continue(match expr {
130                        Expr::FuncCall(fn_name, args) if fn_name == "is_not_bad" => {
131                            if args.len() != 1 {
132                                return TransformResult::Err(Box::new(Error::UnexpectedArgumentsLength {
133                                    expected: 1,
134                                    actual: args.len(),
135                                }));
136                            }
137                            let datas = load_datas().await;
138                            Expr::In(Box::new(args[0].clone()), Box::new(Expr::Array(datas)))
139                        }
140                        _ => expr,
141                    })
142                }
143            }
144
145            async fn load_datas() -> Vec<Expr> {
146                tokio::time::sleep(Duration::from_millis(100)).await;
147
148                vec![
149                    Expr::Str("foo".to_string()),
150                    Expr::Str("bar".to_string()),
151                ]
152            }
153        }
154
155        let expr = Expr::and_([
156            Expr::func_call_("is_not_bad", vec![Expr::field_("magic")]),
157            Expr::func_call_("is_not_bad", vec![Expr::field_("foobar")]),
158        ]);
159
160        let mut transformer = foo::FooTransformer;
161        let result = expr.transform(&mut transformer).await.unwrap();
162        assert_eq!(result, Expr::and_([
163            Expr::in_(Expr::field_("magic"), Expr::array_([Expr::str_("foo"), Expr::str_("bar")])),
164            Expr::in_(Expr::field_("foobar"), Expr::array_([Expr::str_("foo"), Expr::str_("bar")])),
165        ]));
166    }
167
168    #[tokio::test]
169    async fn test_early_stop() {
170        mod early_stop {
171            use super::*;
172
173            pub struct EarlyStopTransformer;
174
175            #[async_trait::async_trait]
176            impl Transform for EarlyStopTransformer {
177                async fn transform(&mut self, expr: Expr, ctx: TransformContext) -> TransformResult {
178                    fn transform_magic_eq_expr(expr: Expr) -> Expr {
179                        if let Expr::Eq(ref left, ref right) = expr {
180                            let left_is_field_magic = matches!(left.as_ref(), Expr::Field(field) if field == "magic");
181                            let right_is_field_foobar = matches!(right.as_ref(), Expr::Str(s) if s == "foobar");
182                            if left_is_field_magic && right_is_field_foobar {
183                                // Ignore the `magic = "foobar"` condition.
184                                return Expr::bool_(true);
185                            }
186                        }
187                        expr
188                    }
189
190                    // Only transform the root expression.
191                    if ctx.depth == 0 {
192                        match expr {
193                            Expr::Eq(..) => {
194                                return TransformResult::Stop(transform_magic_eq_expr(expr))
195                            }
196                            Expr::And(exprs) => {
197                                let transformed_exprs: Vec<Expr> = exprs.into_iter().map(transform_magic_eq_expr).collect();
198                                let result = Expr::and_(transformed_exprs);
199                                return TransformResult::Stop(result)
200                            }
201                            _ => return TransformResult::Stop(expr),
202                        }
203                    }
204
205                    // Not root -- just ignore.
206                    TransformResult::Stop(expr)
207                }
208            }
209        }
210
211        let expr = Expr::and_([
212            Expr::eq_(Expr::field_("magic"), Expr::str_("foobar")),
213            Expr::eq_(Expr::field_("magic"), Expr::str_("baz")),
214        ]);
215
216        let mut transformer = early_stop::EarlyStopTransformer;
217        let result = expr.transform(&mut transformer).await.unwrap();
218        assert_eq!(result, Expr::and_([
219            Expr::bool_(true),
220            Expr::eq_(Expr::field_("magic"), Expr::str_("baz")),
221        ]));
222
223        let expr = Expr::eq_(Expr::field_("magic"), Expr::str_("foobar"));
224
225        let mut transformer = early_stop::EarlyStopTransformer;
226        let result = expr.transform(&mut transformer).await.unwrap();
227        assert_eq!(result, Expr::bool_(true));
228
229        let expr = Expr::or_([
230            Expr::eq_(Expr::field_("magic"), Expr::str_("foobar")),
231            Expr::bool_(true),
232        ]);
233
234        let mut transformer = early_stop::EarlyStopTransformer;
235        let result = expr.transform(&mut transformer).await.unwrap();
236
237        assert_eq!(result, Expr::or_([
238            Expr::eq_(Expr::field_("magic"), Expr::str_("foobar")),
239            Expr::bool_(true),
240        ]));
241    }
242}