datafusion_optimizer/
common_subexpr_eliminate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions
19
20use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41/// Performs Common Sub-expression Elimination optimization.
42///
43/// This optimization improves query performance by computing expressions that
44/// appear more than once and reusing those results rather than re-computing the
45/// same value
46///
47/// Currently only common sub-expressions within a single `LogicalPlan` are
48/// eliminated.
49///
50/// # Example
51///
52/// Given a projection that computes the same expensive expression
53/// multiple times such as parsing as string as a date with `to_date` twice:
54///
55/// ```text
56/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))])
57/// ```
58///
59/// This optimization will rewrite the plan to compute the common expression once
60/// using a new `ProjectionExec` and then rewrite the original expressions to
61/// refer to that new column.
62///
63/// ```text
64/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here
65///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
66/// ```
67#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71    pub fn new() -> Self {
72        Self {}
73    }
74
75    fn try_optimize_proj(
76        &self,
77        projection: Projection,
78        config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        let Projection {
81            expr,
82            input,
83            schema,
84            ..
85        } = projection;
86        let input = Arc::unwrap_or_clone(input);
87        self.try_unary_plan(expr, input, config)?
88            .map_data(|(new_expr, new_input)| {
89                Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90                    .map(LogicalPlan::Projection)
91            })
92    }
93
94    fn try_optimize_sort(
95        &self,
96        sort: Sort,
97        config: &dyn OptimizerConfig,
98    ) -> Result<Transformed<LogicalPlan>> {
99        let Sort { expr, input, fetch } = sort;
100        let input = Arc::unwrap_or_clone(input);
101        let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102            .into_iter()
103            .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104            .unzip();
105        let new_sort = self
106            .try_unary_plan(sort_expressions, input, config)?
107            .update_data(|(new_expr, new_input)| {
108                LogicalPlan::Sort(Sort {
109                    expr: new_expr
110                        .into_iter()
111                        .zip(sort_params)
112                        .map(|(expr, (asc, nulls_first))| SortExpr {
113                            expr,
114                            asc,
115                            nulls_first,
116                        })
117                        .collect(),
118                    input: Arc::new(new_input),
119                    fetch,
120                })
121            });
122        Ok(new_sort)
123    }
124
125    fn try_optimize_filter(
126        &self,
127        filter: Filter,
128        config: &dyn OptimizerConfig,
129    ) -> Result<Transformed<LogicalPlan>> {
130        let Filter {
131            predicate, input, ..
132        } = filter;
133        let input = Arc::unwrap_or_clone(input);
134        let expr = vec![predicate];
135        self.try_unary_plan(expr, input, config)?
136            .map_data(|(mut new_expr, new_input)| {
137                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
138                let new_predicate = new_expr.pop().unwrap();
139                Filter::try_new(new_predicate, Arc::new(new_input))
140                    .map(LogicalPlan::Filter)
141            })
142    }
143
144    fn try_optimize_window(
145        &self,
146        window: Window,
147        config: &dyn OptimizerConfig,
148    ) -> Result<Transformed<LogicalPlan>> {
149        // Collects window expressions from consecutive `LogicalPlan::Window` nodes into
150        // a list.
151        let (window_expr_list, window_schemas, input) =
152            get_consecutive_window_exprs(window);
153
154        // Extract common sub-expressions from the list.
155
156        match CSE::new(ExprCSEController::new(
157            config.alias_generator().as_ref(),
158            ExprMask::Normal,
159        ))
160        .extract_common_nodes(window_expr_list)?
161        {
162            // If there are common sub-expressions, then the insert a projection node
163            // with the common expressions between the new window nodes and the
164            // original input.
165            FoundCommonNodes::Yes {
166                common_nodes: common_exprs,
167                new_nodes_list: new_exprs_list,
168                original_nodes_list: original_exprs_list,
169            } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170                Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171            }),
172            FoundCommonNodes::No {
173                original_nodes_list: original_exprs_list,
174            } => Ok(Transformed::no((original_exprs_list, input, None))),
175        }?
176        // Recurse into the new input.
177        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
178        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179            self.rewrite(new_input, config)?.map_data(|new_input| {
180                Ok((new_window_expr_list, new_input, window_expr_list))
181            })
182        })?
183        // Rebuild the consecutive window nodes.
184        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185            // If there were common expressions extracted, then we need to make sure
186            // we restore the original column names.
187            // TODO: Although `find_common_exprs()` inserts aliases around extracted
188            //  common expressions this doesn't mean that the original column names
189            //  (schema) are preserved due to the inserted aliases are not always at
190            //  the top of the expression.
191            //  Let's consider improving `find_common_exprs()` to always keep column
192            //  names and get rid of additional name preserving logic here.
193            if let Some(window_expr_list) = window_expr_list {
194                let name_preserver = NamePreserver::new_for_projection();
195                let saved_names = window_expr_list
196                    .iter()
197                    .map(|exprs| {
198                        exprs
199                            .iter()
200                            .map(|expr| name_preserver.save(expr))
201                            .collect::<Vec<_>>()
202                    })
203                    .collect::<Vec<_>>();
204                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205                    new_input,
206                    |plan, (new_window_expr, saved_names)| {
207                        let new_window_expr = new_window_expr
208                            .into_iter()
209                            .zip(saved_names)
210                            .map(|(new_window_expr, saved_name)| {
211                                saved_name.restore(new_window_expr)
212                            })
213                            .collect::<Vec<_>>();
214                        Window::try_new(new_window_expr, Arc::new(plan))
215                            .map(LogicalPlan::Window)
216                    },
217                )
218            } else {
219                new_window_expr_list
220                    .into_iter()
221                    .zip(window_schemas)
222                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223                        Window::try_new_with_schema(
224                            new_window_expr,
225                            Arc::new(plan),
226                            schema,
227                        )
228                        .map(LogicalPlan::Window)
229                    })
230            }
231        })
232    }
233
234    fn try_optimize_aggregate(
235        &self,
236        aggregate: Aggregate,
237        config: &dyn OptimizerConfig,
238    ) -> Result<Transformed<LogicalPlan>> {
239        let Aggregate {
240            group_expr,
241            aggr_expr,
242            input,
243            schema,
244            ..
245        } = aggregate;
246        let input = Arc::unwrap_or_clone(input);
247        // Extract common sub-expressions from the aggregate and grouping expressions.
248        match CSE::new(ExprCSEController::new(
249            config.alias_generator().as_ref(),
250            ExprMask::Normal,
251        ))
252        .extract_common_nodes(vec![group_expr, aggr_expr])?
253        {
254            // If there are common sub-expressions, then insert a projection node
255            // with the common expressions between the new aggregate node and the
256            // original input.
257            FoundCommonNodes::Yes {
258                common_nodes: common_exprs,
259                new_nodes_list: mut new_exprs_list,
260                original_nodes_list: mut original_exprs_list,
261            } => {
262                let new_aggr_expr = new_exprs_list.pop().unwrap();
263                let new_group_expr = new_exprs_list.pop().unwrap();
264
265                build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266                    let aggr_expr = original_exprs_list.pop().unwrap();
267                    Transformed::yes((
268                        new_aggr_expr,
269                        new_group_expr,
270                        new_input,
271                        Some(aggr_expr),
272                    ))
273                })
274            }
275
276            FoundCommonNodes::No {
277                original_nodes_list: mut original_exprs_list,
278            } => {
279                let new_aggr_expr = original_exprs_list.pop().unwrap();
280                let new_group_expr = original_exprs_list.pop().unwrap();
281
282                Ok(Transformed::no((
283                    new_aggr_expr,
284                    new_group_expr,
285                    input,
286                    None,
287                )))
288            }
289        }?
290        // Recurse into the new input.
291        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
292        .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293            self.rewrite(new_input, config)?.map_data(|new_input| {
294                Ok((
295                    new_aggr_expr,
296                    new_group_expr,
297                    aggr_expr,
298                    Arc::new(new_input),
299                ))
300            })
301        })?
302        // Try extracting common aggregate expressions and rebuild the aggregate node.
303        .transform_data(
304            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305                // Extract common aggregate sub-expressions from the aggregate expressions.
306                match CSE::new(ExprCSEController::new(
307                    config.alias_generator().as_ref(),
308                    ExprMask::NormalAndAggregates,
309                ))
310                .extract_common_nodes(vec![new_aggr_expr])?
311                {
312                    FoundCommonNodes::Yes {
313                        common_nodes: common_exprs,
314                        new_nodes_list: mut new_exprs_list,
315                        original_nodes_list: mut original_exprs_list,
316                    } => {
317                        let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318                        let new_aggr_expr = original_exprs_list.pop().unwrap();
319
320                        let mut agg_exprs = common_exprs
321                            .into_iter()
322                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
323                            .collect::<Vec<_>>();
324
325                        let mut proj_exprs = vec![];
326                        for expr in &new_group_expr {
327                            extract_expressions(expr, &mut proj_exprs)
328                        }
329                        for (expr_rewritten, expr_orig) in
330                            rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
331                        {
332                            if expr_rewritten == expr_orig {
333                                if let Expr::Alias(Alias { expr, name, .. }) =
334                                    expr_rewritten
335                                {
336                                    agg_exprs.push(expr.alias(&name));
337                                    proj_exprs
338                                        .push(Expr::Column(Column::from_name(name)));
339                                } else {
340                                    let expr_alias =
341                                        config.alias_generator().next(CSE_PREFIX);
342                                    let (qualifier, field_name) =
343                                        expr_rewritten.qualified_name();
344                                    let out_name =
345                                        qualified_name(qualifier.as_ref(), &field_name);
346
347                                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
348                                    proj_exprs.push(
349                                        Expr::Column(Column::from_name(expr_alias))
350                                            .alias(out_name),
351                                    );
352                                }
353                            } else {
354                                proj_exprs.push(expr_rewritten);
355                            }
356                        }
357
358                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
359                            new_input,
360                            new_group_expr,
361                            agg_exprs,
362                        )?);
363                        Projection::try_new(proj_exprs, Arc::new(agg))
364                            .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
365                    }
366
367                    // If there aren't any common aggregate sub-expressions, then just
368                    // rebuild the aggregate node.
369                    FoundCommonNodes::No {
370                        original_nodes_list: mut original_exprs_list,
371                    } => {
372                        let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
373
374                        // If there were common expressions extracted, then we need to
375                        // make sure we restore the original column names.
376                        // TODO: Although `find_common_exprs()` inserts aliases around
377                        //  extracted common expressions this doesn't mean that the
378                        //  original column names (schema) are preserved due to the
379                        //  inserted aliases are not always at the top of the
380                        //  expression.
381                        //  Let's consider improving `find_common_exprs()` to always
382                        //  keep column names and get rid of additional name
383                        //  preserving logic here.
384                        if let Some(aggr_expr) = aggr_expr {
385                            let name_preserver = NamePreserver::new_for_projection();
386                            let saved_names = aggr_expr
387                                .iter()
388                                .map(|expr| name_preserver.save(expr))
389                                .collect::<Vec<_>>();
390                            let new_aggr_expr = rewritten_aggr_expr
391                                .into_iter()
392                                .zip(saved_names)
393                                .map(|(new_expr, saved_name)| {
394                                    saved_name.restore(new_expr)
395                                })
396                                .collect::<Vec<Expr>>();
397
398                            // Since `group_expr` may have changed, schema may also.
399                            // Use `try_new()` method.
400                            Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
401                                .map(LogicalPlan::Aggregate)
402                                .map(Transformed::no)
403                        } else {
404                            Aggregate::try_new_with_schema(
405                                new_input,
406                                new_group_expr,
407                                rewritten_aggr_expr,
408                                schema,
409                            )
410                            .map(LogicalPlan::Aggregate)
411                            .map(Transformed::no)
412                        }
413                    }
414                }
415            },
416        )
417    }
418
419    /// Rewrites the expr list and input to remove common subexpressions
420    ///
421    /// # Parameters
422    ///
423    /// * `exprs`: List of expressions in the node
424    /// * `input`: input plan (that produces the columns referred to in `exprs`)
425    ///
426    /// # Return value
427    ///
428    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
429    ///
430    /// 1. The original `input` of no common subexpressions were extracted
431    /// 2. A newly added projection on top of the original input
432    ///    that computes the common subexpressions
433    fn try_unary_plan(
434        &self,
435        exprs: Vec<Expr>,
436        input: LogicalPlan,
437        config: &dyn OptimizerConfig,
438    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
439        // Extract common sub-expressions from the expressions.
440        match CSE::new(ExprCSEController::new(
441            config.alias_generator().as_ref(),
442            ExprMask::Normal,
443        ))
444        .extract_common_nodes(vec![exprs])?
445        {
446            FoundCommonNodes::Yes {
447                common_nodes: common_exprs,
448                new_nodes_list: mut new_exprs_list,
449                original_nodes_list: _,
450            } => {
451                let new_exprs = new_exprs_list.pop().unwrap();
452                build_common_expr_project_plan(input, common_exprs)
453                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
454            }
455            FoundCommonNodes::No {
456                original_nodes_list: mut original_exprs_list,
457            } => {
458                let new_exprs = original_exprs_list.pop().unwrap();
459                Ok(Transformed::no((new_exprs, input)))
460            }
461        }?
462        // Recurse into the new input.
463        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
464        .transform_data(|(new_exprs, new_input)| {
465            self.rewrite(new_input, config)?
466                .map_data(|new_input| Ok((new_exprs, new_input)))
467        })
468    }
469}
470
471/// Get all window expressions inside the consecutive window operators.
472///
473/// Returns the window expressions, and the input to the deepest child
474/// LogicalPlan.
475///
476/// For example, if the input window looks like
477///
478/// ```text
479///   LogicalPlan::Window(exprs=[a, b, c])
480///     LogicalPlan::Window(exprs=[d])
481///       InputPlan
482/// ```
483///
484/// Returns:
485/// *  `window_exprs`: `[[a, b, c], [d]]`
486/// * InputPlan
487///
488/// Consecutive window expressions may refer to same complex expression.
489///
490/// If same complex expression is referred more than once by subsequent
491/// `WindowAggr`s, we can cache complex expression by evaluating it with a
492/// projection before the first WindowAggr.
493///
494/// This enables us to cache complex expression "c3+c4" for following plan:
495///
496/// ```text
497/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
498/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
499/// ```
500///
501/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
502fn get_consecutive_window_exprs(
503    window: Window,
504) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
505    let mut window_expr_list = vec![];
506    let mut window_schemas = vec![];
507    let mut plan = LogicalPlan::Window(window);
508    while let LogicalPlan::Window(Window {
509        input,
510        window_expr,
511        schema,
512    }) = plan
513    {
514        window_expr_list.push(window_expr);
515        window_schemas.push(schema);
516
517        plan = Arc::unwrap_or_clone(input);
518    }
519    (window_expr_list, window_schemas, plan)
520}
521
522impl OptimizerRule for CommonSubexprEliminate {
523    fn supports_rewrite(&self) -> bool {
524        true
525    }
526
527    fn apply_order(&self) -> Option<ApplyOrder> {
528        // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
529        // This is because in some cases adjacent nodes are collected (e.g. `Window`) and
530        // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
531        None
532    }
533
534    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
535    fn rewrite(
536        &self,
537        plan: LogicalPlan,
538        config: &dyn OptimizerConfig,
539    ) -> Result<Transformed<LogicalPlan>> {
540        let original_schema = Arc::clone(plan.schema());
541
542        let optimized_plan = match plan {
543            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
544            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
545            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
546            LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
547            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
548            LogicalPlan::Join(_)
549            | LogicalPlan::Repartition(_)
550            | LogicalPlan::Union(_)
551            | LogicalPlan::TableScan(_)
552            | LogicalPlan::Values(_)
553            | LogicalPlan::EmptyRelation(_)
554            | LogicalPlan::Subquery(_)
555            | LogicalPlan::SubqueryAlias(_)
556            | LogicalPlan::Limit(_)
557            | LogicalPlan::Ddl(_)
558            | LogicalPlan::Explain(_)
559            | LogicalPlan::Analyze(_)
560            | LogicalPlan::Statement(_)
561            | LogicalPlan::DescribeTable(_)
562            | LogicalPlan::Distinct(_)
563            | LogicalPlan::Extension(_)
564            | LogicalPlan::Dml(_)
565            | LogicalPlan::Copy(_)
566            | LogicalPlan::Unnest(_)
567            | LogicalPlan::RecursiveQuery(_) => {
568                // This rule handles recursion itself in a `ApplyOrder::TopDown` like
569                // manner.
570                plan.map_children(|c| self.rewrite(c, config))?
571            }
572        };
573
574        // If we rewrote the plan, ensure the schema stays the same
575        if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
576        {
577            optimized_plan.map_data(|optimized_plan| {
578                build_recover_project_plan(&original_schema, optimized_plan)
579            })
580        } else {
581            Ok(optimized_plan)
582        }
583    }
584
585    fn name(&self) -> &str {
586        "common_sub_expression_eliminate"
587    }
588}
589
590/// Which type of [expressions](Expr) should be considered for rewriting?
591#[derive(Debug, Clone, Copy)]
592enum ExprMask {
593    /// Ignores:
594    ///
595    /// - [`Literal`](Expr::Literal)
596    /// - [`Columns`](Expr::Column)
597    /// - [`ScalarVariable`](Expr::ScalarVariable)
598    /// - [`Alias`](Expr::Alias)
599    /// - [`Wildcard`](Expr::Wildcard)
600    /// - [`AggregateFunction`](Expr::AggregateFunction)
601    Normal,
602
603    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
604    NormalAndAggregates,
605}
606
607struct ExprCSEController<'a> {
608    alias_generator: &'a AliasGenerator,
609    mask: ExprMask,
610
611    // how many aliases have we seen so far
612    alias_counter: usize,
613}
614
615impl<'a> ExprCSEController<'a> {
616    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
617        Self {
618            alias_generator,
619            mask,
620            alias_counter: 0,
621        }
622    }
623}
624
625impl CSEController for ExprCSEController<'_> {
626    type Node = Expr;
627
628    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
629        match node {
630            // In case of `ScalarFunction`s we don't know which children are surely
631            // executed so start visiting all children conditionally and stop the
632            // recursion with `TreeNodeRecursion::Jump`.
633            Expr::ScalarFunction(ScalarFunction { func, args })
634                if func.short_circuits() =>
635            {
636                Some((vec![], args.iter().collect()))
637            }
638
639            // In case of `And` and `Or` the first child is surely executed, but we
640            // account subexpressions as conditional in the second.
641            Expr::BinaryExpr(BinaryExpr {
642                left,
643                op: Operator::And | Operator::Or,
644                right,
645            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
646
647            // In case of `Case` the optional base expression and the first when
648            // expressions are surely executed, but we account subexpressions as
649            // conditional in the others.
650            Expr::Case(Case {
651                expr,
652                when_then_expr,
653                else_expr,
654            }) => Some((
655                expr.iter()
656                    .map(|e| e.as_ref())
657                    .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
658                    .collect(),
659                when_then_expr
660                    .iter()
661                    .take(1)
662                    .map(|(_, then)| then.as_ref())
663                    .chain(
664                        when_then_expr
665                            .iter()
666                            .skip(1)
667                            .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
668                    )
669                    .chain(else_expr.iter().map(|e| e.as_ref()))
670                    .collect(),
671            )),
672            _ => None,
673        }
674    }
675
676    fn is_valid(node: &Expr) -> bool {
677        !node.is_volatile_node()
678    }
679
680    fn is_ignored(&self, node: &Expr) -> bool {
681        // TODO: remove the next line after `Expr::Wildcard` is removed
682        #[expect(deprecated)]
683        let is_normal_minus_aggregates = matches!(
684            node,
685            Expr::Literal(..)
686                | Expr::Column(..)
687                | Expr::ScalarVariable(..)
688                | Expr::Alias(..)
689                | Expr::Wildcard { .. }
690        );
691
692        let is_aggr = matches!(node, Expr::AggregateFunction(..));
693
694        match self.mask {
695            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
696            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
697        }
698    }
699
700    fn generate_alias(&self) -> String {
701        self.alias_generator.next(CSE_PREFIX)
702    }
703
704    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
705        // alias the expressions without an `Alias` ancestor node
706        if self.alias_counter > 0 {
707            col(alias)
708        } else {
709            self.alias_counter += 1;
710            col(alias).alias(node.schema_name().to_string())
711        }
712    }
713
714    fn rewrite_f_down(&mut self, node: &Expr) {
715        if matches!(node, Expr::Alias(_)) {
716            self.alias_counter += 1;
717        }
718    }
719    fn rewrite_f_up(&mut self, node: &Expr) {
720        if matches!(node, Expr::Alias(_)) {
721            self.alias_counter -= 1
722        }
723    }
724}
725
726impl Default for CommonSubexprEliminate {
727    fn default() -> Self {
728        Self::new()
729    }
730}
731
732/// Build the "intermediate" projection plan that evaluates the extracted common
733/// expressions.
734///
735/// # Arguments
736/// input: the input plan
737///
738/// common_exprs: which common subexpressions were used (and thus are added to
739/// intermediate projection)
740///
741/// expr_stats: the set of common subexpressions
742fn build_common_expr_project_plan(
743    input: LogicalPlan,
744    common_exprs: Vec<(Expr, String)>,
745) -> Result<LogicalPlan> {
746    let mut fields_set = BTreeSet::new();
747    let mut project_exprs = common_exprs
748        .into_iter()
749        .map(|(expr, expr_alias)| {
750            fields_set.insert(expr_alias.clone());
751            Ok(expr.alias(expr_alias))
752        })
753        .collect::<Result<Vec<_>>>()?;
754
755    for (qualifier, field) in input.schema().iter() {
756        if fields_set.insert(qualified_name(qualifier, field.name())) {
757            project_exprs.push(Expr::from((qualifier, field)));
758        }
759    }
760
761    Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
762}
763
764/// Build the projection plan to eliminate unnecessary columns produced by
765/// the "intermediate" projection plan built in [build_common_expr_project_plan].
766///
767/// This is required to keep the schema the same for plans that pass the input
768/// on to the output, such as `Filter` or `Sort`.
769fn build_recover_project_plan(
770    schema: &DFSchema,
771    input: LogicalPlan,
772) -> Result<LogicalPlan> {
773    let col_exprs = schema.iter().map(Expr::from).collect();
774    Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
775}
776
777fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
778    if let Expr::GroupingSet(groupings) = expr {
779        for e in groupings.distinct_expr() {
780            let (qualifier, field_name) = e.qualified_name();
781            let col = Column::new(qualifier, field_name);
782            result.push(Expr::Column(col))
783        }
784    } else {
785        let (qualifier, field_name) = expr.qualified_name();
786        let col = Column::new(qualifier, field_name);
787        result.push(Expr::Column(col));
788    }
789}
790
791#[cfg(test)]
792mod test {
793    use std::any::Any;
794    use std::iter;
795
796    use arrow::datatypes::{DataType, Field, Schema};
797    use datafusion_expr::logical_plan::{table_scan, JoinType};
798    use datafusion_expr::{
799        grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
800        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
801        SimpleAggregateUDF, Volatility,
802    };
803    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
804
805    use super::*;
806    use crate::optimizer::OptimizerContext;
807    use crate::test::*;
808    use crate::Optimizer;
809    use datafusion_expr::test::function_stub::{avg, sum};
810
811    fn assert_optimized_plan_eq(
812        expected: &str,
813        plan: LogicalPlan,
814        config: Option<&dyn OptimizerConfig>,
815    ) {
816        let optimizer =
817            Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]);
818        let default_config = OptimizerContext::new();
819        let config = config.unwrap_or(&default_config);
820        let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap();
821        let formatted_plan = format!("{optimized_plan}");
822        assert_eq!(expected, formatted_plan);
823    }
824
825    #[test]
826    fn tpch_q1_simplified() -> Result<()> {
827        // SQL:
828        //  select
829        //      sum(a * (1 - b)),
830        //      sum(a * (1 - b) * (1 + c))
831        //  from T;
832        //
833        // The manual assembled logical plan don't contains the outermost `Projection`.
834
835        let table_scan = test_table_scan()?;
836
837        let plan = LogicalPlanBuilder::from(table_scan)
838            .aggregate(
839                iter::empty::<Expr>(),
840                vec![
841                    sum(col("a") * (lit(1) - col("b"))),
842                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
843                ],
844            )?
845            .build()?;
846
847        let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
848        \n  Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\
849        \n    TableScan: test";
850
851        assert_optimized_plan_eq(expected, plan, None);
852
853        Ok(())
854    }
855
856    #[test]
857    fn nested_aliases() -> Result<()> {
858        let table_scan = test_table_scan()?;
859
860        let plan = LogicalPlanBuilder::from(table_scan)
861            .project(vec![
862                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
863                col("a") + col("b"),
864            ])?
865            .build()?;
866
867        let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\
868        \n  Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
869        \n    TableScan: test";
870
871        assert_optimized_plan_eq(expected, plan, None);
872
873        Ok(())
874    }
875
876    #[test]
877    fn aggregate() -> Result<()> {
878        let table_scan = test_table_scan()?;
879
880        let return_type = DataType::UInt32;
881        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
882        let udf_agg = |inner: Expr| {
883            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
884                Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
885                    "my_agg",
886                    Signature::exact(vec![DataType::UInt32], Volatility::Stable),
887                    return_type.clone(),
888                    Arc::clone(&accumulator),
889                    vec![Field::new("value", DataType::UInt32, true)],
890                ))),
891                vec![inner],
892                false,
893                None,
894                None,
895                None,
896            ))
897        };
898
899        // test: common aggregates
900        let plan = LogicalPlanBuilder::from(table_scan.clone())
901            .aggregate(
902                iter::empty::<Expr>(),
903                vec![
904                    // common: avg(col("a"))
905                    avg(col("a")).alias("col1"),
906                    avg(col("a")).alias("col2"),
907                    // no common
908                    avg(col("b")).alias("col3"),
909                    avg(col("c")),
910                    // common: udf_agg(col("a"))
911                    udf_agg(col("a")).alias("col4"),
912                    udf_agg(col("a")).alias("col5"),
913                    // no common
914                    udf_agg(col("b")).alias("col6"),
915                    udf_agg(col("c")),
916                ],
917            )?
918            .build()?;
919
920        let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\
921        \n  Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\
922        \n    TableScan: test";
923
924        assert_optimized_plan_eq(expected, plan, None);
925
926        // test: trafo after aggregate
927        let plan = LogicalPlanBuilder::from(table_scan.clone())
928            .aggregate(
929                iter::empty::<Expr>(),
930                vec![
931                    lit(1) + avg(col("a")),
932                    lit(1) - avg(col("a")),
933                    lit(1) + udf_agg(col("a")),
934                    lit(1) - udf_agg(col("a")),
935                ],
936            )?
937            .build()?;
938
939        let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\
940        \n  Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\
941        \n    TableScan: test";
942
943        assert_optimized_plan_eq(expected, plan, None);
944
945        // test: transformation before aggregate
946        let plan = LogicalPlanBuilder::from(table_scan.clone())
947            .aggregate(
948                iter::empty::<Expr>(),
949                vec![
950                    avg(lit(1u32) + col("a")).alias("col1"),
951                    udf_agg(lit(1u32) + col("a")).alias("col2"),
952                ],
953            )?
954            .build()?;
955
956        let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\
957        \n  Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
958        \n    TableScan: test";
959
960        assert_optimized_plan_eq(expected, plan, None);
961
962        // test: common between agg and group
963        let plan = LogicalPlanBuilder::from(table_scan.clone())
964            .aggregate(
965                vec![lit(1u32) + col("a")],
966                vec![
967                    avg(lit(1u32) + col("a")).alias("col1"),
968                    udf_agg(lit(1u32) + col("a")).alias("col2"),
969                ],
970            )?
971            .build()?;
972
973        let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\
974        \n  Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
975        \n    TableScan: test";
976
977        assert_optimized_plan_eq(expected, plan, None);
978
979        // test: all mixed
980        let plan = LogicalPlanBuilder::from(table_scan)
981            .aggregate(
982                vec![lit(1u32) + col("a")],
983                vec![
984                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
985                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
986                    avg(lit(1u32) + col("a")),
987                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
988                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
989                    udf_agg(lit(1u32) + col("a")),
990                ],
991            )?
992            .build()?;
993
994        let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\
995        \n  Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\
996        \n    Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
997        \n      TableScan: test";
998
999        assert_optimized_plan_eq(expected, plan, None);
1000
1001        Ok(())
1002    }
1003
1004    #[test]
1005    fn aggregate_with_relations_and_dots() -> Result<()> {
1006        let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1007        let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1008
1009        let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1010
1011        let plan = LogicalPlanBuilder::from(table_scan)
1012            .aggregate(
1013                vec![col_a.clone()],
1014                vec![
1015                    (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1016                    avg(lit(1u32) + col_a),
1017                ],
1018            )?
1019            .build()?;
1020
1021        let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\
1022        \n  Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\
1023        \n    Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\
1024        \n      TableScan: table.test";
1025
1026        assert_optimized_plan_eq(expected, plan, None);
1027
1028        Ok(())
1029    }
1030
1031    #[test]
1032    fn subexpr_in_same_order() -> Result<()> {
1033        let table_scan = test_table_scan()?;
1034
1035        let plan = LogicalPlanBuilder::from(table_scan)
1036            .project(vec![
1037                (lit(1) + col("a")).alias("first"),
1038                (lit(1) + col("a")).alias("second"),
1039            ])?
1040            .build()?;
1041
1042        let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\
1043        \n  Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
1044        \n    TableScan: test";
1045
1046        assert_optimized_plan_eq(expected, plan, None);
1047
1048        Ok(())
1049    }
1050
1051    #[test]
1052    fn subexpr_in_different_order() -> Result<()> {
1053        let table_scan = test_table_scan()?;
1054
1055        let plan = LogicalPlanBuilder::from(table_scan)
1056            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1057            .build()?;
1058
1059        let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\
1060        \n  Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
1061        \n    TableScan: test";
1062
1063        assert_optimized_plan_eq(expected, plan, None);
1064
1065        Ok(())
1066    }
1067
1068    #[test]
1069    fn cross_plans_subexpr() -> Result<()> {
1070        let table_scan = test_table_scan()?;
1071
1072        let plan = LogicalPlanBuilder::from(table_scan)
1073            .project(vec![lit(1) + col("a"), col("a")])?
1074            .project(vec![lit(1) + col("a")])?
1075            .build()?;
1076
1077        let expected = "Projection: Int32(1) + test.a\
1078        \n  Projection: Int32(1) + test.a, test.a\
1079        \n    TableScan: test";
1080
1081        assert_optimized_plan_eq(expected, plan, None);
1082        Ok(())
1083    }
1084
1085    #[test]
1086    fn redundant_project_fields() {
1087        let table_scan = test_table_scan().unwrap();
1088        let c_plus_a = col("c") + col("a");
1089        let b_plus_a = col("b") + col("a");
1090        let common_exprs_1 = vec![
1091            (c_plus_a, format!("{CSE_PREFIX}_1")),
1092            (b_plus_a, format!("{CSE_PREFIX}_2")),
1093        ];
1094        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1095        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1096        let common_exprs_2 = vec![
1097            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1098            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1099        ];
1100        let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1101        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1102
1103        let mut field_set = BTreeSet::new();
1104        for name in project_2.schema().field_names() {
1105            assert!(field_set.insert(name));
1106        }
1107    }
1108
1109    #[test]
1110    fn redundant_project_fields_join_input() {
1111        let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1112        let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1113        let join = LogicalPlanBuilder::from(table_scan_1)
1114            .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1115            .unwrap()
1116            .build()
1117            .unwrap();
1118        let c_plus_a = col("test1.c") + col("test1.a");
1119        let b_plus_a = col("test1.b") + col("test1.a");
1120        let common_exprs_1 = vec![
1121            (c_plus_a, format!("{CSE_PREFIX}_1")),
1122            (b_plus_a, format!("{CSE_PREFIX}_2")),
1123        ];
1124        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1125        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1126        let common_exprs_2 = vec![
1127            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1128            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1129        ];
1130        let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1131        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1132
1133        let mut field_set = BTreeSet::new();
1134        for name in project_2.schema().field_names() {
1135            assert!(field_set.insert(name));
1136        }
1137    }
1138
1139    #[test]
1140    fn eliminated_subexpr_datatype() {
1141        use datafusion_expr::cast;
1142
1143        let schema = Schema::new(vec![
1144            Field::new("a", DataType::UInt64, false),
1145            Field::new("b", DataType::UInt64, false),
1146            Field::new("c", DataType::UInt64, false),
1147        ]);
1148
1149        let plan = table_scan(Some("table"), &schema, None)
1150            .unwrap()
1151            .filter(
1152                cast(col("a"), DataType::Int64)
1153                    .lt(lit(1_i64))
1154                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1155            )
1156            .unwrap()
1157            .build()
1158            .unwrap();
1159        let rule = CommonSubexprEliminate::new();
1160        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1161        assert!(optimized_plan.transformed);
1162        let optimized_plan = optimized_plan.data;
1163
1164        let schema = optimized_plan.schema();
1165        let fields_with_datatypes: Vec<_> = schema
1166            .fields()
1167            .iter()
1168            .map(|field| (field.name(), field.data_type()))
1169            .collect();
1170        let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1171        let expected = r#"[
1172    (
1173        "a",
1174        UInt64,
1175    ),
1176    (
1177        "b",
1178        UInt64,
1179    ),
1180    (
1181        "c",
1182        UInt64,
1183    ),
1184]"#;
1185        assert_eq!(expected, formatted_fields_with_datatype);
1186    }
1187
1188    #[test]
1189    fn filter_schema_changed() -> Result<()> {
1190        let table_scan = test_table_scan()?;
1191
1192        let plan = LogicalPlanBuilder::from(table_scan)
1193            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1194            .build()?;
1195
1196        let expected = "Projection: test.a, test.b, test.c\
1197        \n  Filter: __common_expr_1 - Int32(10) > __common_expr_1\
1198        \n    Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\
1199        \n      TableScan: test";
1200
1201        assert_optimized_plan_eq(expected, plan, None);
1202
1203        Ok(())
1204    }
1205
1206    #[test]
1207    fn test_extract_expressions_from_grouping_set() -> Result<()> {
1208        let mut result = Vec::with_capacity(3);
1209        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1210        extract_expressions(&grouping, &mut result);
1211
1212        assert!(result.len() == 3);
1213        Ok(())
1214    }
1215
1216    #[test]
1217    fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1218        let mut result = Vec::with_capacity(2);
1219        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1220        extract_expressions(&grouping, &mut result);
1221        assert!(result.len() == 2);
1222        Ok(())
1223    }
1224
1225    #[test]
1226    fn test_alias_collision() -> Result<()> {
1227        let table_scan = test_table_scan()?;
1228
1229        let config = &OptimizerContext::new();
1230        let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1231        let plan = LogicalPlanBuilder::from(table_scan.clone())
1232            .project(vec![
1233                (col("a") + col("b")).alias(common_expr_1.clone()),
1234                col("c"),
1235            ])?
1236            .project(vec![
1237                col(common_expr_1.clone()).alias("c1"),
1238                col(common_expr_1).alias("c2"),
1239                (col("c") + lit(2)).alias("c3"),
1240                (col("c") + lit(2)).alias("c4"),
1241            ])?
1242            .build()?;
1243
1244        let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\
1245        \n  Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\
1246        \n    Projection: test.a + test.b AS __common_expr_1, test.c\
1247        \n      TableScan: test";
1248
1249        assert_optimized_plan_eq(expected, plan, Some(config));
1250
1251        let config = &OptimizerContext::new();
1252        let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1253        let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1254        let plan = LogicalPlanBuilder::from(table_scan)
1255            .project(vec![
1256                (col("a") + col("b")).alias(common_expr_2.clone()),
1257                col("c"),
1258            ])?
1259            .project(vec![
1260                col(common_expr_2.clone()).alias("c1"),
1261                col(common_expr_2).alias("c2"),
1262                (col("c") + lit(2)).alias("c3"),
1263                (col("c") + lit(2)).alias("c4"),
1264            ])?
1265            .build()?;
1266
1267        let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\
1268        \n  Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\
1269        \n    Projection: test.a + test.b AS __common_expr_2, test.c\
1270        \n      TableScan: test";
1271
1272        assert_optimized_plan_eq(expected, plan, Some(config));
1273
1274        Ok(())
1275    }
1276
1277    #[test]
1278    fn test_extract_expressions_from_col() -> Result<()> {
1279        let mut result = Vec::with_capacity(1);
1280        extract_expressions(&col("a"), &mut result);
1281        assert!(result.len() == 1);
1282        Ok(())
1283    }
1284
1285    #[test]
1286    fn test_short_circuits() -> Result<()> {
1287        let table_scan = test_table_scan()?;
1288
1289        let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1290        let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1291        let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1292        let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1293        let plan = LogicalPlanBuilder::from(table_scan)
1294            .project(vec![
1295                extracted_short_circuit.clone().alias("c1"),
1296                extracted_short_circuit.alias("c2"),
1297                extracted_short_circuit_leg_1
1298                    .clone()
1299                    .or(not_extracted_short_circuit_leg_2.clone())
1300                    .alias("c3"),
1301                extracted_short_circuit_leg_1
1302                    .and(not_extracted_short_circuit_leg_2)
1303                    .alias("c4"),
1304                extracted_short_circuit_leg_3
1305                    .clone()
1306                    .or(extracted_short_circuit_leg_3)
1307                    .alias("c5"),
1308            ])?
1309            .build()?;
1310
1311        let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\
1312        \n  Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\
1313        \n    TableScan: test";
1314
1315        assert_optimized_plan_eq(expected, plan, None);
1316
1317        Ok(())
1318    }
1319
1320    #[test]
1321    fn test_volatile() -> Result<()> {
1322        let table_scan = test_table_scan()?;
1323
1324        let extracted_child = col("a") + col("b");
1325        let rand = rand_func().call(vec![]);
1326        let not_extracted_volatile = extracted_child + rand;
1327        let plan = LogicalPlanBuilder::from(table_scan)
1328            .project(vec![
1329                not_extracted_volatile.clone().alias("c1"),
1330                not_extracted_volatile.alias("c2"),
1331            ])?
1332            .build()?;
1333
1334        let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\
1335        \n  Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
1336        \n    TableScan: test";
1337
1338        assert_optimized_plan_eq(expected, plan, None);
1339
1340        Ok(())
1341    }
1342
1343    #[test]
1344    fn test_volatile_short_circuits() -> Result<()> {
1345        let table_scan = test_table_scan()?;
1346
1347        let rand = rand_func().call(vec![]);
1348        let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1349        let not_extracted_volatile_short_circuit_1 =
1350            extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1351        let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1352        let not_extracted_volatile_short_circuit_2 =
1353            rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1354        let plan = LogicalPlanBuilder::from(table_scan)
1355            .project(vec![
1356                not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1357                not_extracted_volatile_short_circuit_1.alias("c2"),
1358                not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1359                not_extracted_volatile_short_circuit_2.alias("c4"),
1360            ])?
1361            .build()?;
1362
1363        let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\
1364        \n  Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
1365        \n    TableScan: test";
1366
1367        assert_optimized_plan_eq(expected, plan, None);
1368
1369        Ok(())
1370    }
1371
1372    #[test]
1373    fn test_non_top_level_common_expression() -> Result<()> {
1374        let table_scan = test_table_scan()?;
1375
1376        let common_expr = col("a") + col("b");
1377        let plan = LogicalPlanBuilder::from(table_scan)
1378            .project(vec![
1379                common_expr.clone().alias("c1"),
1380                common_expr.alias("c2"),
1381            ])?
1382            .project(vec![col("c1"), col("c2")])?
1383            .build()?;
1384
1385        let expected = "Projection: c1, c2\
1386        \n  Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\
1387        \n    Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
1388        \n      TableScan: test";
1389
1390        assert_optimized_plan_eq(expected, plan, None);
1391
1392        Ok(())
1393    }
1394
1395    #[test]
1396    fn test_nested_common_expression() -> Result<()> {
1397        let table_scan = test_table_scan()?;
1398
1399        let nested_common_expr = col("a") + col("b");
1400        let common_expr = nested_common_expr.clone() * nested_common_expr;
1401        let plan = LogicalPlanBuilder::from(table_scan)
1402            .project(vec![
1403                common_expr.clone().alias("c1"),
1404                common_expr.alias("c2"),
1405            ])?
1406            .build()?;
1407
1408        let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\
1409        \n  Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\
1410        \n    Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\
1411        \n      TableScan: test";
1412
1413        assert_optimized_plan_eq(expected, plan, None);
1414
1415        Ok(())
1416    }
1417
1418    #[test]
1419    fn test_normalize_add_expression() -> Result<()> {
1420        // a + b <=> b + a
1421        let table_scan = test_table_scan()?;
1422        let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1423        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1424
1425        let expected = "Projection: test.a, test.b, test.c\
1426        \n  Filter: __common_expr_1 * __common_expr_1 = Int32(30)\
1427        \n    Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
1428        \n      TableScan: test";
1429        assert_optimized_plan_eq(expected, plan, None);
1430
1431        Ok(())
1432    }
1433
1434    #[test]
1435    fn test_normalize_multi_expression() -> Result<()> {
1436        // a * b <=> b * a
1437        let table_scan = test_table_scan()?;
1438        let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1439        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1440
1441        let expected = "Projection: test.a, test.b, test.c\
1442        \n  Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1443        \n    Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\
1444        \n      TableScan: test";
1445        assert_optimized_plan_eq(expected, plan, None);
1446
1447        Ok(())
1448    }
1449
1450    #[test]
1451    fn test_normalize_bitset_and_expression() -> Result<()> {
1452        // a & b <=> b & a
1453        let table_scan = test_table_scan()?;
1454        let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1455        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1456
1457        let expected = "Projection: test.a, test.b, test.c\
1458        \n  Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1459        \n    Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\
1460        \n      TableScan: test";
1461        assert_optimized_plan_eq(expected, plan, None);
1462
1463        Ok(())
1464    }
1465
1466    #[test]
1467    fn test_normalize_bitset_or_expression() -> Result<()> {
1468        // a | b <=> b | a
1469        let table_scan = test_table_scan()?;
1470        let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1471        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1472
1473        let expected = "Projection: test.a, test.b, test.c\
1474        \n  Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1475        \n    Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\
1476        \n      TableScan: test";
1477        assert_optimized_plan_eq(expected, plan, None);
1478
1479        Ok(())
1480    }
1481
1482    #[test]
1483    fn test_normalize_bitset_xor_expression() -> Result<()> {
1484        // a # b <=> b # a
1485        let table_scan = test_table_scan()?;
1486        let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1487        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1488
1489        let expected = "Projection: test.a, test.b, test.c\
1490        \n  Filter: __common_expr_1 + __common_expr_1 = Int32(30)\
1491        \n    Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\
1492        \n      TableScan: test";
1493        assert_optimized_plan_eq(expected, plan, None);
1494
1495        Ok(())
1496    }
1497
1498    #[test]
1499    fn test_normalize_eq_expression() -> Result<()> {
1500        // a = b <=> b = a
1501        let table_scan = test_table_scan()?;
1502        let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1503        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1504
1505        let expected = "Projection: test.a, test.b, test.c\
1506        \n  Filter: __common_expr_1 AND __common_expr_1\
1507        \n    Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\
1508        \n      TableScan: test";
1509        assert_optimized_plan_eq(expected, plan, None);
1510
1511        Ok(())
1512    }
1513
1514    #[test]
1515    fn test_normalize_ne_expression() -> Result<()> {
1516        // a != b <=> b != a
1517        let table_scan = test_table_scan()?;
1518        let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1519        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1520
1521        let expected = "Projection: test.a, test.b, test.c\
1522        \n  Filter: __common_expr_1 AND __common_expr_1\
1523        \n    Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\
1524        \n      TableScan: test";
1525        assert_optimized_plan_eq(expected, plan, None);
1526
1527        Ok(())
1528    }
1529
1530    #[test]
1531    fn test_normalize_complex_expression() -> Result<()> {
1532        // case1: a + b * c <=> b * c + a
1533        let table_scan = test_table_scan()?;
1534        let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1535            .eq(lit(30));
1536        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1537
1538        let expected = "Projection: test.a, test.b, test.c\
1539        \n  Filter: __common_expr_1 - __common_expr_1 = Int32(30)\
1540        \n    Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\
1541        \n      TableScan: test";
1542        assert_optimized_plan_eq(expected, plan, None);
1543
1544        // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
1545        let table_scan = test_table_scan()?;
1546        let expr = (((col("a") + col("b") / col("c")) * col("c"))
1547            / (col("c") * (col("b") / col("c") + col("a")))
1548            + col("a"))
1549        .eq(lit(30));
1550        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1551        let expected = "Projection: test.a, test.b, test.c\
1552        \n  Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\
1553        \n    Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\
1554        \n      TableScan: test";
1555        assert_optimized_plan_eq(expected, plan, None);
1556
1557        // c2 / (c1 + c3) <=> c2 / (c3 + c1)
1558        let table_scan = test_table_scan()?;
1559        let expr = ((col("b") / (col("a") + col("c")))
1560            * (col("b") / (col("c") + col("a"))))
1561        .eq(lit(30));
1562        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1563        let expected = "Projection: test.a, test.b, test.c\
1564        \n  Filter: __common_expr_1 * __common_expr_1 = Int32(30)\
1565        \n    Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\
1566        \n      TableScan: test";
1567        assert_optimized_plan_eq(expected, plan, None);
1568
1569        Ok(())
1570    }
1571
1572    #[derive(Debug)]
1573    pub struct TestUdf {
1574        signature: Signature,
1575    }
1576
1577    impl TestUdf {
1578        pub fn new() -> Self {
1579            Self {
1580                signature: Signature::numeric(1, Volatility::Immutable),
1581            }
1582        }
1583    }
1584
1585    impl ScalarUDFImpl for TestUdf {
1586        fn as_any(&self) -> &dyn Any {
1587            self
1588        }
1589        fn name(&self) -> &str {
1590            "my_udf"
1591        }
1592
1593        fn signature(&self) -> &Signature {
1594            &self.signature
1595        }
1596
1597        fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1598            Ok(DataType::Int32)
1599        }
1600
1601        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1602            panic!("not implemented")
1603        }
1604    }
1605
1606    #[test]
1607    fn test_normalize_inner_binary_expression() -> Result<()> {
1608        // Not(a == b) <=> Not(b == a)
1609        let table_scan = test_table_scan()?;
1610        let expr1 = not(col("a").eq(col("b")));
1611        let expr2 = not(col("b").eq(col("a")));
1612        let plan = LogicalPlanBuilder::from(table_scan)
1613            .project(vec![expr1, expr2])?
1614            .build()?;
1615        let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\
1616        \n  Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\
1617        \n    TableScan: test";
1618        assert_optimized_plan_eq(expected, plan, None);
1619
1620        // is_null(a == b) <=> is_null(b == a)
1621        let table_scan = test_table_scan()?;
1622        let expr1 = is_null(col("a").eq(col("b")));
1623        let expr2 = is_null(col("b").eq(col("a")));
1624        let plan = LogicalPlanBuilder::from(table_scan)
1625            .project(vec![expr1, expr2])?
1626            .build()?;
1627        let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\
1628        \n  Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\
1629        \n    TableScan: test";
1630        assert_optimized_plan_eq(expected, plan, None);
1631
1632        // a + b between 0 and 10 <=> b + a between 0 and 10
1633        let table_scan = test_table_scan()?;
1634        let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1635        let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1636        let plan = LogicalPlanBuilder::from(table_scan)
1637            .project(vec![expr1, expr2])?
1638            .build()?;
1639        let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\
1640        \n  Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\
1641        \n    TableScan: test";
1642        assert_optimized_plan_eq(expected, plan, None);
1643
1644        // c between a + b and 10 <=> c between b + a and 10
1645        let table_scan = test_table_scan()?;
1646        let expr1 = col("c").between(col("a") + col("b"), lit(10));
1647        let expr2 = col("c").between(col("b") + col("a"), lit(10));
1648        let plan = LogicalPlanBuilder::from(table_scan)
1649            .project(vec![expr1, expr2])?
1650            .build()?;
1651        let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\
1652        \n  Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\
1653        \n    TableScan: test";
1654        assert_optimized_plan_eq(expected, plan, None);
1655
1656        // function call with argument <=> function call with argument
1657        let udf = ScalarUDF::from(TestUdf::new());
1658        let table_scan = test_table_scan()?;
1659        let expr1 = udf.call(vec![col("a") + col("b")]);
1660        let expr2 = udf.call(vec![col("b") + col("a")]);
1661        let plan = LogicalPlanBuilder::from(table_scan)
1662            .project(vec![expr1, expr2])?
1663            .build()?;
1664        let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\
1665        \n  Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\
1666        \n    TableScan: test";
1667        assert_optimized_plan_eq(expected, plan, None);
1668        Ok(())
1669    }
1670
1671    /// returns a "random" function that is marked volatile (aka each invocation
1672    /// returns a different value)
1673    ///
1674    /// Does not use datafusion_functions::rand to avoid introducing a
1675    /// dependency on that crate.
1676    fn rand_func() -> ScalarUDF {
1677        ScalarUDF::new_from_impl(RandomStub::new())
1678    }
1679
1680    #[derive(Debug)]
1681    struct RandomStub {
1682        signature: Signature,
1683    }
1684
1685    impl RandomStub {
1686        fn new() -> Self {
1687            Self {
1688                signature: Signature::exact(vec![], Volatility::Volatile),
1689            }
1690        }
1691    }
1692    impl ScalarUDFImpl for RandomStub {
1693        fn as_any(&self) -> &dyn Any {
1694            self
1695        }
1696
1697        fn name(&self) -> &str {
1698            "random"
1699        }
1700
1701        fn signature(&self) -> &Signature {
1702            &self.signature
1703        }
1704
1705        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1706            Ok(DataType::Float64)
1707        }
1708
1709        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1710            panic!("dummy - not implemented")
1711        }
1712    }
1713}