Skip to main content

datafusion_optimizer/
common_subexpr_eliminate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions
19
20use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name};
33use datafusion_expr::expr::{Alias, HigherOrderFunction, ScalarFunction};
34use datafusion_expr::logical_plan::{
35    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{
38    BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col,
39};
40
41const CSE_PREFIX: &str = "__common_expr";
42
43/// Performs Common Sub-expression Elimination optimization.
44///
45/// This optimization improves query performance by computing expressions that
46/// appear more than once and reusing those results rather than re-computing the
47/// same value
48///
49/// Currently only common sub-expressions within a single `LogicalPlan` are
50/// eliminated.
51///
52/// # Example
53///
54/// Given a projection that computes the same expensive expression
55/// multiple times such as parsing as string as a date with `to_date` twice:
56///
57/// ```text
58/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))])
59/// ```
60///
61/// This optimization will rewrite the plan to compute the common expression once
62/// using a new `ProjectionExec` and then rewrite the original expressions to
63/// refer to that new column.
64///
65/// ```text
66/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here
67///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
68/// ```
69#[derive(Debug)]
70pub struct CommonSubexprEliminate {}
71
72impl CommonSubexprEliminate {
73    pub fn new() -> Self {
74        Self {}
75    }
76
77    fn try_optimize_proj(
78        &self,
79        projection: Projection,
80        config: &dyn OptimizerConfig,
81    ) -> Result<Transformed<LogicalPlan>> {
82        let Projection {
83            expr,
84            input,
85            schema,
86            ..
87        } = projection;
88        let input = Arc::unwrap_or_clone(input);
89        self.try_unary_plan(expr, input, config)?
90            .map_data(|(new_expr, new_input)| {
91                Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
92                    .map(LogicalPlan::Projection)
93            })
94    }
95
96    fn try_optimize_sort(
97        &self,
98        sort: Sort,
99        config: &dyn OptimizerConfig,
100    ) -> Result<Transformed<LogicalPlan>> {
101        let Sort { expr, input, fetch } = sort;
102        let input = Arc::unwrap_or_clone(input);
103        let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
104            .into_iter()
105            .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
106            .unzip();
107        let new_sort = self
108            .try_unary_plan(sort_expressions, input, config)?
109            .update_data(|(new_expr, new_input)| {
110                LogicalPlan::Sort(Sort {
111                    expr: new_expr
112                        .into_iter()
113                        .zip(sort_params)
114                        .map(|(expr, (asc, nulls_first))| SortExpr {
115                            expr,
116                            asc,
117                            nulls_first,
118                        })
119                        .collect(),
120                    input: Arc::new(new_input),
121                    fetch,
122                })
123            });
124        Ok(new_sort)
125    }
126
127    fn try_optimize_filter(
128        &self,
129        filter: Filter,
130        config: &dyn OptimizerConfig,
131    ) -> Result<Transformed<LogicalPlan>> {
132        let Filter {
133            predicate, input, ..
134        } = filter;
135        let input = Arc::unwrap_or_clone(input);
136        let expr = vec![predicate];
137        self.try_unary_plan(expr, input, config)?
138            .map_data(|(mut new_expr, new_input)| {
139                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
140                let new_predicate = new_expr.pop().unwrap();
141                Filter::try_new(new_predicate, Arc::new(new_input))
142                    .map(LogicalPlan::Filter)
143            })
144    }
145
146    fn try_optimize_window(
147        &self,
148        window: Window,
149        config: &dyn OptimizerConfig,
150    ) -> Result<Transformed<LogicalPlan>> {
151        // Collects window expressions from consecutive `LogicalPlan::Window` nodes into
152        // a list.
153        let (window_expr_list, window_schemas, input) =
154            get_consecutive_window_exprs(window);
155
156        // Extract common sub-expressions from the list.
157
158        match CSE::new(ExprCSEController::new(
159            config.alias_generator().as_ref(),
160            ExprMask::Normal,
161        ))
162        .extract_common_nodes(window_expr_list)?
163        {
164            // If there are common sub-expressions, then the insert a projection node
165            // with the common expressions between the new window nodes and the
166            // original input.
167            FoundCommonNodes::Yes {
168                common_nodes: common_exprs,
169                new_nodes_list: new_exprs_list,
170                original_nodes_list: original_exprs_list,
171            } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
172                Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
173            }),
174            FoundCommonNodes::No {
175                original_nodes_list: original_exprs_list,
176            } => Ok(Transformed::no((original_exprs_list, input, None))),
177        }?
178        // Recurse into the new input.
179        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
180        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
181            self.rewrite(new_input, config)?.map_data(|new_input| {
182                Ok((new_window_expr_list, new_input, window_expr_list))
183            })
184        })?
185        // Rebuild the consecutive window nodes.
186        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
187            // If there were common expressions extracted, then we need to make sure
188            // we restore the original column names.
189            // TODO: Although `find_common_exprs()` inserts aliases around extracted
190            //  common expressions this doesn't mean that the original column names
191            //  (schema) are preserved due to the inserted aliases are not always at
192            //  the top of the expression.
193            //  Let's consider improving `find_common_exprs()` to always keep column
194            //  names and get rid of additional name preserving logic here.
195            if let Some(window_expr_list) = window_expr_list {
196                let name_preserver = NamePreserver::new_for_projection();
197                let saved_names = window_expr_list
198                    .iter()
199                    .map(|exprs| {
200                        exprs
201                            .iter()
202                            .map(|expr| name_preserver.save(expr))
203                            .collect::<Vec<_>>()
204                    })
205                    .collect::<Vec<_>>();
206                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
207                    new_input,
208                    |plan, (new_window_expr, saved_names)| {
209                        let new_window_expr = new_window_expr
210                            .into_iter()
211                            .zip(saved_names)
212                            .map(|(new_window_expr, saved_name)| {
213                                saved_name.restore(new_window_expr)
214                            })
215                            .collect::<Vec<_>>();
216                        Window::try_new(new_window_expr, Arc::new(plan))
217                            .map(LogicalPlan::Window)
218                    },
219                )
220            } else {
221                new_window_expr_list
222                    .into_iter()
223                    .zip(window_schemas)
224                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
225                        Window::try_new_with_schema(
226                            new_window_expr,
227                            Arc::new(plan),
228                            schema,
229                        )
230                        .map(LogicalPlan::Window)
231                    })
232            }
233        })
234    }
235
236    fn try_optimize_aggregate(
237        &self,
238        aggregate: Aggregate,
239        config: &dyn OptimizerConfig,
240    ) -> Result<Transformed<LogicalPlan>> {
241        let Aggregate {
242            group_expr,
243            aggr_expr,
244            input,
245            schema,
246            ..
247        } = aggregate;
248        let input = Arc::unwrap_or_clone(input);
249        // Extract common sub-expressions from the aggregate and grouping expressions.
250        match CSE::new(ExprCSEController::new(
251            config.alias_generator().as_ref(),
252            ExprMask::Normal,
253        ))
254        .extract_common_nodes(vec![group_expr, aggr_expr])?
255        {
256            // If there are common sub-expressions, then insert a projection node
257            // with the common expressions between the new aggregate node and the
258            // original input.
259            FoundCommonNodes::Yes {
260                common_nodes: common_exprs,
261                new_nodes_list: mut new_exprs_list,
262                original_nodes_list: mut original_exprs_list,
263            } => {
264                let new_aggr_expr = new_exprs_list.pop().unwrap();
265                let new_group_expr = new_exprs_list.pop().unwrap();
266
267                build_common_expr_project_plan(input, common_exprs).map(|new_input| {
268                    let aggr_expr = original_exprs_list.pop().unwrap();
269                    Transformed::yes((
270                        new_aggr_expr,
271                        new_group_expr,
272                        new_input,
273                        Some(aggr_expr),
274                    ))
275                })
276            }
277
278            FoundCommonNodes::No {
279                original_nodes_list: mut original_exprs_list,
280            } => {
281                let new_aggr_expr = original_exprs_list.pop().unwrap();
282                let new_group_expr = original_exprs_list.pop().unwrap();
283
284                Ok(Transformed::no((
285                    new_aggr_expr,
286                    new_group_expr,
287                    input,
288                    None,
289                )))
290            }
291        }?
292        // Recurse into the new input.
293        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
294        .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
295            self.rewrite(new_input, config)?.map_data(|new_input| {
296                Ok((
297                    new_aggr_expr,
298                    new_group_expr,
299                    aggr_expr,
300                    Arc::new(new_input),
301                ))
302            })
303        })?
304        // Try extracting common aggregate expressions and rebuild the aggregate node.
305        .transform_data(
306            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
307                // Extract common aggregate sub-expressions from the aggregate expressions.
308                match CSE::new(ExprCSEController::new(
309                    config.alias_generator().as_ref(),
310                    ExprMask::NormalAndAggregates,
311                ))
312                .extract_common_nodes(vec![new_aggr_expr])?
313                {
314                    FoundCommonNodes::Yes {
315                        common_nodes: common_exprs,
316                        new_nodes_list: mut new_exprs_list,
317                        original_nodes_list: mut original_exprs_list,
318                    } => {
319                        let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
320                        let new_aggr_expr = original_exprs_list.pop().unwrap();
321                        let saved_names = if let Some(aggr_expr) = aggr_expr {
322                            let name_preserver = NamePreserver::new_for_projection();
323                            aggr_expr
324                                .iter()
325                                .map(|expr| Some(name_preserver.save(expr)))
326                                .collect::<Vec<_>>()
327                        } else {
328                            (0..new_aggr_expr.len()).map(|_| None).collect()
329                        };
330
331                        let mut agg_exprs = common_exprs
332                            .into_iter()
333                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
334                            .collect::<Vec<_>>();
335
336                        let mut proj_exprs = vec![];
337                        for expr in &new_group_expr {
338                            extract_expressions(expr, &mut proj_exprs)
339                        }
340                        for ((expr_rewritten, expr_orig), saved_name) in
341                            rewritten_aggr_expr
342                                .into_iter()
343                                .zip(new_aggr_expr)
344                                .zip(saved_names)
345                        {
346                            if expr_rewritten == expr_orig {
347                                let expr_rewritten = if let Some(saved_name) = saved_name
348                                {
349                                    saved_name.restore(expr_rewritten)
350                                } else {
351                                    expr_rewritten
352                                };
353                                if let Expr::Alias(Alias { expr, name, .. }) =
354                                    expr_rewritten
355                                {
356                                    agg_exprs.push(expr.alias(&name));
357                                    proj_exprs
358                                        .push(Expr::Column(Column::from_name(name)));
359                                } else {
360                                    let expr_alias =
361                                        config.alias_generator().next(CSE_PREFIX);
362                                    let (qualifier, field_name) =
363                                        expr_rewritten.qualified_name();
364                                    let out_name =
365                                        qualified_name(qualifier.as_ref(), &field_name);
366
367                                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
368                                    proj_exprs.push(
369                                        Expr::Column(Column::from_name(expr_alias))
370                                            .alias(out_name),
371                                    );
372                                }
373                            } else {
374                                proj_exprs.push(expr_rewritten);
375                            }
376                        }
377
378                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
379                            new_input,
380                            new_group_expr,
381                            agg_exprs,
382                        )?);
383                        Projection::try_new(proj_exprs, Arc::new(agg))
384                            .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
385                    }
386
387                    // If there aren't any common aggregate sub-expressions, then just
388                    // rebuild the aggregate node.
389                    FoundCommonNodes::No {
390                        original_nodes_list: mut original_exprs_list,
391                    } => {
392                        let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
393
394                        // If there were common expressions extracted, then we need to
395                        // make sure we restore the original column names.
396                        // TODO: Although `find_common_exprs()` inserts aliases around
397                        //  extracted common expressions this doesn't mean that the
398                        //  original column names (schema) are preserved due to the
399                        //  inserted aliases are not always at the top of the
400                        //  expression.
401                        //  Let's consider improving `find_common_exprs()` to always
402                        //  keep column names and get rid of additional name
403                        //  preserving logic here.
404                        if let Some(aggr_expr) = aggr_expr {
405                            let name_preserver = NamePreserver::new_for_projection();
406                            let saved_names = aggr_expr
407                                .iter()
408                                .map(|expr| name_preserver.save(expr))
409                                .collect::<Vec<_>>();
410                            let new_aggr_expr = rewritten_aggr_expr
411                                .into_iter()
412                                .zip(saved_names)
413                                .map(|(new_expr, saved_name)| {
414                                    saved_name.restore(new_expr)
415                                })
416                                .collect::<Vec<Expr>>();
417
418                            // Since `group_expr` may have changed, schema may also.
419                            // Use `try_new()` method.
420                            Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
421                                .map(LogicalPlan::Aggregate)
422                                .map(Transformed::no)
423                        } else {
424                            Aggregate::try_new_with_schema(
425                                new_input,
426                                new_group_expr,
427                                rewritten_aggr_expr,
428                                schema,
429                            )
430                            .map(LogicalPlan::Aggregate)
431                            .map(Transformed::no)
432                        }
433                    }
434                }
435            },
436        )
437    }
438
439    /// Rewrites the expr list and input to remove common subexpressions
440    ///
441    /// # Parameters
442    ///
443    /// * `exprs`: List of expressions in the node
444    /// * `input`: input plan (that produces the columns referred to in `exprs`)
445    ///
446    /// # Return value
447    ///
448    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
449    ///
450    /// 1. The original `input` of no common subexpressions were extracted
451    /// 2. A newly added projection on top of the original input
452    ///    that computes the common subexpressions
453    fn try_unary_plan(
454        &self,
455        exprs: Vec<Expr>,
456        input: LogicalPlan,
457        config: &dyn OptimizerConfig,
458    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
459        // Extract common sub-expressions from the expressions.
460        match CSE::new(ExprCSEController::new(
461            config.alias_generator().as_ref(),
462            ExprMask::Normal,
463        ))
464        .extract_common_nodes(vec![exprs])?
465        {
466            FoundCommonNodes::Yes {
467                common_nodes: common_exprs,
468                new_nodes_list: mut new_exprs_list,
469                original_nodes_list: _,
470            } => {
471                let new_exprs = new_exprs_list.pop().unwrap();
472                build_common_expr_project_plan(input, common_exprs)
473                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
474            }
475            FoundCommonNodes::No {
476                original_nodes_list: mut original_exprs_list,
477            } => {
478                let new_exprs = original_exprs_list.pop().unwrap();
479                Ok(Transformed::no((new_exprs, input)))
480            }
481        }?
482        // Recurse into the new input.
483        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
484        .transform_data(|(new_exprs, new_input)| {
485            self.rewrite(new_input, config)?
486                .map_data(|new_input| Ok((new_exprs, new_input)))
487        })
488    }
489}
490
491/// Get all window expressions inside the consecutive window operators.
492///
493/// Returns the window expressions, and the input to the deepest child
494/// LogicalPlan.
495///
496/// For example, if the input window looks like
497///
498/// ```text
499///   LogicalPlan::Window(exprs=[a, b, c])
500///     LogicalPlan::Window(exprs=[d])
501///       InputPlan
502/// ```
503///
504/// Returns:
505/// *  `window_exprs`: `[[a, b, c], [d]]`
506/// * InputPlan
507///
508/// Consecutive window expressions may refer to same complex expression.
509///
510/// If same complex expression is referred more than once by subsequent
511/// `WindowAggr`s, we can cache complex expression by evaluating it with a
512/// projection before the first WindowAggr.
513///
514/// This enables us to cache complex expression "c3+c4" for following plan:
515///
516/// ```text
517/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
518/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
519/// ```
520///
521/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
522fn get_consecutive_window_exprs(
523    window: Window,
524) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
525    let mut window_expr_list = vec![];
526    let mut window_schemas = vec![];
527    let mut plan = LogicalPlan::Window(window);
528    while let LogicalPlan::Window(Window {
529        input,
530        window_expr,
531        schema,
532    }) = plan
533    {
534        window_expr_list.push(window_expr);
535        window_schemas.push(schema);
536
537        plan = Arc::unwrap_or_clone(input);
538    }
539    (window_expr_list, window_schemas, plan)
540}
541
542impl OptimizerRule for CommonSubexprEliminate {
543    fn supports_rewrite(&self) -> bool {
544        true
545    }
546
547    fn apply_order(&self) -> Option<ApplyOrder> {
548        // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
549        // This is because in some cases adjacent nodes are collected (e.g. `Window`) and
550        // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
551        None
552    }
553
554    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
555    fn rewrite(
556        &self,
557        plan: LogicalPlan,
558        config: &dyn OptimizerConfig,
559    ) -> Result<Transformed<LogicalPlan>> {
560        let original_schema = Arc::clone(plan.schema());
561
562        let optimized_plan = match plan {
563            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
564            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
565            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
566            LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
567            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
568            LogicalPlan::Join(_)
569            | LogicalPlan::Repartition(_)
570            | LogicalPlan::Union(_)
571            | LogicalPlan::TableScan(_)
572            | LogicalPlan::Values(_)
573            | LogicalPlan::EmptyRelation(_)
574            | LogicalPlan::Subquery(_)
575            | LogicalPlan::SubqueryAlias(_)
576            | LogicalPlan::Limit(_)
577            | LogicalPlan::Ddl(_)
578            | LogicalPlan::Explain(_)
579            | LogicalPlan::Analyze(_)
580            | LogicalPlan::Statement(_)
581            | LogicalPlan::DescribeTable(_)
582            | LogicalPlan::Distinct(_)
583            | LogicalPlan::Extension(_)
584            | LogicalPlan::Dml(_)
585            | LogicalPlan::Copy(_)
586            | LogicalPlan::Unnest(_)
587            | LogicalPlan::RecursiveQuery(_) => {
588                // This rule handles recursion itself in a `ApplyOrder::TopDown` like
589                // manner. Process uncorrelated subqueries in expressions
590                // (e.g., Expr::ScalarSubquery), then direct children.
591                plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))?
592                    .transform_sibling(|plan| {
593                        plan.map_children(|c| self.rewrite(c, config))
594                    })?
595            }
596        };
597
598        // If we rewrote the plan, ensure the schema stays the same
599        if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
600        {
601            optimized_plan.map_data(|optimized_plan| {
602                build_recover_project_plan(&original_schema, optimized_plan)
603            })
604        } else {
605            Ok(optimized_plan)
606        }
607    }
608
609    fn name(&self) -> &str {
610        "common_sub_expression_eliminate"
611    }
612}
613
614/// Which type of [expressions](Expr) should be considered for rewriting?
615#[derive(Debug, Clone, Copy)]
616enum ExprMask {
617    /// Ignores:
618    ///
619    /// - [`Literal`](Expr::Literal)
620    /// - [`Columns`](Expr::Column)
621    /// - [`ScalarVariable`](Expr::ScalarVariable)
622    /// - [`Alias`](Expr::Alias)
623    /// - [`Wildcard`](Expr::Wildcard)
624    /// - [`AggregateFunction`](Expr::AggregateFunction)
625    Normal,
626
627    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
628    NormalAndAggregates,
629}
630
631struct ExprCSEController<'a> {
632    alias_generator: &'a AliasGenerator,
633    mask: ExprMask,
634
635    // how many aliases have we seen so far
636    alias_counter: usize,
637}
638
639impl<'a> ExprCSEController<'a> {
640    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
641        Self {
642            alias_generator,
643            mask,
644            alias_counter: 0,
645        }
646    }
647}
648
649impl CSEController for ExprCSEController<'_> {
650    type Node = Expr;
651
652    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
653        match node {
654            // In case of `ScalarFunction`s and `HigherOrderFunction`s we don't know which children are surely
655            // executed so start visiting all children conditionally and stop the
656            // recursion with `TreeNodeRecursion::Jump`.
657            Expr::ScalarFunction(ScalarFunction { func, args }) => {
658                func.conditional_arguments(args)
659            }
660            Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
661                func.conditional_arguments(args)
662            }
663
664            // In case of `And` and `Or` the first child is surely executed, but we
665            // account subexpressions as conditional in the second.
666            Expr::BinaryExpr(BinaryExpr {
667                left,
668                op: Operator::And | Operator::Or,
669                right,
670            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
671
672            // In case of `Case` the optional base expression and the first when
673            // expressions are surely executed, but we account subexpressions as
674            // conditional in the others.
675            Expr::Case(Case {
676                expr,
677                when_then_expr,
678                else_expr,
679            }) => Some((
680                expr.iter()
681                    .map(|e| e.as_ref())
682                    .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
683                    .collect(),
684                when_then_expr
685                    .iter()
686                    .take(1)
687                    .map(|(_, then)| then.as_ref())
688                    .chain(
689                        when_then_expr
690                            .iter()
691                            .skip(1)
692                            .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
693                    )
694                    .chain(else_expr.iter().map(|e| e.as_ref()))
695                    .collect(),
696            )),
697            _ => None,
698        }
699    }
700
701    fn is_valid(node: &Expr) -> bool {
702        !node.is_volatile_node()
703            && !matches!(node, Expr::Lambda(_) | Expr::LambdaVariable(_))
704    }
705
706    fn is_ignored(&self, node: &Expr) -> bool {
707        // MoveTowardsLeafNodes expressions (e.g. get_field) are cheap struct
708        // field accesses that the ExtractLeafExpressions / PushDownLeafProjections
709        // rules deliberately duplicate when needed (one copy for a filter
710        // predicate, another for an output column). CSE deduplicating them
711        // creates intermediate projections that fight with those rules,
712        // causing optimizer instability — ExtractLeafExpressions will undo
713        // the dedup, creating an infinite loop that runs until the iteration
714        // limit is hit. Skip them.
715        if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes {
716            return true;
717        }
718
719        // TODO: remove the next line after `Expr::Wildcard` is removed
720        #[expect(deprecated)]
721        let is_normal_minus_aggregates = matches!(
722            node,
723            // TODO: there's an argument for removing `Literal` from here,
724            // maybe using `Expr::placemement().should_push_to_leaves()` instead
725            // so that we extract common literals and don't broadcast them to num_batch_rows multiple times.
726            // However that currently breaks things like `percentile_cont()` which expect literal arguments
727            // (and would instead be getting `col(__common_expr_n)`).
728            Expr::Literal(..)
729                | Expr::Column(..)
730                | Expr::ScalarVariable(..)
731                | Expr::Alias(..)
732                | Expr::Wildcard { .. }
733                | Expr::Lambda(_)
734                | Expr::LambdaVariable(_)
735        );
736
737        let is_aggr = matches!(node, Expr::AggregateFunction(..));
738
739        match self.mask {
740            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
741            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
742        }
743    }
744
745    fn generate_alias(&self) -> String {
746        self.alias_generator.next(CSE_PREFIX)
747    }
748
749    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
750        // alias the expressions without an `Alias` ancestor node
751        if self.alias_counter > 0 {
752            col(alias)
753        } else {
754            self.alias_counter += 1;
755            col(alias).alias(node.schema_name().to_string())
756        }
757    }
758
759    fn rewrite_f_down(&mut self, node: &Expr) {
760        if matches!(node, Expr::Alias(_)) {
761            self.alias_counter += 1;
762        }
763    }
764    fn rewrite_f_up(&mut self, node: &Expr) {
765        if matches!(node, Expr::Alias(_)) {
766            self.alias_counter -= 1
767        }
768    }
769}
770
771impl Default for CommonSubexprEliminate {
772    fn default() -> Self {
773        Self::new()
774    }
775}
776
777/// Build the "intermediate" projection plan that evaluates the extracted common
778/// expressions.
779///
780/// # Arguments
781/// input: the input plan
782///
783/// common_exprs: which common subexpressions were used (and thus are added to
784/// intermediate projection)
785///
786/// expr_stats: the set of common subexpressions
787fn build_common_expr_project_plan(
788    input: LogicalPlan,
789    common_exprs: Vec<(Expr, String)>,
790) -> Result<LogicalPlan> {
791    let mut fields_set = BTreeSet::new();
792    let mut project_exprs = common_exprs
793        .into_iter()
794        .map(|(expr, expr_alias)| {
795            fields_set.insert(expr_alias.clone());
796            Ok(expr.alias(expr_alias))
797        })
798        .collect::<Result<Vec<_>>>()?;
799
800    for (qualifier, field) in input.schema().iter() {
801        if fields_set.insert(qualified_name(qualifier, field.name())) {
802            project_exprs.push(Expr::from((qualifier, field)));
803        }
804    }
805
806    Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
807}
808
809/// Build the projection plan to eliminate unnecessary columns produced by
810/// the "intermediate" projection plan built in [build_common_expr_project_plan].
811///
812/// This is required to keep the schema the same for plans that pass the input
813/// on to the output, such as `Filter` or `Sort`.
814fn build_recover_project_plan(
815    schema: &DFSchema,
816    input: LogicalPlan,
817) -> Result<LogicalPlan> {
818    let col_exprs = schema.iter().map(Expr::from).collect();
819    Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
820}
821
822fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
823    if let Expr::GroupingSet(groupings) = expr {
824        for e in groupings.distinct_expr() {
825            let (qualifier, field_name) = e.qualified_name();
826            let col = Column::new(qualifier, field_name);
827            result.push(Expr::Column(col))
828        }
829    } else {
830        let (qualifier, field_name) = expr.qualified_name();
831        let col = Column::new(qualifier, field_name);
832        result.push(Expr::Column(col));
833    }
834}
835
836#[cfg(test)]
837mod test {
838
839    use std::iter;
840
841    use arrow::datatypes::{DataType, Field, Schema};
842    use datafusion_expr::logical_plan::{JoinType, table_scan};
843    use datafusion_expr::{
844        AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs,
845        ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
846        grouping_set, is_null, not,
847    };
848    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
849
850    use super::*;
851    use crate::assert_optimized_plan_eq_snapshot;
852    use crate::optimizer::OptimizerContext;
853    use crate::test::udfs::leaf_udf_expr;
854    use crate::test::*;
855    use datafusion_expr::test::function_stub::{avg, sum};
856
857    macro_rules! assert_optimized_plan_equal {
858        (
859            $config:expr,
860            $plan:expr,
861            @ $expected:literal $(,)?
862        ) => {{
863            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
864            assert_optimized_plan_eq_snapshot!(
865                $config,
866                rules,
867                $plan,
868                @ $expected,
869            )
870        }};
871
872        (
873            $plan:expr,
874            @ $expected:literal $(,)?
875        ) => {{
876            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
877            let optimizer_ctx = OptimizerContext::new();
878            assert_optimized_plan_eq_snapshot!(
879                optimizer_ctx,
880                rules,
881                $plan,
882                @ $expected,
883            )
884        }};
885    }
886
887    #[test]
888    fn tpch_q1_simplified() -> Result<()> {
889        // SQL:
890        //  select
891        //      sum(a * (1 - b)),
892        //      sum(a * (1 - b) * (1 + c))
893        //  from T;
894        //
895        // The manual assembled logical plan don't contains the outermost `Projection`.
896
897        let table_scan = test_table_scan()?;
898
899        let plan = LogicalPlanBuilder::from(table_scan)
900            .aggregate(
901                iter::empty::<Expr>(),
902                vec![
903                    sum(col("a") * (lit(1) - col("b"))),
904                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
905                ],
906            )?
907            .build()?;
908
909        assert_optimized_plan_equal!(
910            plan,
911            @ r"
912        Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
913          Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
914            TableScan: test
915        "
916        )
917    }
918
919    #[test]
920    fn nested_aliases() -> Result<()> {
921        let table_scan = test_table_scan()?;
922
923        let plan = LogicalPlanBuilder::from(table_scan)
924            .project(vec![
925                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
926                col("a") + col("b"),
927            ])?
928            .build()?;
929
930        assert_optimized_plan_equal!(
931            plan,
932            @ r"
933        Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
934          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
935            TableScan: test
936        "
937        )
938    }
939
940    #[test]
941    fn aggregate() -> Result<()> {
942        let table_scan = test_table_scan()?;
943
944        let return_type = DataType::UInt32;
945        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
946        let udf_agg = |inner: Expr| {
947            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
948                Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
949                    "my_agg",
950                    Signature::exact(vec![DataType::UInt32], Volatility::Stable),
951                    return_type.clone(),
952                    Arc::clone(&accumulator),
953                    vec![Field::new("value", DataType::UInt32, true).into()],
954                ))),
955                vec![inner],
956                false,
957                None,
958                vec![],
959                None,
960            ))
961        };
962
963        // test: common aggregates
964        let plan = LogicalPlanBuilder::from(table_scan.clone())
965            .aggregate(
966                iter::empty::<Expr>(),
967                vec![
968                    // common: avg(col("a"))
969                    avg(col("a")).alias("col1"),
970                    avg(col("a")).alias("col2"),
971                    // no common
972                    avg(col("b")).alias("col3"),
973                    avg(col("c")),
974                    // common: udf_agg(col("a"))
975                    udf_agg(col("a")).alias("col4"),
976                    udf_agg(col("a")).alias("col5"),
977                    // no common
978                    udf_agg(col("b")).alias("col6"),
979                    udf_agg(col("c")),
980                ],
981            )?
982            .build()?;
983
984        assert_optimized_plan_equal!(
985            plan,
986            @ r"
987        Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
988          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
989            TableScan: test
990        "
991        )?;
992
993        // test: trafo after aggregate
994        let plan = LogicalPlanBuilder::from(table_scan.clone())
995            .aggregate(
996                iter::empty::<Expr>(),
997                vec![
998                    lit(1) + avg(col("a")),
999                    lit(1) - avg(col("a")),
1000                    lit(1) + udf_agg(col("a")),
1001                    lit(1) - udf_agg(col("a")),
1002                ],
1003            )?
1004            .build()?;
1005
1006        assert_optimized_plan_equal!(
1007            plan,
1008            @ r"
1009        Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
1010          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
1011            TableScan: test
1012        "
1013        )?;
1014
1015        // test: transformation before aggregate
1016        let plan = LogicalPlanBuilder::from(table_scan.clone())
1017            .aggregate(
1018                iter::empty::<Expr>(),
1019                vec![
1020                    avg(lit(1u32) + col("a")).alias("col1"),
1021                    udf_agg(lit(1u32) + col("a")).alias("col2"),
1022                ],
1023            )?
1024            .build()?;
1025
1026        assert_optimized_plan_equal!(
1027            plan,
1028            @ r"
1029        Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1030          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1031            TableScan: test
1032        "
1033        )?;
1034
1035        // test: common between agg and group
1036        let plan = LogicalPlanBuilder::from(table_scan.clone())
1037            .aggregate(
1038                vec![lit(1u32) + col("a")],
1039                vec![
1040                    avg(lit(1u32) + col("a")).alias("col1"),
1041                    udf_agg(lit(1u32) + col("a")).alias("col2"),
1042                ],
1043            )?
1044            .build()?;
1045
1046        assert_optimized_plan_equal!(
1047            plan,
1048            @ r"
1049        Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1050          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1051            TableScan: test
1052        "
1053        )?;
1054
1055        // test: all mixed
1056        let plan = LogicalPlanBuilder::from(table_scan)
1057            .aggregate(
1058                vec![lit(1u32) + col("a")],
1059                vec![
1060                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1061                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1062                    avg(lit(1u32) + col("a")),
1063                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1064                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1065                    udf_agg(lit(1u32) + col("a")),
1066                ],
1067            )?
1068            .build()?;
1069
1070        assert_optimized_plan_equal!(
1071            plan,
1072            @ r"
1073        Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1074          Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1075            Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1076              TableScan: test
1077        "
1078        )
1079    }
1080
1081    #[test]
1082    fn aggregate_with_relations_and_dots() -> Result<()> {
1083        let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1084        let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1085
1086        let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1087
1088        let plan = LogicalPlanBuilder::from(table_scan)
1089            .aggregate(
1090                vec![col_a.clone()],
1091                vec![
1092                    (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1093                    avg(lit(1u32) + col_a),
1094                ],
1095            )?
1096            .build()?;
1097
1098        assert_optimized_plan_equal!(
1099            plan,
1100            @ r"
1101        Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1102          Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1103            Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1104              TableScan: table.test
1105        "
1106        )
1107    }
1108
1109    #[test]
1110    fn subexpr_in_same_order() -> Result<()> {
1111        let table_scan = test_table_scan()?;
1112
1113        let plan = LogicalPlanBuilder::from(table_scan)
1114            .project(vec![
1115                (lit(1) + col("a")).alias("first"),
1116                (lit(1) + col("a")).alias("second"),
1117            ])?
1118            .build()?;
1119
1120        assert_optimized_plan_equal!(
1121            plan,
1122            @ r"
1123        Projection: __common_expr_1 AS first, __common_expr_1 AS second
1124          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1125            TableScan: test
1126        "
1127        )
1128    }
1129
1130    #[test]
1131    fn subexpr_in_different_order() -> Result<()> {
1132        let table_scan = test_table_scan()?;
1133
1134        let plan = LogicalPlanBuilder::from(table_scan)
1135            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1136            .build()?;
1137
1138        assert_optimized_plan_equal!(
1139            plan,
1140            @ r"
1141        Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1142          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1143            TableScan: test
1144        "
1145        )
1146    }
1147
1148    #[test]
1149    fn cross_plans_subexpr() -> Result<()> {
1150        let table_scan = test_table_scan()?;
1151
1152        let plan = LogicalPlanBuilder::from(table_scan)
1153            .project(vec![lit(1) + col("a"), col("a")])?
1154            .project(vec![lit(1) + col("a")])?
1155            .build()?;
1156
1157        assert_optimized_plan_equal!(
1158            plan,
1159            @ r"
1160        Projection: Int32(1) + test.a
1161          Projection: Int32(1) + test.a, test.a
1162            TableScan: test
1163        "
1164        )
1165    }
1166
1167    #[test]
1168    fn redundant_project_fields() {
1169        let table_scan = test_table_scan().unwrap();
1170        let c_plus_a = col("c") + col("a");
1171        let b_plus_a = col("b") + col("a");
1172        let common_exprs_1 = vec![
1173            (c_plus_a, format!("{CSE_PREFIX}_1")),
1174            (b_plus_a, format!("{CSE_PREFIX}_2")),
1175        ];
1176        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1177        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1178        let common_exprs_2 = vec![
1179            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1180            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1181        ];
1182        let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1183        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1184
1185        let mut field_set = BTreeSet::new();
1186        for name in project_2.schema().field_names() {
1187            assert!(field_set.insert(name));
1188        }
1189    }
1190
1191    #[test]
1192    fn redundant_project_fields_join_input() {
1193        let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1194        let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1195        let join = LogicalPlanBuilder::from(table_scan_1)
1196            .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1197            .unwrap()
1198            .build()
1199            .unwrap();
1200        let c_plus_a = col("test1.c") + col("test1.a");
1201        let b_plus_a = col("test1.b") + col("test1.a");
1202        let common_exprs_1 = vec![
1203            (c_plus_a, format!("{CSE_PREFIX}_1")),
1204            (b_plus_a, format!("{CSE_PREFIX}_2")),
1205        ];
1206        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1207        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1208        let common_exprs_2 = vec![
1209            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1210            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1211        ];
1212        let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1213        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1214
1215        let mut field_set = BTreeSet::new();
1216        for name in project_2.schema().field_names() {
1217            assert!(field_set.insert(name));
1218        }
1219    }
1220
1221    #[test]
1222    fn eliminated_subexpr_datatype() {
1223        use datafusion_expr::cast;
1224
1225        let schema = Schema::new(vec![
1226            Field::new("a", DataType::UInt64, false),
1227            Field::new("b", DataType::UInt64, false),
1228            Field::new("c", DataType::UInt64, false),
1229        ]);
1230
1231        let plan = table_scan(Some("table"), &schema, None)
1232            .unwrap()
1233            .filter(
1234                cast(col("a"), DataType::Int64)
1235                    .lt(lit(1_i64))
1236                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1237            )
1238            .unwrap()
1239            .build()
1240            .unwrap();
1241        let rule = CommonSubexprEliminate::new();
1242        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1243        assert!(optimized_plan.transformed);
1244        let optimized_plan = optimized_plan.data;
1245
1246        let schema = optimized_plan.schema();
1247        let fields_with_datatypes: Vec<_> = schema
1248            .fields()
1249            .iter()
1250            .map(|field| (field.name(), field.data_type()))
1251            .collect();
1252        let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1253        let expected = r#"[
1254    (
1255        "a",
1256        UInt64,
1257    ),
1258    (
1259        "b",
1260        UInt64,
1261    ),
1262    (
1263        "c",
1264        UInt64,
1265    ),
1266]"#;
1267        assert_eq!(expected, formatted_fields_with_datatype);
1268    }
1269
1270    #[test]
1271    fn filter_schema_changed() -> Result<()> {
1272        let table_scan = test_table_scan()?;
1273
1274        let plan = LogicalPlanBuilder::from(table_scan)
1275            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1276            .build()?;
1277
1278        assert_optimized_plan_equal!(
1279            plan,
1280            @ r"
1281        Projection: test.a, test.b, test.c
1282          Filter: __common_expr_1 - Int32(10) > __common_expr_1
1283            Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1284              TableScan: test
1285        "
1286        )
1287    }
1288
1289    #[test]
1290    fn test_extract_expressions_from_grouping_set() -> Result<()> {
1291        let mut result = Vec::with_capacity(3);
1292        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1293        extract_expressions(&grouping, &mut result);
1294
1295        assert!(result.len() == 3);
1296        Ok(())
1297    }
1298
1299    #[test]
1300    fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1301        let mut result = Vec::with_capacity(2);
1302        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1303        extract_expressions(&grouping, &mut result);
1304        assert!(result.len() == 2);
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn test_alias_collision() -> Result<()> {
1310        let table_scan = test_table_scan()?;
1311
1312        let config = OptimizerContext::new();
1313        let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1314        let plan = LogicalPlanBuilder::from(table_scan.clone())
1315            .project(vec![
1316                (col("a") + col("b")).alias(common_expr_1.clone()),
1317                col("c"),
1318            ])?
1319            .project(vec![
1320                col(common_expr_1.clone()).alias("c1"),
1321                col(common_expr_1).alias("c2"),
1322                (col("c") + lit(2)).alias("c3"),
1323                (col("c") + lit(2)).alias("c4"),
1324            ])?
1325            .build()?;
1326
1327        assert_optimized_plan_equal!(
1328            config,
1329            plan,
1330            @ r"
1331        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1332          Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1333            Projection: test.a + test.b AS __common_expr_1, test.c
1334              TableScan: test
1335        "
1336        )?;
1337
1338        let config = OptimizerContext::new();
1339        let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1340        let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1341        let plan = LogicalPlanBuilder::from(table_scan)
1342            .project(vec![
1343                (col("a") + col("b")).alias(common_expr_2.clone()),
1344                col("c"),
1345            ])?
1346            .project(vec![
1347                col(common_expr_2.clone()).alias("c1"),
1348                col(common_expr_2).alias("c2"),
1349                (col("c") + lit(2)).alias("c3"),
1350                (col("c") + lit(2)).alias("c4"),
1351            ])?
1352            .build()?;
1353
1354        assert_optimized_plan_equal!(
1355            config,
1356            plan,
1357            @ r"
1358        Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1359          Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1360            Projection: test.a + test.b AS __common_expr_2, test.c
1361              TableScan: test
1362        "
1363        )?;
1364
1365        Ok(())
1366    }
1367
1368    #[test]
1369    fn test_extract_expressions_from_col() -> Result<()> {
1370        let mut result = Vec::with_capacity(1);
1371        extract_expressions(&col("a"), &mut result);
1372        assert!(result.len() == 1);
1373        Ok(())
1374    }
1375
1376    #[test]
1377    fn test_short_circuits() -> Result<()> {
1378        let table_scan = test_table_scan()?;
1379
1380        let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1381        let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1382        let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1383        let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1384        let plan = LogicalPlanBuilder::from(table_scan)
1385            .project(vec![
1386                extracted_short_circuit.clone().alias("c1"),
1387                extracted_short_circuit.alias("c2"),
1388                extracted_short_circuit_leg_1
1389                    .clone()
1390                    .or(not_extracted_short_circuit_leg_2.clone())
1391                    .alias("c3"),
1392                extracted_short_circuit_leg_1
1393                    .and(not_extracted_short_circuit_leg_2)
1394                    .alias("c4"),
1395                extracted_short_circuit_leg_3
1396                    .clone()
1397                    .or(extracted_short_circuit_leg_3)
1398                    .alias("c5"),
1399            ])?
1400            .build()?;
1401
1402        assert_optimized_plan_equal!(
1403            plan,
1404            @ r"
1405        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1406          Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1407            TableScan: test
1408        "
1409        )
1410    }
1411
1412    #[test]
1413    fn test_volatile() -> Result<()> {
1414        let table_scan = test_table_scan()?;
1415
1416        let extracted_child = col("a") + col("b");
1417        let rand = rand_func().call(vec![]);
1418        let not_extracted_volatile = extracted_child + rand;
1419        let plan = LogicalPlanBuilder::from(table_scan)
1420            .project(vec![
1421                not_extracted_volatile.clone().alias("c1"),
1422                not_extracted_volatile.alias("c2"),
1423            ])?
1424            .build()?;
1425
1426        assert_optimized_plan_equal!(
1427            plan,
1428            @ r"
1429        Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1430          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1431            TableScan: test
1432        "
1433        )
1434    }
1435
1436    #[test]
1437    fn test_volatile_short_circuits() -> Result<()> {
1438        let table_scan = test_table_scan()?;
1439
1440        let rand = rand_func().call(vec![]);
1441        let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1442        let not_extracted_volatile_short_circuit_1 =
1443            extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1444        let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1445        let not_extracted_volatile_short_circuit_2 =
1446            rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1447        let plan = LogicalPlanBuilder::from(table_scan)
1448            .project(vec![
1449                not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1450                not_extracted_volatile_short_circuit_1.alias("c2"),
1451                not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1452                not_extracted_volatile_short_circuit_2.alias("c4"),
1453            ])?
1454            .build()?;
1455
1456        assert_optimized_plan_equal!(
1457            plan,
1458            @ r"
1459        Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1460          Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1461            TableScan: test
1462        "
1463        )
1464    }
1465
1466    #[test]
1467    fn test_non_top_level_common_expression() -> Result<()> {
1468        let table_scan = test_table_scan()?;
1469
1470        let common_expr = col("a") + col("b");
1471        let plan = LogicalPlanBuilder::from(table_scan)
1472            .project(vec![
1473                common_expr.clone().alias("c1"),
1474                common_expr.alias("c2"),
1475            ])?
1476            .project(vec![col("c1"), col("c2")])?
1477            .build()?;
1478
1479        assert_optimized_plan_equal!(
1480            plan,
1481            @ r"
1482        Projection: c1, c2
1483          Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1484            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1485              TableScan: test
1486        "
1487        )
1488    }
1489
1490    #[test]
1491    fn test_nested_common_expression() -> Result<()> {
1492        let table_scan = test_table_scan()?;
1493
1494        let nested_common_expr = col("a") + col("b");
1495        let common_expr = nested_common_expr.clone() * nested_common_expr;
1496        let plan = LogicalPlanBuilder::from(table_scan)
1497            .project(vec![
1498                common_expr.clone().alias("c1"),
1499                common_expr.alias("c2"),
1500            ])?
1501            .build()?;
1502
1503        assert_optimized_plan_equal!(
1504            plan,
1505            @ r"
1506        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1507          Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1508            Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1509              TableScan: test
1510        "
1511        )
1512    }
1513
1514    #[test]
1515    fn test_normalize_add_expression() -> Result<()> {
1516        // a + b <=> b + a
1517        let table_scan = test_table_scan()?;
1518        let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1519        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1520
1521        assert_optimized_plan_equal!(
1522            plan,
1523            @ r"
1524        Projection: test.a, test.b, test.c
1525          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1526            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1527              TableScan: test
1528        "
1529        )
1530    }
1531
1532    #[test]
1533    fn test_normalize_multi_expression() -> Result<()> {
1534        // a * b <=> b * a
1535        let table_scan = test_table_scan()?;
1536        let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1537        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1538
1539        assert_optimized_plan_equal!(
1540            plan,
1541            @ r"
1542        Projection: test.a, test.b, test.c
1543          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1544            Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1545              TableScan: test
1546        "
1547        )
1548    }
1549
1550    #[test]
1551    fn test_normalize_bitset_and_expression() -> Result<()> {
1552        // a & b <=> b & a
1553        let table_scan = test_table_scan()?;
1554        let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1555        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1556
1557        assert_optimized_plan_equal!(
1558            plan,
1559            @ r"
1560        Projection: test.a, test.b, test.c
1561          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1562            Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1563              TableScan: test
1564        "
1565        )
1566    }
1567
1568    #[test]
1569    fn test_normalize_bitset_or_expression() -> Result<()> {
1570        // a | b <=> b | a
1571        let table_scan = test_table_scan()?;
1572        let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1573        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1574
1575        assert_optimized_plan_equal!(
1576            plan,
1577            @ r"
1578        Projection: test.a, test.b, test.c
1579          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1580            Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1581              TableScan: test
1582        "
1583        )
1584    }
1585
1586    #[test]
1587    fn test_normalize_bitset_xor_expression() -> Result<()> {
1588        // a # b <=> b # a
1589        let table_scan = test_table_scan()?;
1590        let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1591        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1592
1593        assert_optimized_plan_equal!(
1594            plan,
1595            @ r"
1596        Projection: test.a, test.b, test.c
1597          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1598            Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1599              TableScan: test
1600        "
1601        )
1602    }
1603
1604    #[test]
1605    fn test_normalize_eq_expression() -> Result<()> {
1606        // a = b <=> b = a
1607        let table_scan = test_table_scan()?;
1608        let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1609        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1610
1611        assert_optimized_plan_equal!(
1612            plan,
1613            @ r"
1614        Projection: test.a, test.b, test.c
1615          Filter: __common_expr_1 AND __common_expr_1
1616            Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1617              TableScan: test
1618        "
1619        )
1620    }
1621
1622    #[test]
1623    fn test_normalize_ne_expression() -> Result<()> {
1624        // a != b <=> b != a
1625        let table_scan = test_table_scan()?;
1626        let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1627        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1628
1629        assert_optimized_plan_equal!(
1630            plan,
1631            @ r"
1632        Projection: test.a, test.b, test.c
1633          Filter: __common_expr_1 AND __common_expr_1
1634            Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1635              TableScan: test
1636        "
1637        )
1638    }
1639
1640    #[test]
1641    fn test_normalize_complex_expression() -> Result<()> {
1642        // case1: a + b * c <=> b * c + a
1643        let table_scan = test_table_scan()?;
1644        let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1645            .eq(lit(30));
1646        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1647
1648        assert_optimized_plan_equal!(
1649            plan,
1650            @ r"
1651        Projection: test.a, test.b, test.c
1652          Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1653            Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1654              TableScan: test
1655        "
1656        )?;
1657
1658        // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
1659        let table_scan = test_table_scan()?;
1660        let expr = (((col("a") + col("b") / col("c")) * col("c"))
1661            / (col("c") * (col("b") / col("c") + col("a")))
1662            + col("a"))
1663        .eq(lit(30));
1664        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1665
1666        assert_optimized_plan_equal!(
1667            plan,
1668            @ r"
1669        Projection: test.a, test.b, test.c
1670          Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1671            Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1672              TableScan: test
1673        "
1674        )?;
1675
1676        // c2 / (c1 + c3) <=> c2 / (c3 + c1)
1677        let table_scan = test_table_scan()?;
1678        let expr = ((col("b") / (col("a") + col("c")))
1679            * (col("b") / (col("c") + col("a"))))
1680        .eq(lit(30));
1681        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1682        assert_optimized_plan_equal!(
1683            plan,
1684            @ r"
1685        Projection: test.a, test.b, test.c
1686          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1687            Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1688              TableScan: test
1689        "
1690        )?;
1691
1692        Ok(())
1693    }
1694
1695    #[derive(Debug, PartialEq, Eq, Hash)]
1696    pub struct TestUdf {
1697        signature: Signature,
1698    }
1699
1700    impl TestUdf {
1701        pub fn new() -> Self {
1702            Self {
1703                signature: Signature::numeric(1, Volatility::Immutable),
1704            }
1705        }
1706    }
1707
1708    impl ScalarUDFImpl for TestUdf {
1709        fn name(&self) -> &str {
1710            "my_udf"
1711        }
1712
1713        fn signature(&self) -> &Signature {
1714            &self.signature
1715        }
1716
1717        fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1718            Ok(DataType::Int32)
1719        }
1720
1721        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1722            panic!("not implemented")
1723        }
1724    }
1725
1726    #[test]
1727    fn test_normalize_inner_binary_expression() -> Result<()> {
1728        // Not(a == b) <=> Not(b == a)
1729        let table_scan = test_table_scan()?;
1730        let expr1 = not(col("a").eq(col("b")));
1731        let expr2 = not(col("b").eq(col("a")));
1732        let plan = LogicalPlanBuilder::from(table_scan)
1733            .project(vec![expr1, expr2])?
1734            .build()?;
1735        assert_optimized_plan_equal!(
1736            plan,
1737            @ r"
1738        Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1739          Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1740            TableScan: test
1741        "
1742        )?;
1743
1744        // is_null(a == b) <=> is_null(b == a)
1745        let table_scan = test_table_scan()?;
1746        let expr1 = is_null(col("a").eq(col("b")));
1747        let expr2 = is_null(col("b").eq(col("a")));
1748        let plan = LogicalPlanBuilder::from(table_scan)
1749            .project(vec![expr1, expr2])?
1750            .build()?;
1751        assert_optimized_plan_equal!(
1752            plan,
1753            @ r"
1754        Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1755          Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1756            TableScan: test
1757        "
1758        )?;
1759
1760        // a + b between 0 and 10 <=> b + a between 0 and 10
1761        let table_scan = test_table_scan()?;
1762        let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1763        let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1764        let plan = LogicalPlanBuilder::from(table_scan)
1765            .project(vec![expr1, expr2])?
1766            .build()?;
1767        assert_optimized_plan_equal!(
1768            plan,
1769            @ r"
1770        Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1771          Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1772            TableScan: test
1773        "
1774        )?;
1775
1776        // c between a + b and 10 <=> c between b + a and 10
1777        let table_scan = test_table_scan()?;
1778        let expr1 = col("c").between(col("a") + col("b"), lit(10));
1779        let expr2 = col("c").between(col("b") + col("a"), lit(10));
1780        let plan = LogicalPlanBuilder::from(table_scan)
1781            .project(vec![expr1, expr2])?
1782            .build()?;
1783        assert_optimized_plan_equal!(
1784            plan,
1785            @ r"
1786        Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1787          Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1788            TableScan: test
1789        "
1790        )?;
1791
1792        // function call with argument <=> function call with argument
1793        let udf = ScalarUDF::from(TestUdf::new());
1794        let table_scan = test_table_scan()?;
1795        let expr1 = udf.call(vec![col("a") + col("b")]);
1796        let expr2 = udf.call(vec![col("b") + col("a")]);
1797        let plan = LogicalPlanBuilder::from(table_scan)
1798            .project(vec![expr1, expr2])?
1799            .build()?;
1800        assert_optimized_plan_equal!(
1801            plan,
1802            @ r"
1803        Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1804          Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1805            TableScan: test
1806        "
1807        )
1808    }
1809
1810    /// returns a "random" function that is marked volatile (aka each invocation
1811    /// returns a different value)
1812    ///
1813    /// Does not use datafusion_functions::rand to avoid introducing a
1814    /// dependency on that crate.
1815    fn rand_func() -> ScalarUDF {
1816        ScalarUDF::new_from_impl(RandomStub::new())
1817    }
1818
1819    #[derive(Debug, PartialEq, Eq, Hash)]
1820    struct RandomStub {
1821        signature: Signature,
1822    }
1823
1824    impl RandomStub {
1825        fn new() -> Self {
1826            Self {
1827                signature: Signature::exact(vec![], Volatility::Volatile),
1828            }
1829        }
1830    }
1831    impl ScalarUDFImpl for RandomStub {
1832        fn name(&self) -> &str {
1833            "random"
1834        }
1835
1836        fn signature(&self) -> &Signature {
1837            &self.signature
1838        }
1839
1840        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1841            Ok(DataType::Float64)
1842        }
1843
1844        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1845            panic!("dummy - not implemented")
1846        }
1847    }
1848
1849    /// Identical MoveTowardsLeafNodes expressions should NOT be deduplicated
1850    /// by CSE — they are cheap (e.g. struct field access) and the extraction
1851    /// rules deliberately duplicate them. Deduplicating causes optimizer
1852    /// instability where one optimizer rule will undo the work of another,
1853    /// resulting in an infinite optimization loop until the
1854    /// we hit the max iteration limit and then give up.
1855    #[test]
1856    fn test_leaf_expression_not_extracted() -> Result<()> {
1857        let table_scan = test_table_scan()?;
1858
1859        let leaf = leaf_udf_expr(col("a"));
1860        let plan = LogicalPlanBuilder::from(table_scan)
1861            .project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])?
1862            .build()?;
1863
1864        // Plan should be unchanged — no __common_expr introduced
1865        assert_optimized_plan_equal!(
1866            plan,
1867            @r"
1868        Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2
1869          TableScan: test
1870        "
1871        )
1872    }
1873
1874    /// When a MoveTowardsLeafNodes expression appears as a sub-expression of
1875    /// a larger expression that IS duplicated, only the outer expression gets
1876    /// deduplicated; the leaf sub-expression stays inline.
1877    #[test]
1878    fn test_leaf_subexpression_not_extracted() -> Result<()> {
1879        let table_scan = test_table_scan()?;
1880
1881        // leaf_udf(a) + b appears twice — the outer `+` is a common
1882        // sub-expression, but leaf_udf(a) by itself is MoveTowardsLeafNodes
1883        // and should not be extracted separately.
1884        let common = leaf_udf_expr(col("a")) + col("b");
1885        let plan = LogicalPlanBuilder::from(table_scan)
1886            .project(vec![common.clone().alias("c1"), common.alias("c2")])?
1887            .build()?;
1888
1889        // The whole `leaf_udf(a) + b` gets deduplicated as __common_expr_1,
1890        // but leaf_udf(a) alone is NOT pulled out.
1891        assert_optimized_plan_equal!(
1892            plan,
1893            @r"
1894        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1895          Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c
1896            TableScan: test
1897        "
1898        )
1899    }
1900}