Skip to main content

datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::{
21    Expr,
22    expr::{
23        AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case,
24        Cast, GroupingSet, HigherOrderFunction, InList, InSubquery, Lambda, Like,
25        Placeholder, ScalarFunction, SetComparison, TryCast, Unnest, WindowFunction,
26        WindowFunctionParams,
27    },
28};
29use datafusion_common::{
30    Result,
31    tree_node::{
32        Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
33    },
34};
35
36/// Implementation of the [`TreeNode`] trait
37///
38/// This allows logical expressions (`Expr`) to be traversed and transformed
39/// Facilitates tasks such as optimization and rewriting during query
40/// planning.
41impl TreeNode for Expr {
42    /// Applies a function `f` to each child expression of `self`.
43    ///
44    /// The function `f` determines whether to continue traversing the tree or to stop.
45    /// This method collects all child expressions and applies `f` to each.
46    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
47        &'n self,
48        f: F,
49    ) -> Result<TreeNodeRecursion> {
50        match self {
51            Expr::Alias(Alias { expr, .. })
52            | Expr::Unnest(Unnest { expr })
53            | Expr::Not(expr)
54            | Expr::IsNotNull(expr)
55            | Expr::IsTrue(expr)
56            | Expr::IsFalse(expr)
57            | Expr::IsUnknown(expr)
58            | Expr::IsNotTrue(expr)
59            | Expr::IsNotFalse(expr)
60            | Expr::IsNotUnknown(expr)
61            | Expr::IsNull(expr)
62            | Expr::Negative(expr)
63            | Expr::Cast(Cast { expr, .. })
64            | Expr::TryCast(TryCast { expr, .. })
65            | Expr::InSubquery(InSubquery { expr, .. })
66            | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f),
67            Expr::GroupingSet(GroupingSet::Rollup(exprs))
68            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
69            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
70                args.apply_elements(f)
71            }
72            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
73                lists_of_exprs.apply_elements(f)
74            }
75            // TODO: remove the next line after `Expr::Wildcard` is removed
76            #[expect(deprecated)]
77            Expr::Column(_)
78            // Treat OuterReferenceColumn as a leaf expression
79            | Expr::OuterReferenceColumn(_, _)
80            | Expr::ScalarVariable(_, _)
81            | Expr::Literal(_, _)
82            | Expr::Exists { .. }
83            | Expr::ScalarSubquery(_)
84            | Expr::Wildcard { .. }
85            | Expr::Placeholder(_)
86            | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue),
87            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
88                (left, right).apply_ref_elements(f)
89            }
90            Expr::Like(Like { expr, pattern, .. })
91            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
92                (expr, pattern).apply_ref_elements(f)
93            }
94            Expr::Between(Between {
95                              expr, low, high, ..
96                          }) => (expr, low, high).apply_ref_elements(f),
97            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
98                (expr, when_then_expr, else_expr).apply_ref_elements(f),
99            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
100                (args, filter, order_by).apply_ref_elements(f),
101            Expr::WindowFunction(window_fun) => {
102                let WindowFunctionParams {
103                    args,
104                    partition_by,
105                    order_by,
106                    filter,
107                    ..
108                } = &window_fun.as_ref().params;
109                (args, partition_by, order_by, filter).apply_ref_elements(f)
110            }
111
112            Expr::InList(InList { expr, list, .. }) => {
113                (expr, list).apply_ref_elements(f)
114            }
115            Expr::HigherOrderFunction(HigherOrderFunction { func: _, args}) => args.apply_elements(f),
116            Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f)
117        }
118    }
119
120    /// Maps each child of `self` using the provided closure `f`.
121    ///
122    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
123    /// indicating whether the expression was transformed or left unchanged.
124    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
125        self,
126        f: F,
127    ) -> Result<Transformed<Self>> {
128        Ok(match self {
129            // TODO: remove the next line after `Expr::Wildcard` is removed
130            #[expect(deprecated)]
131            Expr::Column(_)
132            | Expr::Wildcard { .. }
133            | Expr::Placeholder(Placeholder { .. })
134            | Expr::OuterReferenceColumn(_, _)
135            | Expr::Exists { .. }
136            | Expr::ScalarSubquery(_)
137            | Expr::ScalarVariable(_, _)
138            | Expr::Literal(_, _)
139            | Expr::LambdaVariable(_) => Transformed::no(self),
140            Expr::SetComparison(SetComparison {
141                expr,
142                subquery,
143                op,
144                quantifier,
145            }) => expr.map_elements(f)?.update_data(|expr| {
146                Expr::SetComparison(SetComparison {
147                    expr,
148                    subquery,
149                    op,
150                    quantifier,
151                })
152            }),
153            Expr::Unnest(Unnest { expr, .. }) => expr
154                .map_elements(f)?
155                .update_data(|expr| Expr::Unnest(Unnest { expr })),
156            Expr::Alias(Alias {
157                expr,
158                relation,
159                name,
160                metadata,
161            }) => expr.map_elements(f)?.update_data(|expr| {
162                Expr::Alias(Alias {
163                    expr,
164                    relation,
165                    name,
166                    metadata,
167                })
168            }),
169            Expr::InSubquery(InSubquery {
170                expr,
171                subquery,
172                negated,
173            }) => expr.map_elements(f)?.update_data(|be| {
174                Expr::InSubquery(InSubquery::new(be, subquery, negated))
175            }),
176            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
177                .map_elements(f)?
178                .update_data(|(new_left, new_right)| {
179                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
180                }),
181            Expr::Like(Like {
182                negated,
183                expr,
184                pattern,
185                escape_char,
186                case_insensitive,
187            }) => {
188                (expr, pattern)
189                    .map_elements(f)?
190                    .update_data(|(new_expr, new_pattern)| {
191                        Expr::Like(Like::new(
192                            negated,
193                            new_expr,
194                            new_pattern,
195                            escape_char,
196                            case_insensitive,
197                        ))
198                    })
199            }
200            Expr::SimilarTo(Like {
201                negated,
202                expr,
203                pattern,
204                escape_char,
205                case_insensitive,
206            }) => {
207                (expr, pattern)
208                    .map_elements(f)?
209                    .update_data(|(new_expr, new_pattern)| {
210                        Expr::SimilarTo(Like::new(
211                            negated,
212                            new_expr,
213                            new_pattern,
214                            escape_char,
215                            case_insensitive,
216                        ))
217                    })
218            }
219            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
220            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
221            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
222            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
223            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
224            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
225            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
226            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
227            Expr::IsNotUnknown(expr) => {
228                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
229            }
230            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
231            Expr::Between(Between {
232                expr,
233                negated,
234                low,
235                high,
236            }) => (expr, low, high).map_elements(f)?.update_data(
237                |(new_expr, new_low, new_high)| {
238                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
239                },
240            ),
241            Expr::Case(Case {
242                expr,
243                when_then_expr,
244                else_expr,
245            }) => (expr, when_then_expr, else_expr)
246                .map_elements(f)?
247                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
248                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
249                }),
250            Expr::Cast(Cast { expr, field }) => expr
251                .map_elements(f)?
252                .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))),
253            Expr::TryCast(TryCast { expr, field }) => expr
254                .map_elements(f)?
255                .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))),
256            Expr::ScalarFunction(ScalarFunction { func, args }) => {
257                args.map_elements(f)?.map_data(|new_args| {
258                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
259                        func, new_args,
260                    )))
261                })?
262            }
263            Expr::WindowFunction(window_fun) => {
264                let WindowFunction {
265                    fun,
266                    params:
267                        WindowFunctionParams {
268                            args,
269                            partition_by,
270                            order_by,
271                            window_frame,
272                            filter,
273                            null_treatment,
274                            distinct,
275                        },
276                } = *window_fun;
277
278                (args, partition_by, order_by, filter)
279                    .map_elements(f)?
280                    .map_data(
281                        |(new_args, new_partition_by, new_order_by, new_filter)| {
282                            Ok(Expr::from(WindowFunction {
283                                fun,
284                                params: WindowFunctionParams {
285                                    args: new_args,
286                                    partition_by: new_partition_by,
287                                    order_by: new_order_by,
288                                    window_frame,
289                                    filter: new_filter,
290                                    null_treatment,
291                                    distinct,
292                                },
293                            }))
294                        },
295                    )?
296            }
297            Expr::AggregateFunction(AggregateFunction {
298                func,
299                params:
300                    AggregateFunctionParams {
301                        args,
302                        distinct,
303                        filter,
304                        order_by,
305                        null_treatment,
306                    },
307            }) => (args, filter, order_by).map_elements(f)?.map_data(
308                |(new_args, new_filter, new_order_by)| {
309                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
310                        func,
311                        new_args,
312                        distinct,
313                        new_filter,
314                        new_order_by,
315                        null_treatment,
316                    )))
317                },
318            )?,
319            Expr::GroupingSet(grouping_set) => match grouping_set {
320                GroupingSet::Rollup(exprs) => exprs
321                    .map_elements(f)?
322                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
323                GroupingSet::Cube(exprs) => exprs
324                    .map_elements(f)?
325                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
326                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
327                    .map_elements(f)?
328                    .update_data(|new_lists_of_exprs| {
329                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
330                    }),
331            },
332            Expr::InList(InList {
333                expr,
334                list,
335                negated,
336            }) => (expr, list)
337                .map_elements(f)?
338                .update_data(|(new_expr, new_list)| {
339                    Expr::InList(InList::new(new_expr, new_list, negated))
340                }),
341            Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
342                args.map_elements(f)?.update_data(|args| {
343                    Expr::HigherOrderFunction(HigherOrderFunction { func, args })
344                })
345            }
346            Expr::Lambda(Lambda { params, body }) => body
347                .map_elements(f)?
348                .update_data(|body| Expr::Lambda(Lambda { params, body })),
349        })
350    }
351}