Skip to main content

datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::Expr;
21use crate::expr::{
22    AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
23    GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, SetComparison,
24    TryCast, Unnest, WindowFunction, WindowFunctionParams,
25};
26
27use datafusion_common::Result;
28use datafusion_common::tree_node::{
29    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
30};
31
32/// Implementation of the [`TreeNode`] trait
33///
34/// This allows logical expressions (`Expr`) to be traversed and transformed
35/// Facilitates tasks such as optimization and rewriting during query
36/// planning.
37impl TreeNode for Expr {
38    /// Applies a function `f` to each child expression of `self`.
39    ///
40    /// The function `f` determines whether to continue traversing the tree or to stop.
41    /// This method collects all child expressions and applies `f` to each.
42    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43        &'n self,
44        f: F,
45    ) -> Result<TreeNodeRecursion> {
46        match self {
47            Expr::Alias(Alias { expr, .. })
48            | Expr::Unnest(Unnest { expr })
49            | Expr::Not(expr)
50            | Expr::IsNotNull(expr)
51            | Expr::IsTrue(expr)
52            | Expr::IsFalse(expr)
53            | Expr::IsUnknown(expr)
54            | Expr::IsNotTrue(expr)
55            | Expr::IsNotFalse(expr)
56            | Expr::IsNotUnknown(expr)
57            | Expr::IsNull(expr)
58            | Expr::Negative(expr)
59            | Expr::Cast(Cast { expr, .. })
60            | Expr::TryCast(TryCast { expr, .. })
61            | Expr::InSubquery(InSubquery { expr, .. })
62            | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f),
63            Expr::GroupingSet(GroupingSet::Rollup(exprs))
64            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
65            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
66                args.apply_elements(f)
67            }
68            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
69                lists_of_exprs.apply_elements(f)
70            }
71            // TODO: remove the next line after `Expr::Wildcard` is removed
72            #[expect(deprecated)]
73            Expr::Column(_)
74            // Treat OuterReferenceColumn as a leaf expression
75            | Expr::OuterReferenceColumn(_, _)
76            | Expr::ScalarVariable(_, _)
77            | Expr::Literal(_, _)
78            | Expr::Exists { .. }
79            | Expr::ScalarSubquery(_)
80            | Expr::Wildcard { .. }
81            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
82            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
83                (left, right).apply_ref_elements(f)
84            }
85            Expr::Like(Like { expr, pattern, .. })
86            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
87                (expr, pattern).apply_ref_elements(f)
88            }
89            Expr::Between(Between {
90                              expr, low, high, ..
91                          }) => (expr, low, high).apply_ref_elements(f),
92            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
93                (expr, when_then_expr, else_expr).apply_ref_elements(f),
94            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
95                (args, filter, order_by).apply_ref_elements(f),
96            Expr::WindowFunction(window_fun) => {
97                let WindowFunctionParams {
98                    args,
99                    partition_by,
100                    order_by,
101                    filter,
102                    ..
103                } = &window_fun.as_ref().params;
104                (args, partition_by, order_by, filter).apply_ref_elements(f)
105            }
106
107            Expr::InList(InList { expr, list, .. }) => {
108                (expr, list).apply_ref_elements(f)
109            }
110        }
111    }
112
113    /// Maps each child of `self` using the provided closure `f`.
114    ///
115    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
116    /// indicating whether the expression was transformed or left unchanged.
117    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
118        self,
119        mut f: F,
120    ) -> Result<Transformed<Self>> {
121        Ok(match self {
122            // TODO: remove the next line after `Expr::Wildcard` is removed
123            #[expect(deprecated)]
124            Expr::Column(_)
125            | Expr::Wildcard { .. }
126            | Expr::Placeholder(Placeholder { .. })
127            | Expr::OuterReferenceColumn(_, _)
128            | Expr::Exists { .. }
129            | Expr::ScalarSubquery(_)
130            | Expr::ScalarVariable(_, _)
131            | Expr::Literal(_, _) => Transformed::no(self),
132            Expr::SetComparison(SetComparison {
133                expr,
134                subquery,
135                op,
136                quantifier,
137            }) => expr.map_elements(f)?.update_data(|expr| {
138                Expr::SetComparison(SetComparison {
139                    expr,
140                    subquery,
141                    op,
142                    quantifier,
143                })
144            }),
145            Expr::Unnest(Unnest { expr, .. }) => expr
146                .map_elements(f)?
147                .update_data(|expr| Expr::Unnest(Unnest { expr })),
148            Expr::Alias(Alias {
149                expr,
150                relation,
151                name,
152                metadata,
153            }) => f(*expr)?.update_data(|e| {
154                e.alias_qualified_with_metadata(relation, name, metadata)
155            }),
156            Expr::InSubquery(InSubquery {
157                expr,
158                subquery,
159                negated,
160            }) => expr.map_elements(f)?.update_data(|be| {
161                Expr::InSubquery(InSubquery::new(be, subquery, negated))
162            }),
163            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
164                .map_elements(f)?
165                .update_data(|(new_left, new_right)| {
166                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
167                }),
168            Expr::Like(Like {
169                negated,
170                expr,
171                pattern,
172                escape_char,
173                case_insensitive,
174            }) => {
175                (expr, pattern)
176                    .map_elements(f)?
177                    .update_data(|(new_expr, new_pattern)| {
178                        Expr::Like(Like::new(
179                            negated,
180                            new_expr,
181                            new_pattern,
182                            escape_char,
183                            case_insensitive,
184                        ))
185                    })
186            }
187            Expr::SimilarTo(Like {
188                negated,
189                expr,
190                pattern,
191                escape_char,
192                case_insensitive,
193            }) => {
194                (expr, pattern)
195                    .map_elements(f)?
196                    .update_data(|(new_expr, new_pattern)| {
197                        Expr::SimilarTo(Like::new(
198                            negated,
199                            new_expr,
200                            new_pattern,
201                            escape_char,
202                            case_insensitive,
203                        ))
204                    })
205            }
206            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
207            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
208            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
209            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
210            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
211            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
212            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
213            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
214            Expr::IsNotUnknown(expr) => {
215                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
216            }
217            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
218            Expr::Between(Between {
219                expr,
220                negated,
221                low,
222                high,
223            }) => (expr, low, high).map_elements(f)?.update_data(
224                |(new_expr, new_low, new_high)| {
225                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
226                },
227            ),
228            Expr::Case(Case {
229                expr,
230                when_then_expr,
231                else_expr,
232            }) => (expr, when_then_expr, else_expr)
233                .map_elements(f)?
234                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
235                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
236                }),
237            Expr::Cast(Cast { expr, data_type }) => expr
238                .map_elements(f)?
239                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
240            Expr::TryCast(TryCast { expr, data_type }) => expr
241                .map_elements(f)?
242                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
243            Expr::ScalarFunction(ScalarFunction { func, args }) => {
244                args.map_elements(f)?.map_data(|new_args| {
245                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
246                        func, new_args,
247                    )))
248                })?
249            }
250            Expr::WindowFunction(window_fun) => {
251                let WindowFunction {
252                    fun,
253                    params:
254                        WindowFunctionParams {
255                            args,
256                            partition_by,
257                            order_by,
258                            window_frame,
259                            filter,
260                            null_treatment,
261                            distinct,
262                        },
263                } = *window_fun;
264
265                (args, partition_by, order_by, filter)
266                    .map_elements(f)?
267                    .map_data(
268                        |(new_args, new_partition_by, new_order_by, new_filter)| {
269                            Ok(Expr::from(WindowFunction {
270                                fun,
271                                params: WindowFunctionParams {
272                                    args: new_args,
273                                    partition_by: new_partition_by,
274                                    order_by: new_order_by,
275                                    window_frame,
276                                    filter: new_filter,
277                                    null_treatment,
278                                    distinct,
279                                },
280                            }))
281                        },
282                    )?
283            }
284            Expr::AggregateFunction(AggregateFunction {
285                func,
286                params:
287                    AggregateFunctionParams {
288                        args,
289                        distinct,
290                        filter,
291                        order_by,
292                        null_treatment,
293                    },
294            }) => (args, filter, order_by).map_elements(f)?.map_data(
295                |(new_args, new_filter, new_order_by)| {
296                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
297                        func,
298                        new_args,
299                        distinct,
300                        new_filter,
301                        new_order_by,
302                        null_treatment,
303                    )))
304                },
305            )?,
306            Expr::GroupingSet(grouping_set) => match grouping_set {
307                GroupingSet::Rollup(exprs) => exprs
308                    .map_elements(f)?
309                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
310                GroupingSet::Cube(exprs) => exprs
311                    .map_elements(f)?
312                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
313                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
314                    .map_elements(f)?
315                    .update_data(|new_lists_of_exprs| {
316                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
317                    }),
318            },
319            Expr::InList(InList {
320                expr,
321                list,
322                negated,
323            }) => (expr, list)
324                .map_elements(f)?
325                .update_data(|(new_expr, new_list)| {
326                    Expr::InList(InList::new(new_expr, new_list, negated))
327                }),
328        })
329    }
330}