datafusion_optimizer/
common_subexpr_eliminate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions
19
20use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41/// Performs Common Sub-expression Elimination optimization.
42///
43/// This optimization improves query performance by computing expressions that
44/// appear more than once and reusing those results rather than re-computing the
45/// same value
46///
47/// Currently only common sub-expressions within a single `LogicalPlan` are
48/// eliminated.
49///
50/// # Example
51///
52/// Given a projection that computes the same expensive expression
53/// multiple times such as parsing as string as a date with `to_date` twice:
54///
55/// ```text
56/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))])
57/// ```
58///
59/// This optimization will rewrite the plan to compute the common expression once
60/// using a new `ProjectionExec` and then rewrite the original expressions to
61/// refer to that new column.
62///
63/// ```text
64/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here
65///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
66/// ```
67#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71    pub fn new() -> Self {
72        Self {}
73    }
74
75    fn try_optimize_proj(
76        &self,
77        projection: Projection,
78        config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        let Projection {
81            expr,
82            input,
83            schema,
84            ..
85        } = projection;
86        let input = Arc::unwrap_or_clone(input);
87        self.try_unary_plan(expr, input, config)?
88            .map_data(|(new_expr, new_input)| {
89                Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90                    .map(LogicalPlan::Projection)
91            })
92    }
93
94    fn try_optimize_sort(
95        &self,
96        sort: Sort,
97        config: &dyn OptimizerConfig,
98    ) -> Result<Transformed<LogicalPlan>> {
99        let Sort { expr, input, fetch } = sort;
100        let input = Arc::unwrap_or_clone(input);
101        let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102            .into_iter()
103            .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104            .unzip();
105        let new_sort = self
106            .try_unary_plan(sort_expressions, input, config)?
107            .update_data(|(new_expr, new_input)| {
108                LogicalPlan::Sort(Sort {
109                    expr: new_expr
110                        .into_iter()
111                        .zip(sort_params)
112                        .map(|(expr, (asc, nulls_first))| SortExpr {
113                            expr,
114                            asc,
115                            nulls_first,
116                        })
117                        .collect(),
118                    input: Arc::new(new_input),
119                    fetch,
120                })
121            });
122        Ok(new_sort)
123    }
124
125    fn try_optimize_filter(
126        &self,
127        filter: Filter,
128        config: &dyn OptimizerConfig,
129    ) -> Result<Transformed<LogicalPlan>> {
130        let Filter {
131            predicate, input, ..
132        } = filter;
133        let input = Arc::unwrap_or_clone(input);
134        let expr = vec![predicate];
135        self.try_unary_plan(expr, input, config)?
136            .map_data(|(mut new_expr, new_input)| {
137                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
138                let new_predicate = new_expr.pop().unwrap();
139                Filter::try_new(new_predicate, Arc::new(new_input))
140                    .map(LogicalPlan::Filter)
141            })
142    }
143
144    fn try_optimize_window(
145        &self,
146        window: Window,
147        config: &dyn OptimizerConfig,
148    ) -> Result<Transformed<LogicalPlan>> {
149        // Collects window expressions from consecutive `LogicalPlan::Window` nodes into
150        // a list.
151        let (window_expr_list, window_schemas, input) =
152            get_consecutive_window_exprs(window);
153
154        // Extract common sub-expressions from the list.
155
156        match CSE::new(ExprCSEController::new(
157            config.alias_generator().as_ref(),
158            ExprMask::Normal,
159        ))
160        .extract_common_nodes(window_expr_list)?
161        {
162            // If there are common sub-expressions, then the insert a projection node
163            // with the common expressions between the new window nodes and the
164            // original input.
165            FoundCommonNodes::Yes {
166                common_nodes: common_exprs,
167                new_nodes_list: new_exprs_list,
168                original_nodes_list: original_exprs_list,
169            } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170                Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171            }),
172            FoundCommonNodes::No {
173                original_nodes_list: original_exprs_list,
174            } => Ok(Transformed::no((original_exprs_list, input, None))),
175        }?
176        // Recurse into the new input.
177        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
178        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179            self.rewrite(new_input, config)?.map_data(|new_input| {
180                Ok((new_window_expr_list, new_input, window_expr_list))
181            })
182        })?
183        // Rebuild the consecutive window nodes.
184        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185            // If there were common expressions extracted, then we need to make sure
186            // we restore the original column names.
187            // TODO: Although `find_common_exprs()` inserts aliases around extracted
188            //  common expressions this doesn't mean that the original column names
189            //  (schema) are preserved due to the inserted aliases are not always at
190            //  the top of the expression.
191            //  Let's consider improving `find_common_exprs()` to always keep column
192            //  names and get rid of additional name preserving logic here.
193            if let Some(window_expr_list) = window_expr_list {
194                let name_preserver = NamePreserver::new_for_projection();
195                let saved_names = window_expr_list
196                    .iter()
197                    .map(|exprs| {
198                        exprs
199                            .iter()
200                            .map(|expr| name_preserver.save(expr))
201                            .collect::<Vec<_>>()
202                    })
203                    .collect::<Vec<_>>();
204                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205                    new_input,
206                    |plan, (new_window_expr, saved_names)| {
207                        let new_window_expr = new_window_expr
208                            .into_iter()
209                            .zip(saved_names)
210                            .map(|(new_window_expr, saved_name)| {
211                                saved_name.restore(new_window_expr)
212                            })
213                            .collect::<Vec<_>>();
214                        Window::try_new(new_window_expr, Arc::new(plan))
215                            .map(LogicalPlan::Window)
216                    },
217                )
218            } else {
219                new_window_expr_list
220                    .into_iter()
221                    .zip(window_schemas)
222                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223                        Window::try_new_with_schema(
224                            new_window_expr,
225                            Arc::new(plan),
226                            schema,
227                        )
228                        .map(LogicalPlan::Window)
229                    })
230            }
231        })
232    }
233
234    fn try_optimize_aggregate(
235        &self,
236        aggregate: Aggregate,
237        config: &dyn OptimizerConfig,
238    ) -> Result<Transformed<LogicalPlan>> {
239        let Aggregate {
240            group_expr,
241            aggr_expr,
242            input,
243            schema,
244            ..
245        } = aggregate;
246        let input = Arc::unwrap_or_clone(input);
247        // Extract common sub-expressions from the aggregate and grouping expressions.
248        match CSE::new(ExprCSEController::new(
249            config.alias_generator().as_ref(),
250            ExprMask::Normal,
251        ))
252        .extract_common_nodes(vec![group_expr, aggr_expr])?
253        {
254            // If there are common sub-expressions, then insert a projection node
255            // with the common expressions between the new aggregate node and the
256            // original input.
257            FoundCommonNodes::Yes {
258                common_nodes: common_exprs,
259                new_nodes_list: mut new_exprs_list,
260                original_nodes_list: mut original_exprs_list,
261            } => {
262                let new_aggr_expr = new_exprs_list.pop().unwrap();
263                let new_group_expr = new_exprs_list.pop().unwrap();
264
265                build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266                    let aggr_expr = original_exprs_list.pop().unwrap();
267                    Transformed::yes((
268                        new_aggr_expr,
269                        new_group_expr,
270                        new_input,
271                        Some(aggr_expr),
272                    ))
273                })
274            }
275
276            FoundCommonNodes::No {
277                original_nodes_list: mut original_exprs_list,
278            } => {
279                let new_aggr_expr = original_exprs_list.pop().unwrap();
280                let new_group_expr = original_exprs_list.pop().unwrap();
281
282                Ok(Transformed::no((
283                    new_aggr_expr,
284                    new_group_expr,
285                    input,
286                    None,
287                )))
288            }
289        }?
290        // Recurse into the new input.
291        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
292        .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293            self.rewrite(new_input, config)?.map_data(|new_input| {
294                Ok((
295                    new_aggr_expr,
296                    new_group_expr,
297                    aggr_expr,
298                    Arc::new(new_input),
299                ))
300            })
301        })?
302        // Try extracting common aggregate expressions and rebuild the aggregate node.
303        .transform_data(
304            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305                // Extract common aggregate sub-expressions from the aggregate expressions.
306                match CSE::new(ExprCSEController::new(
307                    config.alias_generator().as_ref(),
308                    ExprMask::NormalAndAggregates,
309                ))
310                .extract_common_nodes(vec![new_aggr_expr])?
311                {
312                    FoundCommonNodes::Yes {
313                        common_nodes: common_exprs,
314                        new_nodes_list: mut new_exprs_list,
315                        original_nodes_list: mut original_exprs_list,
316                    } => {
317                        let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318                        let new_aggr_expr = original_exprs_list.pop().unwrap();
319
320                        let mut agg_exprs = common_exprs
321                            .into_iter()
322                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
323                            .collect::<Vec<_>>();
324
325                        let mut proj_exprs = vec![];
326                        for expr in &new_group_expr {
327                            extract_expressions(expr, &mut proj_exprs)
328                        }
329                        for (expr_rewritten, expr_orig) in
330                            rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
331                        {
332                            if expr_rewritten == expr_orig {
333                                if let Expr::Alias(Alias { expr, name, .. }) =
334                                    expr_rewritten
335                                {
336                                    agg_exprs.push(expr.alias(&name));
337                                    proj_exprs
338                                        .push(Expr::Column(Column::from_name(name)));
339                                } else {
340                                    let expr_alias =
341                                        config.alias_generator().next(CSE_PREFIX);
342                                    let (qualifier, field_name) =
343                                        expr_rewritten.qualified_name();
344                                    let out_name =
345                                        qualified_name(qualifier.as_ref(), &field_name);
346
347                                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
348                                    proj_exprs.push(
349                                        Expr::Column(Column::from_name(expr_alias))
350                                            .alias(out_name),
351                                    );
352                                }
353                            } else {
354                                proj_exprs.push(expr_rewritten);
355                            }
356                        }
357
358                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
359                            new_input,
360                            new_group_expr,
361                            agg_exprs,
362                        )?);
363                        Projection::try_new(proj_exprs, Arc::new(agg))
364                            .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
365                    }
366
367                    // If there aren't any common aggregate sub-expressions, then just
368                    // rebuild the aggregate node.
369                    FoundCommonNodes::No {
370                        original_nodes_list: mut original_exprs_list,
371                    } => {
372                        let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
373
374                        // If there were common expressions extracted, then we need to
375                        // make sure we restore the original column names.
376                        // TODO: Although `find_common_exprs()` inserts aliases around
377                        //  extracted common expressions this doesn't mean that the
378                        //  original column names (schema) are preserved due to the
379                        //  inserted aliases are not always at the top of the
380                        //  expression.
381                        //  Let's consider improving `find_common_exprs()` to always
382                        //  keep column names and get rid of additional name
383                        //  preserving logic here.
384                        if let Some(aggr_expr) = aggr_expr {
385                            let name_preserver = NamePreserver::new_for_projection();
386                            let saved_names = aggr_expr
387                                .iter()
388                                .map(|expr| name_preserver.save(expr))
389                                .collect::<Vec<_>>();
390                            let new_aggr_expr = rewritten_aggr_expr
391                                .into_iter()
392                                .zip(saved_names)
393                                .map(|(new_expr, saved_name)| {
394                                    saved_name.restore(new_expr)
395                                })
396                                .collect::<Vec<Expr>>();
397
398                            // Since `group_expr` may have changed, schema may also.
399                            // Use `try_new()` method.
400                            Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
401                                .map(LogicalPlan::Aggregate)
402                                .map(Transformed::no)
403                        } else {
404                            Aggregate::try_new_with_schema(
405                                new_input,
406                                new_group_expr,
407                                rewritten_aggr_expr,
408                                schema,
409                            )
410                            .map(LogicalPlan::Aggregate)
411                            .map(Transformed::no)
412                        }
413                    }
414                }
415            },
416        )
417    }
418
419    /// Rewrites the expr list and input to remove common subexpressions
420    ///
421    /// # Parameters
422    ///
423    /// * `exprs`: List of expressions in the node
424    /// * `input`: input plan (that produces the columns referred to in `exprs`)
425    ///
426    /// # Return value
427    ///
428    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
429    ///
430    /// 1. The original `input` of no common subexpressions were extracted
431    /// 2. A newly added projection on top of the original input
432    ///    that computes the common subexpressions
433    fn try_unary_plan(
434        &self,
435        exprs: Vec<Expr>,
436        input: LogicalPlan,
437        config: &dyn OptimizerConfig,
438    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
439        // Extract common sub-expressions from the expressions.
440        match CSE::new(ExprCSEController::new(
441            config.alias_generator().as_ref(),
442            ExprMask::Normal,
443        ))
444        .extract_common_nodes(vec![exprs])?
445        {
446            FoundCommonNodes::Yes {
447                common_nodes: common_exprs,
448                new_nodes_list: mut new_exprs_list,
449                original_nodes_list: _,
450            } => {
451                let new_exprs = new_exprs_list.pop().unwrap();
452                build_common_expr_project_plan(input, common_exprs)
453                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
454            }
455            FoundCommonNodes::No {
456                original_nodes_list: mut original_exprs_list,
457            } => {
458                let new_exprs = original_exprs_list.pop().unwrap();
459                Ok(Transformed::no((new_exprs, input)))
460            }
461        }?
462        // Recurse into the new input.
463        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
464        .transform_data(|(new_exprs, new_input)| {
465            self.rewrite(new_input, config)?
466                .map_data(|new_input| Ok((new_exprs, new_input)))
467        })
468    }
469}
470
471/// Get all window expressions inside the consecutive window operators.
472///
473/// Returns the window expressions, and the input to the deepest child
474/// LogicalPlan.
475///
476/// For example, if the input window looks like
477///
478/// ```text
479///   LogicalPlan::Window(exprs=[a, b, c])
480///     LogicalPlan::Window(exprs=[d])
481///       InputPlan
482/// ```
483///
484/// Returns:
485/// *  `window_exprs`: `[[a, b, c], [d]]`
486/// * InputPlan
487///
488/// Consecutive window expressions may refer to same complex expression.
489///
490/// If same complex expression is referred more than once by subsequent
491/// `WindowAggr`s, we can cache complex expression by evaluating it with a
492/// projection before the first WindowAggr.
493///
494/// This enables us to cache complex expression "c3+c4" for following plan:
495///
496/// ```text
497/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
498/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
499/// ```
500///
501/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
502fn get_consecutive_window_exprs(
503    window: Window,
504) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
505    let mut window_expr_list = vec![];
506    let mut window_schemas = vec![];
507    let mut plan = LogicalPlan::Window(window);
508    while let LogicalPlan::Window(Window {
509        input,
510        window_expr,
511        schema,
512    }) = plan
513    {
514        window_expr_list.push(window_expr);
515        window_schemas.push(schema);
516
517        plan = Arc::unwrap_or_clone(input);
518    }
519    (window_expr_list, window_schemas, plan)
520}
521
522impl OptimizerRule for CommonSubexprEliminate {
523    fn supports_rewrite(&self) -> bool {
524        true
525    }
526
527    fn apply_order(&self) -> Option<ApplyOrder> {
528        // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
529        // This is because in some cases adjacent nodes are collected (e.g. `Window`) and
530        // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
531        None
532    }
533
534    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
535    fn rewrite(
536        &self,
537        plan: LogicalPlan,
538        config: &dyn OptimizerConfig,
539    ) -> Result<Transformed<LogicalPlan>> {
540        let original_schema = Arc::clone(plan.schema());
541
542        let optimized_plan = match plan {
543            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
544            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
545            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
546            LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
547            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
548            LogicalPlan::Join(_)
549            | LogicalPlan::Repartition(_)
550            | LogicalPlan::Union(_)
551            | LogicalPlan::TableScan(_)
552            | LogicalPlan::Values(_)
553            | LogicalPlan::EmptyRelation(_)
554            | LogicalPlan::Subquery(_)
555            | LogicalPlan::SubqueryAlias(_)
556            | LogicalPlan::Limit(_)
557            | LogicalPlan::Ddl(_)
558            | LogicalPlan::Explain(_)
559            | LogicalPlan::Analyze(_)
560            | LogicalPlan::Statement(_)
561            | LogicalPlan::DescribeTable(_)
562            | LogicalPlan::Distinct(_)
563            | LogicalPlan::Extension(_)
564            | LogicalPlan::Dml(_)
565            | LogicalPlan::Copy(_)
566            | LogicalPlan::Unnest(_)
567            | LogicalPlan::RecursiveQuery(_) => {
568                // This rule handles recursion itself in a `ApplyOrder::TopDown` like
569                // manner.
570                plan.map_children(|c| self.rewrite(c, config))?
571            }
572        };
573
574        // If we rewrote the plan, ensure the schema stays the same
575        if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
576        {
577            optimized_plan.map_data(|optimized_plan| {
578                build_recover_project_plan(&original_schema, optimized_plan)
579            })
580        } else {
581            Ok(optimized_plan)
582        }
583    }
584
585    fn name(&self) -> &str {
586        "common_sub_expression_eliminate"
587    }
588}
589
590/// Which type of [expressions](Expr) should be considered for rewriting?
591#[derive(Debug, Clone, Copy)]
592enum ExprMask {
593    /// Ignores:
594    ///
595    /// - [`Literal`](Expr::Literal)
596    /// - [`Columns`](Expr::Column)
597    /// - [`ScalarVariable`](Expr::ScalarVariable)
598    /// - [`Alias`](Expr::Alias)
599    /// - [`Wildcard`](Expr::Wildcard)
600    /// - [`AggregateFunction`](Expr::AggregateFunction)
601    Normal,
602
603    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
604    NormalAndAggregates,
605}
606
607struct ExprCSEController<'a> {
608    alias_generator: &'a AliasGenerator,
609    mask: ExprMask,
610
611    // how many aliases have we seen so far
612    alias_counter: usize,
613}
614
615impl<'a> ExprCSEController<'a> {
616    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
617        Self {
618            alias_generator,
619            mask,
620            alias_counter: 0,
621        }
622    }
623}
624
625impl CSEController for ExprCSEController<'_> {
626    type Node = Expr;
627
628    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
629        match node {
630            // In case of `ScalarFunction`s we don't know which children are surely
631            // executed so start visiting all children conditionally and stop the
632            // recursion with `TreeNodeRecursion::Jump`.
633            Expr::ScalarFunction(ScalarFunction { func, args })
634                if func.short_circuits() =>
635            {
636                Some((vec![], args.iter().collect()))
637            }
638
639            // In case of `And` and `Or` the first child is surely executed, but we
640            // account subexpressions as conditional in the second.
641            Expr::BinaryExpr(BinaryExpr {
642                left,
643                op: Operator::And | Operator::Or,
644                right,
645            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
646
647            // In case of `Case` the optional base expression and the first when
648            // expressions are surely executed, but we account subexpressions as
649            // conditional in the others.
650            Expr::Case(Case {
651                expr,
652                when_then_expr,
653                else_expr,
654            }) => Some((
655                expr.iter()
656                    .map(|e| e.as_ref())
657                    .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
658                    .collect(),
659                when_then_expr
660                    .iter()
661                    .take(1)
662                    .map(|(_, then)| then.as_ref())
663                    .chain(
664                        when_then_expr
665                            .iter()
666                            .skip(1)
667                            .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
668                    )
669                    .chain(else_expr.iter().map(|e| e.as_ref()))
670                    .collect(),
671            )),
672            _ => None,
673        }
674    }
675
676    fn is_valid(node: &Expr) -> bool {
677        !node.is_volatile_node()
678    }
679
680    fn is_ignored(&self, node: &Expr) -> bool {
681        // TODO: remove the next line after `Expr::Wildcard` is removed
682        #[expect(deprecated)]
683        let is_normal_minus_aggregates = matches!(
684            node,
685            Expr::Literal(..)
686                | Expr::Column(..)
687                | Expr::ScalarVariable(..)
688                | Expr::Alias(..)
689                | Expr::Wildcard { .. }
690        );
691
692        let is_aggr = matches!(node, Expr::AggregateFunction(..));
693
694        match self.mask {
695            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
696            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
697        }
698    }
699
700    fn generate_alias(&self) -> String {
701        self.alias_generator.next(CSE_PREFIX)
702    }
703
704    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
705        // alias the expressions without an `Alias` ancestor node
706        if self.alias_counter > 0 {
707            col(alias)
708        } else {
709            self.alias_counter += 1;
710            col(alias).alias(node.schema_name().to_string())
711        }
712    }
713
714    fn rewrite_f_down(&mut self, node: &Expr) {
715        if matches!(node, Expr::Alias(_)) {
716            self.alias_counter += 1;
717        }
718    }
719    fn rewrite_f_up(&mut self, node: &Expr) {
720        if matches!(node, Expr::Alias(_)) {
721            self.alias_counter -= 1
722        }
723    }
724}
725
726impl Default for CommonSubexprEliminate {
727    fn default() -> Self {
728        Self::new()
729    }
730}
731
732/// Build the "intermediate" projection plan that evaluates the extracted common
733/// expressions.
734///
735/// # Arguments
736/// input: the input plan
737///
738/// common_exprs: which common subexpressions were used (and thus are added to
739/// intermediate projection)
740///
741/// expr_stats: the set of common subexpressions
742fn build_common_expr_project_plan(
743    input: LogicalPlan,
744    common_exprs: Vec<(Expr, String)>,
745) -> Result<LogicalPlan> {
746    let mut fields_set = BTreeSet::new();
747    let mut project_exprs = common_exprs
748        .into_iter()
749        .map(|(expr, expr_alias)| {
750            fields_set.insert(expr_alias.clone());
751            Ok(expr.alias(expr_alias))
752        })
753        .collect::<Result<Vec<_>>>()?;
754
755    for (qualifier, field) in input.schema().iter() {
756        if fields_set.insert(qualified_name(qualifier, field.name())) {
757            project_exprs.push(Expr::from((qualifier, field)));
758        }
759    }
760
761    Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
762}
763
764/// Build the projection plan to eliminate unnecessary columns produced by
765/// the "intermediate" projection plan built in [build_common_expr_project_plan].
766///
767/// This is required to keep the schema the same for plans that pass the input
768/// on to the output, such as `Filter` or `Sort`.
769fn build_recover_project_plan(
770    schema: &DFSchema,
771    input: LogicalPlan,
772) -> Result<LogicalPlan> {
773    let col_exprs = schema.iter().map(Expr::from).collect();
774    Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
775}
776
777fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
778    if let Expr::GroupingSet(groupings) = expr {
779        for e in groupings.distinct_expr() {
780            let (qualifier, field_name) = e.qualified_name();
781            let col = Column::new(qualifier, field_name);
782            result.push(Expr::Column(col))
783        }
784    } else {
785        let (qualifier, field_name) = expr.qualified_name();
786        let col = Column::new(qualifier, field_name);
787        result.push(Expr::Column(col));
788    }
789}
790
791#[cfg(test)]
792mod test {
793    use std::any::Any;
794    use std::iter;
795
796    use arrow::datatypes::{DataType, Field, Schema};
797    use datafusion_expr::logical_plan::{table_scan, JoinType};
798    use datafusion_expr::{
799        grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
800        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
801        SimpleAggregateUDF, Volatility,
802    };
803    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
804
805    use super::*;
806    use crate::assert_optimized_plan_eq_snapshot;
807    use crate::optimizer::OptimizerContext;
808    use crate::test::*;
809    use datafusion_expr::test::function_stub::{avg, sum};
810
811    macro_rules! assert_optimized_plan_equal {
812        (
813            $config:expr,
814            $plan:expr,
815            @ $expected:literal $(,)?
816        ) => {{
817            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
818            assert_optimized_plan_eq_snapshot!(
819                $config,
820                rules,
821                $plan,
822                @ $expected,
823            )
824        }};
825
826        (
827            $plan:expr,
828            @ $expected:literal $(,)?
829        ) => {{
830            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
831            let optimizer_ctx = OptimizerContext::new();
832            assert_optimized_plan_eq_snapshot!(
833                optimizer_ctx,
834                rules,
835                $plan,
836                @ $expected,
837            )
838        }};
839    }
840
841    #[test]
842    fn tpch_q1_simplified() -> Result<()> {
843        // SQL:
844        //  select
845        //      sum(a * (1 - b)),
846        //      sum(a * (1 - b) * (1 + c))
847        //  from T;
848        //
849        // The manual assembled logical plan don't contains the outermost `Projection`.
850
851        let table_scan = test_table_scan()?;
852
853        let plan = LogicalPlanBuilder::from(table_scan)
854            .aggregate(
855                iter::empty::<Expr>(),
856                vec![
857                    sum(col("a") * (lit(1) - col("b"))),
858                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
859                ],
860            )?
861            .build()?;
862
863        assert_optimized_plan_equal!(
864            plan,
865            @ r"
866        Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
867          Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
868            TableScan: test
869        "
870        )
871    }
872
873    #[test]
874    fn nested_aliases() -> Result<()> {
875        let table_scan = test_table_scan()?;
876
877        let plan = LogicalPlanBuilder::from(table_scan)
878            .project(vec![
879                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
880                col("a") + col("b"),
881            ])?
882            .build()?;
883
884        assert_optimized_plan_equal!(
885            plan,
886            @ r"
887        Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
888          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
889            TableScan: test
890        "
891        )
892    }
893
894    #[test]
895    fn aggregate() -> Result<()> {
896        let table_scan = test_table_scan()?;
897
898        let return_type = DataType::UInt32;
899        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
900        let udf_agg = |inner: Expr| {
901            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
902                Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
903                    "my_agg",
904                    Signature::exact(vec![DataType::UInt32], Volatility::Stable),
905                    return_type.clone(),
906                    Arc::clone(&accumulator),
907                    vec![Field::new("value", DataType::UInt32, true).into()],
908                ))),
909                vec![inner],
910                false,
911                None,
912                None,
913                None,
914            ))
915        };
916
917        // test: common aggregates
918        let plan = LogicalPlanBuilder::from(table_scan.clone())
919            .aggregate(
920                iter::empty::<Expr>(),
921                vec![
922                    // common: avg(col("a"))
923                    avg(col("a")).alias("col1"),
924                    avg(col("a")).alias("col2"),
925                    // no common
926                    avg(col("b")).alias("col3"),
927                    avg(col("c")),
928                    // common: udf_agg(col("a"))
929                    udf_agg(col("a")).alias("col4"),
930                    udf_agg(col("a")).alias("col5"),
931                    // no common
932                    udf_agg(col("b")).alias("col6"),
933                    udf_agg(col("c")),
934                ],
935            )?
936            .build()?;
937
938        assert_optimized_plan_equal!(
939            plan,
940            @ r"
941        Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
942          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
943            TableScan: test
944        "
945        )?;
946
947        // test: trafo after aggregate
948        let plan = LogicalPlanBuilder::from(table_scan.clone())
949            .aggregate(
950                iter::empty::<Expr>(),
951                vec![
952                    lit(1) + avg(col("a")),
953                    lit(1) - avg(col("a")),
954                    lit(1) + udf_agg(col("a")),
955                    lit(1) - udf_agg(col("a")),
956                ],
957            )?
958            .build()?;
959
960        assert_optimized_plan_equal!(
961            plan,
962            @ r"
963        Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
964          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
965            TableScan: test
966        "
967        )?;
968
969        // test: transformation before aggregate
970        let plan = LogicalPlanBuilder::from(table_scan.clone())
971            .aggregate(
972                iter::empty::<Expr>(),
973                vec![
974                    avg(lit(1u32) + col("a")).alias("col1"),
975                    udf_agg(lit(1u32) + col("a")).alias("col2"),
976                ],
977            )?
978            .build()?;
979
980        assert_optimized_plan_equal!(
981            plan,
982            @ r"
983        Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
984          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
985            TableScan: test
986        "
987        )?;
988
989        // test: common between agg and group
990        let plan = LogicalPlanBuilder::from(table_scan.clone())
991            .aggregate(
992                vec![lit(1u32) + col("a")],
993                vec![
994                    avg(lit(1u32) + col("a")).alias("col1"),
995                    udf_agg(lit(1u32) + col("a")).alias("col2"),
996                ],
997            )?
998            .build()?;
999
1000        assert_optimized_plan_equal!(
1001            plan,
1002            @ r"
1003        Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1004          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1005            TableScan: test
1006        "
1007        )?;
1008
1009        // test: all mixed
1010        let plan = LogicalPlanBuilder::from(table_scan)
1011            .aggregate(
1012                vec![lit(1u32) + col("a")],
1013                vec![
1014                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1015                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1016                    avg(lit(1u32) + col("a")),
1017                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1018                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1019                    udf_agg(lit(1u32) + col("a")),
1020                ],
1021            )?
1022            .build()?;
1023
1024        assert_optimized_plan_equal!(
1025            plan,
1026            @ r"
1027        Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1028          Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1029            Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1030              TableScan: test
1031        "
1032        )
1033    }
1034
1035    #[test]
1036    fn aggregate_with_relations_and_dots() -> Result<()> {
1037        let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1038        let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1039
1040        let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1041
1042        let plan = LogicalPlanBuilder::from(table_scan)
1043            .aggregate(
1044                vec![col_a.clone()],
1045                vec![
1046                    (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1047                    avg(lit(1u32) + col_a),
1048                ],
1049            )?
1050            .build()?;
1051
1052        assert_optimized_plan_equal!(
1053            plan,
1054            @ r"
1055        Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1056          Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1057            Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1058              TableScan: table.test
1059        "
1060        )
1061    }
1062
1063    #[test]
1064    fn subexpr_in_same_order() -> Result<()> {
1065        let table_scan = test_table_scan()?;
1066
1067        let plan = LogicalPlanBuilder::from(table_scan)
1068            .project(vec![
1069                (lit(1) + col("a")).alias("first"),
1070                (lit(1) + col("a")).alias("second"),
1071            ])?
1072            .build()?;
1073
1074        assert_optimized_plan_equal!(
1075            plan,
1076            @ r"
1077        Projection: __common_expr_1 AS first, __common_expr_1 AS second
1078          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1079            TableScan: test
1080        "
1081        )
1082    }
1083
1084    #[test]
1085    fn subexpr_in_different_order() -> Result<()> {
1086        let table_scan = test_table_scan()?;
1087
1088        let plan = LogicalPlanBuilder::from(table_scan)
1089            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1090            .build()?;
1091
1092        assert_optimized_plan_equal!(
1093            plan,
1094            @ r"
1095        Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1096          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1097            TableScan: test
1098        "
1099        )
1100    }
1101
1102    #[test]
1103    fn cross_plans_subexpr() -> Result<()> {
1104        let table_scan = test_table_scan()?;
1105
1106        let plan = LogicalPlanBuilder::from(table_scan)
1107            .project(vec![lit(1) + col("a"), col("a")])?
1108            .project(vec![lit(1) + col("a")])?
1109            .build()?;
1110
1111        assert_optimized_plan_equal!(
1112            plan,
1113            @ r"
1114        Projection: Int32(1) + test.a
1115          Projection: Int32(1) + test.a, test.a
1116            TableScan: test
1117        "
1118        )
1119    }
1120
1121    #[test]
1122    fn redundant_project_fields() {
1123        let table_scan = test_table_scan().unwrap();
1124        let c_plus_a = col("c") + col("a");
1125        let b_plus_a = col("b") + col("a");
1126        let common_exprs_1 = vec![
1127            (c_plus_a, format!("{CSE_PREFIX}_1")),
1128            (b_plus_a, format!("{CSE_PREFIX}_2")),
1129        ];
1130        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1131        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1132        let common_exprs_2 = vec![
1133            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1134            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1135        ];
1136        let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1137        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1138
1139        let mut field_set = BTreeSet::new();
1140        for name in project_2.schema().field_names() {
1141            assert!(field_set.insert(name));
1142        }
1143    }
1144
1145    #[test]
1146    fn redundant_project_fields_join_input() {
1147        let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1148        let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1149        let join = LogicalPlanBuilder::from(table_scan_1)
1150            .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1151            .unwrap()
1152            .build()
1153            .unwrap();
1154        let c_plus_a = col("test1.c") + col("test1.a");
1155        let b_plus_a = col("test1.b") + col("test1.a");
1156        let common_exprs_1 = vec![
1157            (c_plus_a, format!("{CSE_PREFIX}_1")),
1158            (b_plus_a, format!("{CSE_PREFIX}_2")),
1159        ];
1160        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1161        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1162        let common_exprs_2 = vec![
1163            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1164            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1165        ];
1166        let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1167        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1168
1169        let mut field_set = BTreeSet::new();
1170        for name in project_2.schema().field_names() {
1171            assert!(field_set.insert(name));
1172        }
1173    }
1174
1175    #[test]
1176    fn eliminated_subexpr_datatype() {
1177        use datafusion_expr::cast;
1178
1179        let schema = Schema::new(vec![
1180            Field::new("a", DataType::UInt64, false),
1181            Field::new("b", DataType::UInt64, false),
1182            Field::new("c", DataType::UInt64, false),
1183        ]);
1184
1185        let plan = table_scan(Some("table"), &schema, None)
1186            .unwrap()
1187            .filter(
1188                cast(col("a"), DataType::Int64)
1189                    .lt(lit(1_i64))
1190                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1191            )
1192            .unwrap()
1193            .build()
1194            .unwrap();
1195        let rule = CommonSubexprEliminate::new();
1196        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1197        assert!(optimized_plan.transformed);
1198        let optimized_plan = optimized_plan.data;
1199
1200        let schema = optimized_plan.schema();
1201        let fields_with_datatypes: Vec<_> = schema
1202            .fields()
1203            .iter()
1204            .map(|field| (field.name(), field.data_type()))
1205            .collect();
1206        let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1207        let expected = r#"[
1208    (
1209        "a",
1210        UInt64,
1211    ),
1212    (
1213        "b",
1214        UInt64,
1215    ),
1216    (
1217        "c",
1218        UInt64,
1219    ),
1220]"#;
1221        assert_eq!(expected, formatted_fields_with_datatype);
1222    }
1223
1224    #[test]
1225    fn filter_schema_changed() -> Result<()> {
1226        let table_scan = test_table_scan()?;
1227
1228        let plan = LogicalPlanBuilder::from(table_scan)
1229            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1230            .build()?;
1231
1232        assert_optimized_plan_equal!(
1233            plan,
1234            @ r"
1235        Projection: test.a, test.b, test.c
1236          Filter: __common_expr_1 - Int32(10) > __common_expr_1
1237            Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1238              TableScan: test
1239        "
1240        )
1241    }
1242
1243    #[test]
1244    fn test_extract_expressions_from_grouping_set() -> Result<()> {
1245        let mut result = Vec::with_capacity(3);
1246        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1247        extract_expressions(&grouping, &mut result);
1248
1249        assert!(result.len() == 3);
1250        Ok(())
1251    }
1252
1253    #[test]
1254    fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1255        let mut result = Vec::with_capacity(2);
1256        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1257        extract_expressions(&grouping, &mut result);
1258        assert!(result.len() == 2);
1259        Ok(())
1260    }
1261
1262    #[test]
1263    fn test_alias_collision() -> Result<()> {
1264        let table_scan = test_table_scan()?;
1265
1266        let config = OptimizerContext::new();
1267        let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1268        let plan = LogicalPlanBuilder::from(table_scan.clone())
1269            .project(vec![
1270                (col("a") + col("b")).alias(common_expr_1.clone()),
1271                col("c"),
1272            ])?
1273            .project(vec![
1274                col(common_expr_1.clone()).alias("c1"),
1275                col(common_expr_1).alias("c2"),
1276                (col("c") + lit(2)).alias("c3"),
1277                (col("c") + lit(2)).alias("c4"),
1278            ])?
1279            .build()?;
1280
1281        assert_optimized_plan_equal!(
1282            config,
1283            plan,
1284            @ r"
1285        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1286          Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1287            Projection: test.a + test.b AS __common_expr_1, test.c
1288              TableScan: test
1289        "
1290        )?;
1291
1292        let config = OptimizerContext::new();
1293        let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1294        let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1295        let plan = LogicalPlanBuilder::from(table_scan)
1296            .project(vec![
1297                (col("a") + col("b")).alias(common_expr_2.clone()),
1298                col("c"),
1299            ])?
1300            .project(vec![
1301                col(common_expr_2.clone()).alias("c1"),
1302                col(common_expr_2).alias("c2"),
1303                (col("c") + lit(2)).alias("c3"),
1304                (col("c") + lit(2)).alias("c4"),
1305            ])?
1306            .build()?;
1307
1308        assert_optimized_plan_equal!(
1309            config,
1310            plan,
1311            @ r"
1312        Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1313          Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1314            Projection: test.a + test.b AS __common_expr_2, test.c
1315              TableScan: test
1316        "
1317        )?;
1318
1319        Ok(())
1320    }
1321
1322    #[test]
1323    fn test_extract_expressions_from_col() -> Result<()> {
1324        let mut result = Vec::with_capacity(1);
1325        extract_expressions(&col("a"), &mut result);
1326        assert!(result.len() == 1);
1327        Ok(())
1328    }
1329
1330    #[test]
1331    fn test_short_circuits() -> Result<()> {
1332        let table_scan = test_table_scan()?;
1333
1334        let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1335        let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1336        let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1337        let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1338        let plan = LogicalPlanBuilder::from(table_scan)
1339            .project(vec![
1340                extracted_short_circuit.clone().alias("c1"),
1341                extracted_short_circuit.alias("c2"),
1342                extracted_short_circuit_leg_1
1343                    .clone()
1344                    .or(not_extracted_short_circuit_leg_2.clone())
1345                    .alias("c3"),
1346                extracted_short_circuit_leg_1
1347                    .and(not_extracted_short_circuit_leg_2)
1348                    .alias("c4"),
1349                extracted_short_circuit_leg_3
1350                    .clone()
1351                    .or(extracted_short_circuit_leg_3)
1352                    .alias("c5"),
1353            ])?
1354            .build()?;
1355
1356        assert_optimized_plan_equal!(
1357            plan,
1358            @ r"
1359        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1360          Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1361            TableScan: test
1362        "
1363        )
1364    }
1365
1366    #[test]
1367    fn test_volatile() -> Result<()> {
1368        let table_scan = test_table_scan()?;
1369
1370        let extracted_child = col("a") + col("b");
1371        let rand = rand_func().call(vec![]);
1372        let not_extracted_volatile = extracted_child + rand;
1373        let plan = LogicalPlanBuilder::from(table_scan)
1374            .project(vec![
1375                not_extracted_volatile.clone().alias("c1"),
1376                not_extracted_volatile.alias("c2"),
1377            ])?
1378            .build()?;
1379
1380        assert_optimized_plan_equal!(
1381            plan,
1382            @ r"
1383        Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1384          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1385            TableScan: test
1386        "
1387        )
1388    }
1389
1390    #[test]
1391    fn test_volatile_short_circuits() -> Result<()> {
1392        let table_scan = test_table_scan()?;
1393
1394        let rand = rand_func().call(vec![]);
1395        let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1396        let not_extracted_volatile_short_circuit_1 =
1397            extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1398        let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1399        let not_extracted_volatile_short_circuit_2 =
1400            rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1401        let plan = LogicalPlanBuilder::from(table_scan)
1402            .project(vec![
1403                not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1404                not_extracted_volatile_short_circuit_1.alias("c2"),
1405                not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1406                not_extracted_volatile_short_circuit_2.alias("c4"),
1407            ])?
1408            .build()?;
1409
1410        assert_optimized_plan_equal!(
1411            plan,
1412            @ r"
1413        Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1414          Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1415            TableScan: test
1416        "
1417        )
1418    }
1419
1420    #[test]
1421    fn test_non_top_level_common_expression() -> Result<()> {
1422        let table_scan = test_table_scan()?;
1423
1424        let common_expr = col("a") + col("b");
1425        let plan = LogicalPlanBuilder::from(table_scan)
1426            .project(vec![
1427                common_expr.clone().alias("c1"),
1428                common_expr.alias("c2"),
1429            ])?
1430            .project(vec![col("c1"), col("c2")])?
1431            .build()?;
1432
1433        assert_optimized_plan_equal!(
1434            plan,
1435            @ r"
1436        Projection: c1, c2
1437          Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1438            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1439              TableScan: test
1440        "
1441        )
1442    }
1443
1444    #[test]
1445    fn test_nested_common_expression() -> Result<()> {
1446        let table_scan = test_table_scan()?;
1447
1448        let nested_common_expr = col("a") + col("b");
1449        let common_expr = nested_common_expr.clone() * nested_common_expr;
1450        let plan = LogicalPlanBuilder::from(table_scan)
1451            .project(vec![
1452                common_expr.clone().alias("c1"),
1453                common_expr.alias("c2"),
1454            ])?
1455            .build()?;
1456
1457        assert_optimized_plan_equal!(
1458            plan,
1459            @ r"
1460        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1461          Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1462            Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1463              TableScan: test
1464        "
1465        )
1466    }
1467
1468    #[test]
1469    fn test_normalize_add_expression() -> Result<()> {
1470        // a + b <=> b + a
1471        let table_scan = test_table_scan()?;
1472        let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1473        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1474
1475        assert_optimized_plan_equal!(
1476            plan,
1477            @ r"
1478        Projection: test.a, test.b, test.c
1479          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1480            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1481              TableScan: test
1482        "
1483        )
1484    }
1485
1486    #[test]
1487    fn test_normalize_multi_expression() -> Result<()> {
1488        // a * b <=> b * a
1489        let table_scan = test_table_scan()?;
1490        let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1491        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1492
1493        assert_optimized_plan_equal!(
1494            plan,
1495            @ r"
1496        Projection: test.a, test.b, test.c
1497          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1498            Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1499              TableScan: test
1500        "
1501        )
1502    }
1503
1504    #[test]
1505    fn test_normalize_bitset_and_expression() -> Result<()> {
1506        // a & b <=> b & a
1507        let table_scan = test_table_scan()?;
1508        let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1509        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1510
1511        assert_optimized_plan_equal!(
1512            plan,
1513            @ r"
1514        Projection: test.a, test.b, test.c
1515          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1516            Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1517              TableScan: test
1518        "
1519        )
1520    }
1521
1522    #[test]
1523    fn test_normalize_bitset_or_expression() -> Result<()> {
1524        // a | b <=> b | a
1525        let table_scan = test_table_scan()?;
1526        let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1527        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1528
1529        assert_optimized_plan_equal!(
1530            plan,
1531            @ r"
1532        Projection: test.a, test.b, test.c
1533          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1534            Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1535              TableScan: test
1536        "
1537        )
1538    }
1539
1540    #[test]
1541    fn test_normalize_bitset_xor_expression() -> Result<()> {
1542        // a # b <=> b # a
1543        let table_scan = test_table_scan()?;
1544        let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1545        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1546
1547        assert_optimized_plan_equal!(
1548            plan,
1549            @ r"
1550        Projection: test.a, test.b, test.c
1551          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1552            Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1553              TableScan: test
1554        "
1555        )
1556    }
1557
1558    #[test]
1559    fn test_normalize_eq_expression() -> Result<()> {
1560        // a = b <=> b = a
1561        let table_scan = test_table_scan()?;
1562        let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1563        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1564
1565        assert_optimized_plan_equal!(
1566            plan,
1567            @ r"
1568        Projection: test.a, test.b, test.c
1569          Filter: __common_expr_1 AND __common_expr_1
1570            Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1571              TableScan: test
1572        "
1573        )
1574    }
1575
1576    #[test]
1577    fn test_normalize_ne_expression() -> Result<()> {
1578        // a != b <=> b != a
1579        let table_scan = test_table_scan()?;
1580        let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1581        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1582
1583        assert_optimized_plan_equal!(
1584            plan,
1585            @ r"
1586        Projection: test.a, test.b, test.c
1587          Filter: __common_expr_1 AND __common_expr_1
1588            Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1589              TableScan: test
1590        "
1591        )
1592    }
1593
1594    #[test]
1595    fn test_normalize_complex_expression() -> Result<()> {
1596        // case1: a + b * c <=> b * c + a
1597        let table_scan = test_table_scan()?;
1598        let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1599            .eq(lit(30));
1600        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1601
1602        assert_optimized_plan_equal!(
1603            plan,
1604            @ r"
1605        Projection: test.a, test.b, test.c
1606          Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1607            Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1608              TableScan: test
1609        "
1610        )?;
1611
1612        // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
1613        let table_scan = test_table_scan()?;
1614        let expr = (((col("a") + col("b") / col("c")) * col("c"))
1615            / (col("c") * (col("b") / col("c") + col("a")))
1616            + col("a"))
1617        .eq(lit(30));
1618        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1619
1620        assert_optimized_plan_equal!(
1621            plan,
1622            @ r"
1623        Projection: test.a, test.b, test.c
1624          Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1625            Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1626              TableScan: test
1627        "
1628        )?;
1629
1630        // c2 / (c1 + c3) <=> c2 / (c3 + c1)
1631        let table_scan = test_table_scan()?;
1632        let expr = ((col("b") / (col("a") + col("c")))
1633            * (col("b") / (col("c") + col("a"))))
1634        .eq(lit(30));
1635        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1636        assert_optimized_plan_equal!(
1637            plan,
1638            @ r"
1639        Projection: test.a, test.b, test.c
1640          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1641            Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1642              TableScan: test
1643        "
1644        )?;
1645
1646        Ok(())
1647    }
1648
1649    #[derive(Debug)]
1650    pub struct TestUdf {
1651        signature: Signature,
1652    }
1653
1654    impl TestUdf {
1655        pub fn new() -> Self {
1656            Self {
1657                signature: Signature::numeric(1, Volatility::Immutable),
1658            }
1659        }
1660    }
1661
1662    impl ScalarUDFImpl for TestUdf {
1663        fn as_any(&self) -> &dyn Any {
1664            self
1665        }
1666        fn name(&self) -> &str {
1667            "my_udf"
1668        }
1669
1670        fn signature(&self) -> &Signature {
1671            &self.signature
1672        }
1673
1674        fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1675            Ok(DataType::Int32)
1676        }
1677
1678        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1679            panic!("not implemented")
1680        }
1681    }
1682
1683    #[test]
1684    fn test_normalize_inner_binary_expression() -> Result<()> {
1685        // Not(a == b) <=> Not(b == a)
1686        let table_scan = test_table_scan()?;
1687        let expr1 = not(col("a").eq(col("b")));
1688        let expr2 = not(col("b").eq(col("a")));
1689        let plan = LogicalPlanBuilder::from(table_scan)
1690            .project(vec![expr1, expr2])?
1691            .build()?;
1692        assert_optimized_plan_equal!(
1693            plan,
1694            @ r"
1695        Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1696          Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1697            TableScan: test
1698        "
1699        )?;
1700
1701        // is_null(a == b) <=> is_null(b == a)
1702        let table_scan = test_table_scan()?;
1703        let expr1 = is_null(col("a").eq(col("b")));
1704        let expr2 = is_null(col("b").eq(col("a")));
1705        let plan = LogicalPlanBuilder::from(table_scan)
1706            .project(vec![expr1, expr2])?
1707            .build()?;
1708        assert_optimized_plan_equal!(
1709            plan,
1710            @ r"
1711        Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1712          Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1713            TableScan: test
1714        "
1715        )?;
1716
1717        // a + b between 0 and 10 <=> b + a between 0 and 10
1718        let table_scan = test_table_scan()?;
1719        let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1720        let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1721        let plan = LogicalPlanBuilder::from(table_scan)
1722            .project(vec![expr1, expr2])?
1723            .build()?;
1724        assert_optimized_plan_equal!(
1725            plan,
1726            @ r"
1727        Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1728          Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1729            TableScan: test
1730        "
1731        )?;
1732
1733        // c between a + b and 10 <=> c between b + a and 10
1734        let table_scan = test_table_scan()?;
1735        let expr1 = col("c").between(col("a") + col("b"), lit(10));
1736        let expr2 = col("c").between(col("b") + col("a"), lit(10));
1737        let plan = LogicalPlanBuilder::from(table_scan)
1738            .project(vec![expr1, expr2])?
1739            .build()?;
1740        assert_optimized_plan_equal!(
1741            plan,
1742            @ r"
1743        Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1744          Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1745            TableScan: test
1746        "
1747        )?;
1748
1749        // function call with argument <=> function call with argument
1750        let udf = ScalarUDF::from(TestUdf::new());
1751        let table_scan = test_table_scan()?;
1752        let expr1 = udf.call(vec![col("a") + col("b")]);
1753        let expr2 = udf.call(vec![col("b") + col("a")]);
1754        let plan = LogicalPlanBuilder::from(table_scan)
1755            .project(vec![expr1, expr2])?
1756            .build()?;
1757        assert_optimized_plan_equal!(
1758            plan,
1759            @ r"
1760        Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1761          Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1762            TableScan: test
1763        "
1764        )
1765    }
1766
1767    /// returns a "random" function that is marked volatile (aka each invocation
1768    /// returns a different value)
1769    ///
1770    /// Does not use datafusion_functions::rand to avoid introducing a
1771    /// dependency on that crate.
1772    fn rand_func() -> ScalarUDF {
1773        ScalarUDF::new_from_impl(RandomStub::new())
1774    }
1775
1776    #[derive(Debug)]
1777    struct RandomStub {
1778        signature: Signature,
1779    }
1780
1781    impl RandomStub {
1782        fn new() -> Self {
1783            Self {
1784                signature: Signature::exact(vec![], Volatility::Volatile),
1785            }
1786        }
1787    }
1788    impl ScalarUDFImpl for RandomStub {
1789        fn as_any(&self) -> &dyn Any {
1790            self
1791        }
1792
1793        fn name(&self) -> &str {
1794            "random"
1795        }
1796
1797        fn signature(&self) -> &Signature {
1798            &self.signature
1799        }
1800
1801        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1802            Ok(DataType::Float64)
1803        }
1804
1805        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1806            panic!("dummy - not implemented")
1807        }
1808    }
1809}