datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::expr::{
21    AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22    GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23    WindowFunction, WindowFunctionParams,
24};
25use crate::{Expr, ExprFunctionExt};
26
27use datafusion_common::tree_node::{
28    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
29};
30use datafusion_common::Result;
31
32/// Implementation of the [`TreeNode`] trait
33///
34/// This allows logical expressions (`Expr`) to be traversed and transformed
35/// Facilitates tasks such as optimization and rewriting during query
36/// planning.
37impl TreeNode for Expr {
38    /// Applies a function `f` to each child expression of `self`.
39    ///
40    /// The function `f` determines whether to continue traversing the tree or to stop.
41    /// This method collects all child expressions and applies `f` to each.
42    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43        &'n self,
44        f: F,
45    ) -> Result<TreeNodeRecursion> {
46        match self {
47            Expr::Alias(Alias { expr, .. })
48            | Expr::Unnest(Unnest { expr })
49            | Expr::Not(expr)
50            | Expr::IsNotNull(expr)
51            | Expr::IsTrue(expr)
52            | Expr::IsFalse(expr)
53            | Expr::IsUnknown(expr)
54            | Expr::IsNotTrue(expr)
55            | Expr::IsNotFalse(expr)
56            | Expr::IsNotUnknown(expr)
57            | Expr::IsNull(expr)
58            | Expr::Negative(expr)
59            | Expr::Cast(Cast { expr, .. })
60            | Expr::TryCast(TryCast { expr, .. })
61            | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
62            Expr::GroupingSet(GroupingSet::Rollup(exprs))
63            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
64            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
65                args.apply_elements(f)
66            }
67            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
68                lists_of_exprs.apply_elements(f)
69            }
70            // TODO: remove the next line after `Expr::Wildcard` is removed
71            #[expect(deprecated)]
72            Expr::Column(_)
73            // Treat OuterReferenceColumn as a leaf expression
74            | Expr::OuterReferenceColumn(_, _)
75            | Expr::ScalarVariable(_, _)
76            | Expr::Literal(_, _)
77            | Expr::Exists { .. }
78            | Expr::ScalarSubquery(_)
79            | Expr::Wildcard { .. }
80            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
81            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
82                (left, right).apply_ref_elements(f)
83            }
84            Expr::Like(Like { expr, pattern, .. })
85            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
86                (expr, pattern).apply_ref_elements(f)
87            }
88            Expr::Between(Between {
89                              expr, low, high, ..
90                          }) => (expr, low, high).apply_ref_elements(f),
91            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
92                (expr, when_then_expr, else_expr).apply_ref_elements(f),
93            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
94                (args, filter, order_by).apply_ref_elements(f),
95            Expr::WindowFunction(window_fun) => {
96                let WindowFunctionParams {
97                    args,
98                    partition_by,
99                    order_by,
100                    ..
101                } = &window_fun.as_ref().params;
102                (args, partition_by, order_by).apply_ref_elements(f)
103            }
104
105            Expr::InList(InList { expr, list, .. }) => {
106                (expr, list).apply_ref_elements(f)
107            }
108        }
109    }
110
111    /// Maps each child of `self` using the provided closure `f`.
112    ///
113    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
114    /// indicating whether the expression was transformed or left unchanged.
115    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
116        self,
117        mut f: F,
118    ) -> Result<Transformed<Self>> {
119        Ok(match self {
120            // TODO: remove the next line after `Expr::Wildcard` is removed
121            #[expect(deprecated)]
122            Expr::Column(_)
123            | Expr::Wildcard { .. }
124            | Expr::Placeholder(Placeholder { .. })
125            | Expr::OuterReferenceColumn(_, _)
126            | Expr::Exists { .. }
127            | Expr::ScalarSubquery(_)
128            | Expr::ScalarVariable(_, _)
129            | Expr::Literal(_, _) => Transformed::no(self),
130            Expr::Unnest(Unnest { expr, .. }) => expr
131                .map_elements(f)?
132                .update_data(|expr| Expr::Unnest(Unnest { expr })),
133            Expr::Alias(Alias {
134                expr,
135                relation,
136                name,
137                metadata,
138            }) => f(*expr)?.update_data(|e| {
139                e.alias_qualified_with_metadata(relation, name, metadata)
140            }),
141            Expr::InSubquery(InSubquery {
142                expr,
143                subquery,
144                negated,
145            }) => expr.map_elements(f)?.update_data(|be| {
146                Expr::InSubquery(InSubquery::new(be, subquery, negated))
147            }),
148            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
149                .map_elements(f)?
150                .update_data(|(new_left, new_right)| {
151                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
152                }),
153            Expr::Like(Like {
154                negated,
155                expr,
156                pattern,
157                escape_char,
158                case_insensitive,
159            }) => {
160                (expr, pattern)
161                    .map_elements(f)?
162                    .update_data(|(new_expr, new_pattern)| {
163                        Expr::Like(Like::new(
164                            negated,
165                            new_expr,
166                            new_pattern,
167                            escape_char,
168                            case_insensitive,
169                        ))
170                    })
171            }
172            Expr::SimilarTo(Like {
173                negated,
174                expr,
175                pattern,
176                escape_char,
177                case_insensitive,
178            }) => {
179                (expr, pattern)
180                    .map_elements(f)?
181                    .update_data(|(new_expr, new_pattern)| {
182                        Expr::SimilarTo(Like::new(
183                            negated,
184                            new_expr,
185                            new_pattern,
186                            escape_char,
187                            case_insensitive,
188                        ))
189                    })
190            }
191            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
192            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
193            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
194            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
195            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
196            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
197            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
198            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
199            Expr::IsNotUnknown(expr) => {
200                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
201            }
202            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
203            Expr::Between(Between {
204                expr,
205                negated,
206                low,
207                high,
208            }) => (expr, low, high).map_elements(f)?.update_data(
209                |(new_expr, new_low, new_high)| {
210                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
211                },
212            ),
213            Expr::Case(Case {
214                expr,
215                when_then_expr,
216                else_expr,
217            }) => (expr, when_then_expr, else_expr)
218                .map_elements(f)?
219                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
220                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
221                }),
222            Expr::Cast(Cast { expr, data_type }) => expr
223                .map_elements(f)?
224                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
225            Expr::TryCast(TryCast { expr, data_type }) => expr
226                .map_elements(f)?
227                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
228            Expr::ScalarFunction(ScalarFunction { func, args }) => {
229                args.map_elements(f)?.map_data(|new_args| {
230                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
231                        func, new_args,
232                    )))
233                })?
234            }
235            Expr::WindowFunction(window_fun) => {
236                let WindowFunction {
237                    fun,
238                    params:
239                        WindowFunctionParams {
240                            args,
241                            partition_by,
242                            order_by,
243                            window_frame,
244                            null_treatment,
245                        },
246                } = *window_fun;
247                (args, partition_by, order_by).map_elements(f)?.update_data(
248                    |(new_args, new_partition_by, new_order_by)| {
249                        Expr::from(WindowFunction::new(fun, new_args))
250                            .partition_by(new_partition_by)
251                            .order_by(new_order_by)
252                            .window_frame(window_frame)
253                            .null_treatment(null_treatment)
254                            .build()
255                            .unwrap()
256                    },
257                )
258            }
259            Expr::AggregateFunction(AggregateFunction {
260                func,
261                params:
262                    AggregateFunctionParams {
263                        args,
264                        distinct,
265                        filter,
266                        order_by,
267                        null_treatment,
268                    },
269            }) => (args, filter, order_by).map_elements(f)?.map_data(
270                |(new_args, new_filter, new_order_by)| {
271                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
272                        func,
273                        new_args,
274                        distinct,
275                        new_filter,
276                        new_order_by,
277                        null_treatment,
278                    )))
279                },
280            )?,
281            Expr::GroupingSet(grouping_set) => match grouping_set {
282                GroupingSet::Rollup(exprs) => exprs
283                    .map_elements(f)?
284                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
285                GroupingSet::Cube(exprs) => exprs
286                    .map_elements(f)?
287                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
288                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
289                    .map_elements(f)?
290                    .update_data(|new_lists_of_exprs| {
291                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
292                    }),
293            },
294            Expr::InList(InList {
295                expr,
296                list,
297                negated,
298            }) => (expr, list)
299                .map_elements(f)?
300                .update_data(|(new_expr, new_list)| {
301                    Expr::InList(InList::new(new_expr, new_list, negated))
302                }),
303        })
304    }
305}