datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::expr::{
21    AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22    GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23    WindowFunction, WindowFunctionParams,
24};
25use crate::Expr;
26
27use datafusion_common::tree_node::{
28    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
29};
30use datafusion_common::Result;
31
32/// Implementation of the [`TreeNode`] trait
33///
34/// This allows logical expressions (`Expr`) to be traversed and transformed
35/// Facilitates tasks such as optimization and rewriting during query
36/// planning.
37impl TreeNode for Expr {
38    /// Applies a function `f` to each child expression of `self`.
39    ///
40    /// The function `f` determines whether to continue traversing the tree or to stop.
41    /// This method collects all child expressions and applies `f` to each.
42    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43        &'n self,
44        f: F,
45    ) -> Result<TreeNodeRecursion> {
46        match self {
47            Expr::Alias(Alias { expr, .. })
48            | Expr::Unnest(Unnest { expr })
49            | Expr::Not(expr)
50            | Expr::IsNotNull(expr)
51            | Expr::IsTrue(expr)
52            | Expr::IsFalse(expr)
53            | Expr::IsUnknown(expr)
54            | Expr::IsNotTrue(expr)
55            | Expr::IsNotFalse(expr)
56            | Expr::IsNotUnknown(expr)
57            | Expr::IsNull(expr)
58            | Expr::Negative(expr)
59            | Expr::Cast(Cast { expr, .. })
60            | Expr::TryCast(TryCast { expr, .. })
61            | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
62            Expr::GroupingSet(GroupingSet::Rollup(exprs))
63            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
64            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
65                args.apply_elements(f)
66            }
67            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
68                lists_of_exprs.apply_elements(f)
69            }
70            // TODO: remove the next line after `Expr::Wildcard` is removed
71            #[expect(deprecated)]
72            Expr::Column(_)
73            // Treat OuterReferenceColumn as a leaf expression
74            | Expr::OuterReferenceColumn(_, _)
75            | Expr::ScalarVariable(_, _)
76            | Expr::Literal(_, _)
77            | Expr::Exists { .. }
78            | Expr::ScalarSubquery(_)
79            | Expr::Wildcard { .. }
80            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
81            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
82                (left, right).apply_ref_elements(f)
83            }
84            Expr::Like(Like { expr, pattern, .. })
85            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
86                (expr, pattern).apply_ref_elements(f)
87            }
88            Expr::Between(Between {
89                              expr, low, high, ..
90                          }) => (expr, low, high).apply_ref_elements(f),
91            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
92                (expr, when_then_expr, else_expr).apply_ref_elements(f),
93            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
94                (args, filter, order_by).apply_ref_elements(f),
95            Expr::WindowFunction(window_fun) => {
96                let WindowFunctionParams {
97                    args,
98                    partition_by,
99                    order_by,
100                    filter,
101                    ..
102                } = &window_fun.as_ref().params;
103                (args, partition_by, order_by, filter).apply_ref_elements(f)
104            }
105
106            Expr::InList(InList { expr, list, .. }) => {
107                (expr, list).apply_ref_elements(f)
108            }
109        }
110    }
111
112    /// Maps each child of `self` using the provided closure `f`.
113    ///
114    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
115    /// indicating whether the expression was transformed or left unchanged.
116    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
117        self,
118        mut f: F,
119    ) -> Result<Transformed<Self>> {
120        Ok(match self {
121            // TODO: remove the next line after `Expr::Wildcard` is removed
122            #[expect(deprecated)]
123            Expr::Column(_)
124            | Expr::Wildcard { .. }
125            | Expr::Placeholder(Placeholder { .. })
126            | Expr::OuterReferenceColumn(_, _)
127            | Expr::Exists { .. }
128            | Expr::ScalarSubquery(_)
129            | Expr::ScalarVariable(_, _)
130            | Expr::Literal(_, _) => Transformed::no(self),
131            Expr::Unnest(Unnest { expr, .. }) => expr
132                .map_elements(f)?
133                .update_data(|expr| Expr::Unnest(Unnest { expr })),
134            Expr::Alias(Alias {
135                expr,
136                relation,
137                name,
138                metadata,
139            }) => f(*expr)?.update_data(|e| {
140                e.alias_qualified_with_metadata(relation, name, metadata)
141            }),
142            Expr::InSubquery(InSubquery {
143                expr,
144                subquery,
145                negated,
146            }) => expr.map_elements(f)?.update_data(|be| {
147                Expr::InSubquery(InSubquery::new(be, subquery, negated))
148            }),
149            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
150                .map_elements(f)?
151                .update_data(|(new_left, new_right)| {
152                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
153                }),
154            Expr::Like(Like {
155                negated,
156                expr,
157                pattern,
158                escape_char,
159                case_insensitive,
160            }) => {
161                (expr, pattern)
162                    .map_elements(f)?
163                    .update_data(|(new_expr, new_pattern)| {
164                        Expr::Like(Like::new(
165                            negated,
166                            new_expr,
167                            new_pattern,
168                            escape_char,
169                            case_insensitive,
170                        ))
171                    })
172            }
173            Expr::SimilarTo(Like {
174                negated,
175                expr,
176                pattern,
177                escape_char,
178                case_insensitive,
179            }) => {
180                (expr, pattern)
181                    .map_elements(f)?
182                    .update_data(|(new_expr, new_pattern)| {
183                        Expr::SimilarTo(Like::new(
184                            negated,
185                            new_expr,
186                            new_pattern,
187                            escape_char,
188                            case_insensitive,
189                        ))
190                    })
191            }
192            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
193            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
194            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
195            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
196            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
197            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
198            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
199            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
200            Expr::IsNotUnknown(expr) => {
201                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
202            }
203            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
204            Expr::Between(Between {
205                expr,
206                negated,
207                low,
208                high,
209            }) => (expr, low, high).map_elements(f)?.update_data(
210                |(new_expr, new_low, new_high)| {
211                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
212                },
213            ),
214            Expr::Case(Case {
215                expr,
216                when_then_expr,
217                else_expr,
218            }) => (expr, when_then_expr, else_expr)
219                .map_elements(f)?
220                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
221                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
222                }),
223            Expr::Cast(Cast { expr, data_type }) => expr
224                .map_elements(f)?
225                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
226            Expr::TryCast(TryCast { expr, data_type }) => expr
227                .map_elements(f)?
228                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
229            Expr::ScalarFunction(ScalarFunction { func, args }) => {
230                args.map_elements(f)?.map_data(|new_args| {
231                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
232                        func, new_args,
233                    )))
234                })?
235            }
236            Expr::WindowFunction(window_fun) => {
237                let WindowFunction {
238                    fun,
239                    params:
240                        WindowFunctionParams {
241                            args,
242                            partition_by,
243                            order_by,
244                            window_frame,
245                            filter,
246                            null_treatment,
247                            distinct,
248                        },
249                } = *window_fun;
250
251                (args, partition_by, order_by, filter)
252                    .map_elements(f)?
253                    .map_data(
254                        |(new_args, new_partition_by, new_order_by, new_filter)| {
255                            Ok(Expr::from(WindowFunction {
256                                fun,
257                                params: WindowFunctionParams {
258                                    args: new_args,
259                                    partition_by: new_partition_by,
260                                    order_by: new_order_by,
261                                    window_frame,
262                                    filter: new_filter,
263                                    null_treatment,
264                                    distinct,
265                                },
266                            }))
267                        },
268                    )?
269            }
270            Expr::AggregateFunction(AggregateFunction {
271                func,
272                params:
273                    AggregateFunctionParams {
274                        args,
275                        distinct,
276                        filter,
277                        order_by,
278                        null_treatment,
279                    },
280            }) => (args, filter, order_by).map_elements(f)?.map_data(
281                |(new_args, new_filter, new_order_by)| {
282                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
283                        func,
284                        new_args,
285                        distinct,
286                        new_filter,
287                        new_order_by,
288                        null_treatment,
289                    )))
290                },
291            )?,
292            Expr::GroupingSet(grouping_set) => match grouping_set {
293                GroupingSet::Rollup(exprs) => exprs
294                    .map_elements(f)?
295                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
296                GroupingSet::Cube(exprs) => exprs
297                    .map_elements(f)?
298                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
299                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
300                    .map_elements(f)?
301                    .update_data(|new_lists_of_exprs| {
302                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
303                    }),
304            },
305            Expr::InList(InList {
306                expr,
307                list,
308                negated,
309            }) => (expr, list)
310                .map_elements(f)?
311                .update_data(|(new_expr, new_list)| {
312                    Expr::InList(InList::new(new_expr, new_list, negated))
313                }),
314        })
315    }
316}