datafusion_optimizer/
common_subexpr_eliminate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions
19
20use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41/// Performs Common Sub-expression Elimination optimization.
42///
43/// This optimization improves query performance by computing expressions that
44/// appear more than once and reusing those results rather than re-computing the
45/// same value
46///
47/// Currently only common sub-expressions within a single `LogicalPlan` are
48/// eliminated.
49///
50/// # Example
51///
52/// Given a projection that computes the same expensive expression
53/// multiple times such as parsing as string as a date with `to_date` twice:
54///
55/// ```text
56/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))])
57/// ```
58///
59/// This optimization will rewrite the plan to compute the common expression once
60/// using a new `ProjectionExec` and then rewrite the original expressions to
61/// refer to that new column.
62///
63/// ```text
64/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here
65///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
66/// ```
67#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71    pub fn new() -> Self {
72        Self {}
73    }
74
75    fn try_optimize_proj(
76        &self,
77        projection: Projection,
78        config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        let Projection {
81            expr,
82            input,
83            schema,
84            ..
85        } = projection;
86        let input = Arc::unwrap_or_clone(input);
87        self.try_unary_plan(expr, input, config)?
88            .map_data(|(new_expr, new_input)| {
89                Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90                    .map(LogicalPlan::Projection)
91            })
92    }
93
94    fn try_optimize_sort(
95        &self,
96        sort: Sort,
97        config: &dyn OptimizerConfig,
98    ) -> Result<Transformed<LogicalPlan>> {
99        let Sort { expr, input, fetch } = sort;
100        let input = Arc::unwrap_or_clone(input);
101        let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102            .into_iter()
103            .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104            .unzip();
105        let new_sort = self
106            .try_unary_plan(sort_expressions, input, config)?
107            .update_data(|(new_expr, new_input)| {
108                LogicalPlan::Sort(Sort {
109                    expr: new_expr
110                        .into_iter()
111                        .zip(sort_params)
112                        .map(|(expr, (asc, nulls_first))| SortExpr {
113                            expr,
114                            asc,
115                            nulls_first,
116                        })
117                        .collect(),
118                    input: Arc::new(new_input),
119                    fetch,
120                })
121            });
122        Ok(new_sort)
123    }
124
125    fn try_optimize_filter(
126        &self,
127        filter: Filter,
128        config: &dyn OptimizerConfig,
129    ) -> Result<Transformed<LogicalPlan>> {
130        let Filter {
131            predicate, input, ..
132        } = filter;
133        let input = Arc::unwrap_or_clone(input);
134        let expr = vec![predicate];
135        self.try_unary_plan(expr, input, config)?
136            .map_data(|(mut new_expr, new_input)| {
137                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
138                let new_predicate = new_expr.pop().unwrap();
139                Filter::try_new(new_predicate, Arc::new(new_input))
140                    .map(LogicalPlan::Filter)
141            })
142    }
143
144    fn try_optimize_window(
145        &self,
146        window: Window,
147        config: &dyn OptimizerConfig,
148    ) -> Result<Transformed<LogicalPlan>> {
149        // Collects window expressions from consecutive `LogicalPlan::Window` nodes into
150        // a list.
151        let (window_expr_list, window_schemas, input) =
152            get_consecutive_window_exprs(window);
153
154        // Extract common sub-expressions from the list.
155
156        match CSE::new(ExprCSEController::new(
157            config.alias_generator().as_ref(),
158            ExprMask::Normal,
159        ))
160        .extract_common_nodes(window_expr_list)?
161        {
162            // If there are common sub-expressions, then the insert a projection node
163            // with the common expressions between the new window nodes and the
164            // original input.
165            FoundCommonNodes::Yes {
166                common_nodes: common_exprs,
167                new_nodes_list: new_exprs_list,
168                original_nodes_list: original_exprs_list,
169            } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170                Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171            }),
172            FoundCommonNodes::No {
173                original_nodes_list: original_exprs_list,
174            } => Ok(Transformed::no((original_exprs_list, input, None))),
175        }?
176        // Recurse into the new input.
177        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
178        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179            self.rewrite(new_input, config)?.map_data(|new_input| {
180                Ok((new_window_expr_list, new_input, window_expr_list))
181            })
182        })?
183        // Rebuild the consecutive window nodes.
184        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185            // If there were common expressions extracted, then we need to make sure
186            // we restore the original column names.
187            // TODO: Although `find_common_exprs()` inserts aliases around extracted
188            //  common expressions this doesn't mean that the original column names
189            //  (schema) are preserved due to the inserted aliases are not always at
190            //  the top of the expression.
191            //  Let's consider improving `find_common_exprs()` to always keep column
192            //  names and get rid of additional name preserving logic here.
193            if let Some(window_expr_list) = window_expr_list {
194                let name_preserver = NamePreserver::new_for_projection();
195                let saved_names = window_expr_list
196                    .iter()
197                    .map(|exprs| {
198                        exprs
199                            .iter()
200                            .map(|expr| name_preserver.save(expr))
201                            .collect::<Vec<_>>()
202                    })
203                    .collect::<Vec<_>>();
204                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205                    new_input,
206                    |plan, (new_window_expr, saved_names)| {
207                        let new_window_expr = new_window_expr
208                            .into_iter()
209                            .zip(saved_names)
210                            .map(|(new_window_expr, saved_name)| {
211                                saved_name.restore(new_window_expr)
212                            })
213                            .collect::<Vec<_>>();
214                        Window::try_new(new_window_expr, Arc::new(plan))
215                            .map(LogicalPlan::Window)
216                    },
217                )
218            } else {
219                new_window_expr_list
220                    .into_iter()
221                    .zip(window_schemas)
222                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223                        Window::try_new_with_schema(
224                            new_window_expr,
225                            Arc::new(plan),
226                            schema,
227                        )
228                        .map(LogicalPlan::Window)
229                    })
230            }
231        })
232    }
233
234    fn try_optimize_aggregate(
235        &self,
236        aggregate: Aggregate,
237        config: &dyn OptimizerConfig,
238    ) -> Result<Transformed<LogicalPlan>> {
239        let Aggregate {
240            group_expr,
241            aggr_expr,
242            input,
243            schema,
244            ..
245        } = aggregate;
246        let input = Arc::unwrap_or_clone(input);
247        // Extract common sub-expressions from the aggregate and grouping expressions.
248        match CSE::new(ExprCSEController::new(
249            config.alias_generator().as_ref(),
250            ExprMask::Normal,
251        ))
252        .extract_common_nodes(vec![group_expr, aggr_expr])?
253        {
254            // If there are common sub-expressions, then insert a projection node
255            // with the common expressions between the new aggregate node and the
256            // original input.
257            FoundCommonNodes::Yes {
258                common_nodes: common_exprs,
259                new_nodes_list: mut new_exprs_list,
260                original_nodes_list: mut original_exprs_list,
261            } => {
262                let new_aggr_expr = new_exprs_list.pop().unwrap();
263                let new_group_expr = new_exprs_list.pop().unwrap();
264
265                build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266                    let aggr_expr = original_exprs_list.pop().unwrap();
267                    Transformed::yes((
268                        new_aggr_expr,
269                        new_group_expr,
270                        new_input,
271                        Some(aggr_expr),
272                    ))
273                })
274            }
275
276            FoundCommonNodes::No {
277                original_nodes_list: mut original_exprs_list,
278            } => {
279                let new_aggr_expr = original_exprs_list.pop().unwrap();
280                let new_group_expr = original_exprs_list.pop().unwrap();
281
282                Ok(Transformed::no((
283                    new_aggr_expr,
284                    new_group_expr,
285                    input,
286                    None,
287                )))
288            }
289        }?
290        // Recurse into the new input.
291        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
292        .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293            self.rewrite(new_input, config)?.map_data(|new_input| {
294                Ok((
295                    new_aggr_expr,
296                    new_group_expr,
297                    aggr_expr,
298                    Arc::new(new_input),
299                ))
300            })
301        })?
302        // Try extracting common aggregate expressions and rebuild the aggregate node.
303        .transform_data(
304            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305                // Extract common aggregate sub-expressions from the aggregate expressions.
306                match CSE::new(ExprCSEController::new(
307                    config.alias_generator().as_ref(),
308                    ExprMask::NormalAndAggregates,
309                ))
310                .extract_common_nodes(vec![new_aggr_expr])?
311                {
312                    FoundCommonNodes::Yes {
313                        common_nodes: common_exprs,
314                        new_nodes_list: mut new_exprs_list,
315                        original_nodes_list: mut original_exprs_list,
316                    } => {
317                        let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318                        let new_aggr_expr = original_exprs_list.pop().unwrap();
319                        let saved_names = if let Some(aggr_expr) = aggr_expr {
320                            let name_preserver = NamePreserver::new_for_projection();
321                            aggr_expr
322                                .iter()
323                                .map(|expr| Some(name_preserver.save(expr)))
324                                .collect::<Vec<_>>()
325                        } else {
326                            new_aggr_expr
327                                .clone()
328                                .into_iter()
329                                .map(|_| None)
330                                .collect::<Vec<_>>()
331                        };
332
333                        let mut agg_exprs = common_exprs
334                            .into_iter()
335                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
336                            .collect::<Vec<_>>();
337
338                        let mut proj_exprs = vec![];
339                        for expr in &new_group_expr {
340                            extract_expressions(expr, &mut proj_exprs)
341                        }
342                        for ((expr_rewritten, expr_orig), saved_name) in
343                            rewritten_aggr_expr
344                                .into_iter()
345                                .zip(new_aggr_expr)
346                                .zip(saved_names)
347                        {
348                            if expr_rewritten == expr_orig {
349                                let expr_rewritten = if let Some(saved_name) = saved_name
350                                {
351                                    saved_name.restore(expr_rewritten)
352                                } else {
353                                    expr_rewritten
354                                };
355                                if let Expr::Alias(Alias { expr, name, .. }) =
356                                    expr_rewritten
357                                {
358                                    agg_exprs.push(expr.alias(&name));
359                                    proj_exprs
360                                        .push(Expr::Column(Column::from_name(name)));
361                                } else {
362                                    let expr_alias =
363                                        config.alias_generator().next(CSE_PREFIX);
364                                    let (qualifier, field_name) =
365                                        expr_rewritten.qualified_name();
366                                    let out_name =
367                                        qualified_name(qualifier.as_ref(), &field_name);
368
369                                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
370                                    proj_exprs.push(
371                                        Expr::Column(Column::from_name(expr_alias))
372                                            .alias(out_name),
373                                    );
374                                }
375                            } else {
376                                proj_exprs.push(expr_rewritten);
377                            }
378                        }
379
380                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
381                            new_input,
382                            new_group_expr,
383                            agg_exprs,
384                        )?);
385                        Projection::try_new(proj_exprs, Arc::new(agg))
386                            .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
387                    }
388
389                    // If there aren't any common aggregate sub-expressions, then just
390                    // rebuild the aggregate node.
391                    FoundCommonNodes::No {
392                        original_nodes_list: mut original_exprs_list,
393                    } => {
394                        let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
395
396                        // If there were common expressions extracted, then we need to
397                        // make sure we restore the original column names.
398                        // TODO: Although `find_common_exprs()` inserts aliases around
399                        //  extracted common expressions this doesn't mean that the
400                        //  original column names (schema) are preserved due to the
401                        //  inserted aliases are not always at the top of the
402                        //  expression.
403                        //  Let's consider improving `find_common_exprs()` to always
404                        //  keep column names and get rid of additional name
405                        //  preserving logic here.
406                        if let Some(aggr_expr) = aggr_expr {
407                            let name_preserver = NamePreserver::new_for_projection();
408                            let saved_names = aggr_expr
409                                .iter()
410                                .map(|expr| name_preserver.save(expr))
411                                .collect::<Vec<_>>();
412                            let new_aggr_expr = rewritten_aggr_expr
413                                .into_iter()
414                                .zip(saved_names)
415                                .map(|(new_expr, saved_name)| {
416                                    saved_name.restore(new_expr)
417                                })
418                                .collect::<Vec<Expr>>();
419
420                            // Since `group_expr` may have changed, schema may also.
421                            // Use `try_new()` method.
422                            Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
423                                .map(LogicalPlan::Aggregate)
424                                .map(Transformed::no)
425                        } else {
426                            Aggregate::try_new_with_schema(
427                                new_input,
428                                new_group_expr,
429                                rewritten_aggr_expr,
430                                schema,
431                            )
432                            .map(LogicalPlan::Aggregate)
433                            .map(Transformed::no)
434                        }
435                    }
436                }
437            },
438        )
439    }
440
441    /// Rewrites the expr list and input to remove common subexpressions
442    ///
443    /// # Parameters
444    ///
445    /// * `exprs`: List of expressions in the node
446    /// * `input`: input plan (that produces the columns referred to in `exprs`)
447    ///
448    /// # Return value
449    ///
450    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
451    ///
452    /// 1. The original `input` of no common subexpressions were extracted
453    /// 2. A newly added projection on top of the original input
454    ///    that computes the common subexpressions
455    fn try_unary_plan(
456        &self,
457        exprs: Vec<Expr>,
458        input: LogicalPlan,
459        config: &dyn OptimizerConfig,
460    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
461        // Extract common sub-expressions from the expressions.
462        match CSE::new(ExprCSEController::new(
463            config.alias_generator().as_ref(),
464            ExprMask::Normal,
465        ))
466        .extract_common_nodes(vec![exprs])?
467        {
468            FoundCommonNodes::Yes {
469                common_nodes: common_exprs,
470                new_nodes_list: mut new_exprs_list,
471                original_nodes_list: _,
472            } => {
473                let new_exprs = new_exprs_list.pop().unwrap();
474                build_common_expr_project_plan(input, common_exprs)
475                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
476            }
477            FoundCommonNodes::No {
478                original_nodes_list: mut original_exprs_list,
479            } => {
480                let new_exprs = original_exprs_list.pop().unwrap();
481                Ok(Transformed::no((new_exprs, input)))
482            }
483        }?
484        // Recurse into the new input.
485        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
486        .transform_data(|(new_exprs, new_input)| {
487            self.rewrite(new_input, config)?
488                .map_data(|new_input| Ok((new_exprs, new_input)))
489        })
490    }
491}
492
493/// Get all window expressions inside the consecutive window operators.
494///
495/// Returns the window expressions, and the input to the deepest child
496/// LogicalPlan.
497///
498/// For example, if the input window looks like
499///
500/// ```text
501///   LogicalPlan::Window(exprs=[a, b, c])
502///     LogicalPlan::Window(exprs=[d])
503///       InputPlan
504/// ```
505///
506/// Returns:
507/// *  `window_exprs`: `[[a, b, c], [d]]`
508/// * InputPlan
509///
510/// Consecutive window expressions may refer to same complex expression.
511///
512/// If same complex expression is referred more than once by subsequent
513/// `WindowAggr`s, we can cache complex expression by evaluating it with a
514/// projection before the first WindowAggr.
515///
516/// This enables us to cache complex expression "c3+c4" for following plan:
517///
518/// ```text
519/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
520/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
521/// ```
522///
523/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
524fn get_consecutive_window_exprs(
525    window: Window,
526) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
527    let mut window_expr_list = vec![];
528    let mut window_schemas = vec![];
529    let mut plan = LogicalPlan::Window(window);
530    while let LogicalPlan::Window(Window {
531        input,
532        window_expr,
533        schema,
534    }) = plan
535    {
536        window_expr_list.push(window_expr);
537        window_schemas.push(schema);
538
539        plan = Arc::unwrap_or_clone(input);
540    }
541    (window_expr_list, window_schemas, plan)
542}
543
544impl OptimizerRule for CommonSubexprEliminate {
545    fn supports_rewrite(&self) -> bool {
546        true
547    }
548
549    fn apply_order(&self) -> Option<ApplyOrder> {
550        // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
551        // This is because in some cases adjacent nodes are collected (e.g. `Window`) and
552        // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
553        None
554    }
555
556    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
557    fn rewrite(
558        &self,
559        plan: LogicalPlan,
560        config: &dyn OptimizerConfig,
561    ) -> Result<Transformed<LogicalPlan>> {
562        let original_schema = Arc::clone(plan.schema());
563
564        let optimized_plan = match plan {
565            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
566            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
567            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
568            LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
569            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
570            LogicalPlan::Join(_)
571            | LogicalPlan::Repartition(_)
572            | LogicalPlan::Union(_)
573            | LogicalPlan::TableScan(_)
574            | LogicalPlan::Values(_)
575            | LogicalPlan::EmptyRelation(_)
576            | LogicalPlan::Subquery(_)
577            | LogicalPlan::SubqueryAlias(_)
578            | LogicalPlan::Limit(_)
579            | LogicalPlan::Ddl(_)
580            | LogicalPlan::Explain(_)
581            | LogicalPlan::Analyze(_)
582            | LogicalPlan::Statement(_)
583            | LogicalPlan::DescribeTable(_)
584            | LogicalPlan::Distinct(_)
585            | LogicalPlan::Extension(_)
586            | LogicalPlan::Dml(_)
587            | LogicalPlan::Copy(_)
588            | LogicalPlan::Unnest(_)
589            | LogicalPlan::RecursiveQuery(_) => {
590                // This rule handles recursion itself in a `ApplyOrder::TopDown` like
591                // manner.
592                plan.map_children(|c| self.rewrite(c, config))?
593            }
594        };
595
596        // If we rewrote the plan, ensure the schema stays the same
597        if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
598        {
599            optimized_plan.map_data(|optimized_plan| {
600                build_recover_project_plan(&original_schema, optimized_plan)
601            })
602        } else {
603            Ok(optimized_plan)
604        }
605    }
606
607    fn name(&self) -> &str {
608        "common_sub_expression_eliminate"
609    }
610}
611
612/// Which type of [expressions](Expr) should be considered for rewriting?
613#[derive(Debug, Clone, Copy)]
614enum ExprMask {
615    /// Ignores:
616    ///
617    /// - [`Literal`](Expr::Literal)
618    /// - [`Columns`](Expr::Column)
619    /// - [`ScalarVariable`](Expr::ScalarVariable)
620    /// - [`Alias`](Expr::Alias)
621    /// - [`Wildcard`](Expr::Wildcard)
622    /// - [`AggregateFunction`](Expr::AggregateFunction)
623    Normal,
624
625    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
626    NormalAndAggregates,
627}
628
629struct ExprCSEController<'a> {
630    alias_generator: &'a AliasGenerator,
631    mask: ExprMask,
632
633    // how many aliases have we seen so far
634    alias_counter: usize,
635}
636
637impl<'a> ExprCSEController<'a> {
638    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
639        Self {
640            alias_generator,
641            mask,
642            alias_counter: 0,
643        }
644    }
645}
646
647impl CSEController for ExprCSEController<'_> {
648    type Node = Expr;
649
650    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
651        match node {
652            // In case of `ScalarFunction`s we don't know which children are surely
653            // executed so start visiting all children conditionally and stop the
654            // recursion with `TreeNodeRecursion::Jump`.
655            Expr::ScalarFunction(ScalarFunction { func, args })
656                if func.short_circuits() =>
657            {
658                Some((vec![], args.iter().collect()))
659            }
660
661            // In case of `And` and `Or` the first child is surely executed, but we
662            // account subexpressions as conditional in the second.
663            Expr::BinaryExpr(BinaryExpr {
664                left,
665                op: Operator::And | Operator::Or,
666                right,
667            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
668
669            // In case of `Case` the optional base expression and the first when
670            // expressions are surely executed, but we account subexpressions as
671            // conditional in the others.
672            Expr::Case(Case {
673                expr,
674                when_then_expr,
675                else_expr,
676            }) => Some((
677                expr.iter()
678                    .map(|e| e.as_ref())
679                    .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
680                    .collect(),
681                when_then_expr
682                    .iter()
683                    .take(1)
684                    .map(|(_, then)| then.as_ref())
685                    .chain(
686                        when_then_expr
687                            .iter()
688                            .skip(1)
689                            .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
690                    )
691                    .chain(else_expr.iter().map(|e| e.as_ref()))
692                    .collect(),
693            )),
694            _ => None,
695        }
696    }
697
698    fn is_valid(node: &Expr) -> bool {
699        !node.is_volatile_node()
700    }
701
702    fn is_ignored(&self, node: &Expr) -> bool {
703        // TODO: remove the next line after `Expr::Wildcard` is removed
704        #[expect(deprecated)]
705        let is_normal_minus_aggregates = matches!(
706            node,
707            Expr::Literal(..)
708                | Expr::Column(..)
709                | Expr::ScalarVariable(..)
710                | Expr::Alias(..)
711                | Expr::Wildcard { .. }
712        );
713
714        let is_aggr = matches!(node, Expr::AggregateFunction(..));
715
716        match self.mask {
717            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
718            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
719        }
720    }
721
722    fn generate_alias(&self) -> String {
723        self.alias_generator.next(CSE_PREFIX)
724    }
725
726    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
727        // alias the expressions without an `Alias` ancestor node
728        if self.alias_counter > 0 {
729            col(alias)
730        } else {
731            self.alias_counter += 1;
732            col(alias).alias(node.schema_name().to_string())
733        }
734    }
735
736    fn rewrite_f_down(&mut self, node: &Expr) {
737        if matches!(node, Expr::Alias(_)) {
738            self.alias_counter += 1;
739        }
740    }
741    fn rewrite_f_up(&mut self, node: &Expr) {
742        if matches!(node, Expr::Alias(_)) {
743            self.alias_counter -= 1
744        }
745    }
746}
747
748impl Default for CommonSubexprEliminate {
749    fn default() -> Self {
750        Self::new()
751    }
752}
753
754/// Build the "intermediate" projection plan that evaluates the extracted common
755/// expressions.
756///
757/// # Arguments
758/// input: the input plan
759///
760/// common_exprs: which common subexpressions were used (and thus are added to
761/// intermediate projection)
762///
763/// expr_stats: the set of common subexpressions
764fn build_common_expr_project_plan(
765    input: LogicalPlan,
766    common_exprs: Vec<(Expr, String)>,
767) -> Result<LogicalPlan> {
768    let mut fields_set = BTreeSet::new();
769    let mut project_exprs = common_exprs
770        .into_iter()
771        .map(|(expr, expr_alias)| {
772            fields_set.insert(expr_alias.clone());
773            Ok(expr.alias(expr_alias))
774        })
775        .collect::<Result<Vec<_>>>()?;
776
777    for (qualifier, field) in input.schema().iter() {
778        if fields_set.insert(qualified_name(qualifier, field.name())) {
779            project_exprs.push(Expr::from((qualifier, field)));
780        }
781    }
782
783    Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
784}
785
786/// Build the projection plan to eliminate unnecessary columns produced by
787/// the "intermediate" projection plan built in [build_common_expr_project_plan].
788///
789/// This is required to keep the schema the same for plans that pass the input
790/// on to the output, such as `Filter` or `Sort`.
791fn build_recover_project_plan(
792    schema: &DFSchema,
793    input: LogicalPlan,
794) -> Result<LogicalPlan> {
795    let col_exprs = schema.iter().map(Expr::from).collect();
796    Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
797}
798
799fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
800    if let Expr::GroupingSet(groupings) = expr {
801        for e in groupings.distinct_expr() {
802            let (qualifier, field_name) = e.qualified_name();
803            let col = Column::new(qualifier, field_name);
804            result.push(Expr::Column(col))
805        }
806    } else {
807        let (qualifier, field_name) = expr.qualified_name();
808        let col = Column::new(qualifier, field_name);
809        result.push(Expr::Column(col));
810    }
811}
812
813#[cfg(test)]
814mod test {
815    use std::any::Any;
816    use std::iter;
817
818    use arrow::datatypes::{DataType, Field, Schema};
819    use datafusion_expr::logical_plan::{table_scan, JoinType};
820    use datafusion_expr::{
821        grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
822        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
823        SimpleAggregateUDF, Volatility,
824    };
825    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
826
827    use super::*;
828    use crate::assert_optimized_plan_eq_snapshot;
829    use crate::optimizer::OptimizerContext;
830    use crate::test::*;
831    use datafusion_expr::test::function_stub::{avg, sum};
832
833    macro_rules! assert_optimized_plan_equal {
834        (
835            $config:expr,
836            $plan:expr,
837            @ $expected:literal $(,)?
838        ) => {{
839            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
840            assert_optimized_plan_eq_snapshot!(
841                $config,
842                rules,
843                $plan,
844                @ $expected,
845            )
846        }};
847
848        (
849            $plan:expr,
850            @ $expected:literal $(,)?
851        ) => {{
852            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
853            let optimizer_ctx = OptimizerContext::new();
854            assert_optimized_plan_eq_snapshot!(
855                optimizer_ctx,
856                rules,
857                $plan,
858                @ $expected,
859            )
860        }};
861    }
862
863    #[test]
864    fn tpch_q1_simplified() -> Result<()> {
865        // SQL:
866        //  select
867        //      sum(a * (1 - b)),
868        //      sum(a * (1 - b) * (1 + c))
869        //  from T;
870        //
871        // The manual assembled logical plan don't contains the outermost `Projection`.
872
873        let table_scan = test_table_scan()?;
874
875        let plan = LogicalPlanBuilder::from(table_scan)
876            .aggregate(
877                iter::empty::<Expr>(),
878                vec![
879                    sum(col("a") * (lit(1) - col("b"))),
880                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
881                ],
882            )?
883            .build()?;
884
885        assert_optimized_plan_equal!(
886            plan,
887            @ r"
888        Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
889          Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
890            TableScan: test
891        "
892        )
893    }
894
895    #[test]
896    fn nested_aliases() -> Result<()> {
897        let table_scan = test_table_scan()?;
898
899        let plan = LogicalPlanBuilder::from(table_scan)
900            .project(vec![
901                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
902                col("a") + col("b"),
903            ])?
904            .build()?;
905
906        assert_optimized_plan_equal!(
907            plan,
908            @ r"
909        Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
910          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
911            TableScan: test
912        "
913        )
914    }
915
916    #[test]
917    fn aggregate() -> Result<()> {
918        let table_scan = test_table_scan()?;
919
920        let return_type = DataType::UInt32;
921        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
922        let udf_agg = |inner: Expr| {
923            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
924                Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
925                    "my_agg",
926                    Signature::exact(vec![DataType::UInt32], Volatility::Stable),
927                    return_type.clone(),
928                    Arc::clone(&accumulator),
929                    vec![Field::new("value", DataType::UInt32, true).into()],
930                ))),
931                vec![inner],
932                false,
933                None,
934                vec![],
935                None,
936            ))
937        };
938
939        // test: common aggregates
940        let plan = LogicalPlanBuilder::from(table_scan.clone())
941            .aggregate(
942                iter::empty::<Expr>(),
943                vec![
944                    // common: avg(col("a"))
945                    avg(col("a")).alias("col1"),
946                    avg(col("a")).alias("col2"),
947                    // no common
948                    avg(col("b")).alias("col3"),
949                    avg(col("c")),
950                    // common: udf_agg(col("a"))
951                    udf_agg(col("a")).alias("col4"),
952                    udf_agg(col("a")).alias("col5"),
953                    // no common
954                    udf_agg(col("b")).alias("col6"),
955                    udf_agg(col("c")),
956                ],
957            )?
958            .build()?;
959
960        assert_optimized_plan_equal!(
961            plan,
962            @ r"
963        Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
964          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
965            TableScan: test
966        "
967        )?;
968
969        // test: trafo after aggregate
970        let plan = LogicalPlanBuilder::from(table_scan.clone())
971            .aggregate(
972                iter::empty::<Expr>(),
973                vec![
974                    lit(1) + avg(col("a")),
975                    lit(1) - avg(col("a")),
976                    lit(1) + udf_agg(col("a")),
977                    lit(1) - udf_agg(col("a")),
978                ],
979            )?
980            .build()?;
981
982        assert_optimized_plan_equal!(
983            plan,
984            @ r"
985        Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
986          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
987            TableScan: test
988        "
989        )?;
990
991        // test: transformation before aggregate
992        let plan = LogicalPlanBuilder::from(table_scan.clone())
993            .aggregate(
994                iter::empty::<Expr>(),
995                vec![
996                    avg(lit(1u32) + col("a")).alias("col1"),
997                    udf_agg(lit(1u32) + col("a")).alias("col2"),
998                ],
999            )?
1000            .build()?;
1001
1002        assert_optimized_plan_equal!(
1003            plan,
1004            @ r"
1005        Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1006          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1007            TableScan: test
1008        "
1009        )?;
1010
1011        // test: common between agg and group
1012        let plan = LogicalPlanBuilder::from(table_scan.clone())
1013            .aggregate(
1014                vec![lit(1u32) + col("a")],
1015                vec![
1016                    avg(lit(1u32) + col("a")).alias("col1"),
1017                    udf_agg(lit(1u32) + col("a")).alias("col2"),
1018                ],
1019            )?
1020            .build()?;
1021
1022        assert_optimized_plan_equal!(
1023            plan,
1024            @ r"
1025        Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1026          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1027            TableScan: test
1028        "
1029        )?;
1030
1031        // test: all mixed
1032        let plan = LogicalPlanBuilder::from(table_scan)
1033            .aggregate(
1034                vec![lit(1u32) + col("a")],
1035                vec![
1036                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1037                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1038                    avg(lit(1u32) + col("a")),
1039                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1040                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1041                    udf_agg(lit(1u32) + col("a")),
1042                ],
1043            )?
1044            .build()?;
1045
1046        assert_optimized_plan_equal!(
1047            plan,
1048            @ r"
1049        Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1050          Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1051            Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1052              TableScan: test
1053        "
1054        )
1055    }
1056
1057    #[test]
1058    fn aggregate_with_relations_and_dots() -> Result<()> {
1059        let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1060        let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1061
1062        let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1063
1064        let plan = LogicalPlanBuilder::from(table_scan)
1065            .aggregate(
1066                vec![col_a.clone()],
1067                vec![
1068                    (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1069                    avg(lit(1u32) + col_a),
1070                ],
1071            )?
1072            .build()?;
1073
1074        assert_optimized_plan_equal!(
1075            plan,
1076            @ r"
1077        Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1078          Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1079            Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1080              TableScan: table.test
1081        "
1082        )
1083    }
1084
1085    #[test]
1086    fn subexpr_in_same_order() -> Result<()> {
1087        let table_scan = test_table_scan()?;
1088
1089        let plan = LogicalPlanBuilder::from(table_scan)
1090            .project(vec![
1091                (lit(1) + col("a")).alias("first"),
1092                (lit(1) + col("a")).alias("second"),
1093            ])?
1094            .build()?;
1095
1096        assert_optimized_plan_equal!(
1097            plan,
1098            @ r"
1099        Projection: __common_expr_1 AS first, __common_expr_1 AS second
1100          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1101            TableScan: test
1102        "
1103        )
1104    }
1105
1106    #[test]
1107    fn subexpr_in_different_order() -> Result<()> {
1108        let table_scan = test_table_scan()?;
1109
1110        let plan = LogicalPlanBuilder::from(table_scan)
1111            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1112            .build()?;
1113
1114        assert_optimized_plan_equal!(
1115            plan,
1116            @ r"
1117        Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1118          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1119            TableScan: test
1120        "
1121        )
1122    }
1123
1124    #[test]
1125    fn cross_plans_subexpr() -> Result<()> {
1126        let table_scan = test_table_scan()?;
1127
1128        let plan = LogicalPlanBuilder::from(table_scan)
1129            .project(vec![lit(1) + col("a"), col("a")])?
1130            .project(vec![lit(1) + col("a")])?
1131            .build()?;
1132
1133        assert_optimized_plan_equal!(
1134            plan,
1135            @ r"
1136        Projection: Int32(1) + test.a
1137          Projection: Int32(1) + test.a, test.a
1138            TableScan: test
1139        "
1140        )
1141    }
1142
1143    #[test]
1144    fn redundant_project_fields() {
1145        let table_scan = test_table_scan().unwrap();
1146        let c_plus_a = col("c") + col("a");
1147        let b_plus_a = col("b") + col("a");
1148        let common_exprs_1 = vec![
1149            (c_plus_a, format!("{CSE_PREFIX}_1")),
1150            (b_plus_a, format!("{CSE_PREFIX}_2")),
1151        ];
1152        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1153        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1154        let common_exprs_2 = vec![
1155            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1156            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1157        ];
1158        let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1159        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1160
1161        let mut field_set = BTreeSet::new();
1162        for name in project_2.schema().field_names() {
1163            assert!(field_set.insert(name));
1164        }
1165    }
1166
1167    #[test]
1168    fn redundant_project_fields_join_input() {
1169        let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1170        let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1171        let join = LogicalPlanBuilder::from(table_scan_1)
1172            .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1173            .unwrap()
1174            .build()
1175            .unwrap();
1176        let c_plus_a = col("test1.c") + col("test1.a");
1177        let b_plus_a = col("test1.b") + col("test1.a");
1178        let common_exprs_1 = vec![
1179            (c_plus_a, format!("{CSE_PREFIX}_1")),
1180            (b_plus_a, format!("{CSE_PREFIX}_2")),
1181        ];
1182        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1183        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1184        let common_exprs_2 = vec![
1185            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1186            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1187        ];
1188        let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1189        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1190
1191        let mut field_set = BTreeSet::new();
1192        for name in project_2.schema().field_names() {
1193            assert!(field_set.insert(name));
1194        }
1195    }
1196
1197    #[test]
1198    fn eliminated_subexpr_datatype() {
1199        use datafusion_expr::cast;
1200
1201        let schema = Schema::new(vec![
1202            Field::new("a", DataType::UInt64, false),
1203            Field::new("b", DataType::UInt64, false),
1204            Field::new("c", DataType::UInt64, false),
1205        ]);
1206
1207        let plan = table_scan(Some("table"), &schema, None)
1208            .unwrap()
1209            .filter(
1210                cast(col("a"), DataType::Int64)
1211                    .lt(lit(1_i64))
1212                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1213            )
1214            .unwrap()
1215            .build()
1216            .unwrap();
1217        let rule = CommonSubexprEliminate::new();
1218        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1219        assert!(optimized_plan.transformed);
1220        let optimized_plan = optimized_plan.data;
1221
1222        let schema = optimized_plan.schema();
1223        let fields_with_datatypes: Vec<_> = schema
1224            .fields()
1225            .iter()
1226            .map(|field| (field.name(), field.data_type()))
1227            .collect();
1228        let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1229        let expected = r#"[
1230    (
1231        "a",
1232        UInt64,
1233    ),
1234    (
1235        "b",
1236        UInt64,
1237    ),
1238    (
1239        "c",
1240        UInt64,
1241    ),
1242]"#;
1243        assert_eq!(expected, formatted_fields_with_datatype);
1244    }
1245
1246    #[test]
1247    fn filter_schema_changed() -> Result<()> {
1248        let table_scan = test_table_scan()?;
1249
1250        let plan = LogicalPlanBuilder::from(table_scan)
1251            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1252            .build()?;
1253
1254        assert_optimized_plan_equal!(
1255            plan,
1256            @ r"
1257        Projection: test.a, test.b, test.c
1258          Filter: __common_expr_1 - Int32(10) > __common_expr_1
1259            Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1260              TableScan: test
1261        "
1262        )
1263    }
1264
1265    #[test]
1266    fn test_extract_expressions_from_grouping_set() -> Result<()> {
1267        let mut result = Vec::with_capacity(3);
1268        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1269        extract_expressions(&grouping, &mut result);
1270
1271        assert!(result.len() == 3);
1272        Ok(())
1273    }
1274
1275    #[test]
1276    fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1277        let mut result = Vec::with_capacity(2);
1278        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1279        extract_expressions(&grouping, &mut result);
1280        assert!(result.len() == 2);
1281        Ok(())
1282    }
1283
1284    #[test]
1285    fn test_alias_collision() -> Result<()> {
1286        let table_scan = test_table_scan()?;
1287
1288        let config = OptimizerContext::new();
1289        let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1290        let plan = LogicalPlanBuilder::from(table_scan.clone())
1291            .project(vec![
1292                (col("a") + col("b")).alias(common_expr_1.clone()),
1293                col("c"),
1294            ])?
1295            .project(vec![
1296                col(common_expr_1.clone()).alias("c1"),
1297                col(common_expr_1).alias("c2"),
1298                (col("c") + lit(2)).alias("c3"),
1299                (col("c") + lit(2)).alias("c4"),
1300            ])?
1301            .build()?;
1302
1303        assert_optimized_plan_equal!(
1304            config,
1305            plan,
1306            @ r"
1307        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1308          Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1309            Projection: test.a + test.b AS __common_expr_1, test.c
1310              TableScan: test
1311        "
1312        )?;
1313
1314        let config = OptimizerContext::new();
1315        let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1316        let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1317        let plan = LogicalPlanBuilder::from(table_scan)
1318            .project(vec![
1319                (col("a") + col("b")).alias(common_expr_2.clone()),
1320                col("c"),
1321            ])?
1322            .project(vec![
1323                col(common_expr_2.clone()).alias("c1"),
1324                col(common_expr_2).alias("c2"),
1325                (col("c") + lit(2)).alias("c3"),
1326                (col("c") + lit(2)).alias("c4"),
1327            ])?
1328            .build()?;
1329
1330        assert_optimized_plan_equal!(
1331            config,
1332            plan,
1333            @ r"
1334        Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1335          Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1336            Projection: test.a + test.b AS __common_expr_2, test.c
1337              TableScan: test
1338        "
1339        )?;
1340
1341        Ok(())
1342    }
1343
1344    #[test]
1345    fn test_extract_expressions_from_col() -> Result<()> {
1346        let mut result = Vec::with_capacity(1);
1347        extract_expressions(&col("a"), &mut result);
1348        assert!(result.len() == 1);
1349        Ok(())
1350    }
1351
1352    #[test]
1353    fn test_short_circuits() -> Result<()> {
1354        let table_scan = test_table_scan()?;
1355
1356        let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1357        let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1358        let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1359        let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1360        let plan = LogicalPlanBuilder::from(table_scan)
1361            .project(vec![
1362                extracted_short_circuit.clone().alias("c1"),
1363                extracted_short_circuit.alias("c2"),
1364                extracted_short_circuit_leg_1
1365                    .clone()
1366                    .or(not_extracted_short_circuit_leg_2.clone())
1367                    .alias("c3"),
1368                extracted_short_circuit_leg_1
1369                    .and(not_extracted_short_circuit_leg_2)
1370                    .alias("c4"),
1371                extracted_short_circuit_leg_3
1372                    .clone()
1373                    .or(extracted_short_circuit_leg_3)
1374                    .alias("c5"),
1375            ])?
1376            .build()?;
1377
1378        assert_optimized_plan_equal!(
1379            plan,
1380            @ r"
1381        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1382          Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1383            TableScan: test
1384        "
1385        )
1386    }
1387
1388    #[test]
1389    fn test_volatile() -> Result<()> {
1390        let table_scan = test_table_scan()?;
1391
1392        let extracted_child = col("a") + col("b");
1393        let rand = rand_func().call(vec![]);
1394        let not_extracted_volatile = extracted_child + rand;
1395        let plan = LogicalPlanBuilder::from(table_scan)
1396            .project(vec![
1397                not_extracted_volatile.clone().alias("c1"),
1398                not_extracted_volatile.alias("c2"),
1399            ])?
1400            .build()?;
1401
1402        assert_optimized_plan_equal!(
1403            plan,
1404            @ r"
1405        Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1406          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1407            TableScan: test
1408        "
1409        )
1410    }
1411
1412    #[test]
1413    fn test_volatile_short_circuits() -> Result<()> {
1414        let table_scan = test_table_scan()?;
1415
1416        let rand = rand_func().call(vec![]);
1417        let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1418        let not_extracted_volatile_short_circuit_1 =
1419            extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1420        let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1421        let not_extracted_volatile_short_circuit_2 =
1422            rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1423        let plan = LogicalPlanBuilder::from(table_scan)
1424            .project(vec![
1425                not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1426                not_extracted_volatile_short_circuit_1.alias("c2"),
1427                not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1428                not_extracted_volatile_short_circuit_2.alias("c4"),
1429            ])?
1430            .build()?;
1431
1432        assert_optimized_plan_equal!(
1433            plan,
1434            @ r"
1435        Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1436          Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1437            TableScan: test
1438        "
1439        )
1440    }
1441
1442    #[test]
1443    fn test_non_top_level_common_expression() -> Result<()> {
1444        let table_scan = test_table_scan()?;
1445
1446        let common_expr = col("a") + col("b");
1447        let plan = LogicalPlanBuilder::from(table_scan)
1448            .project(vec![
1449                common_expr.clone().alias("c1"),
1450                common_expr.alias("c2"),
1451            ])?
1452            .project(vec![col("c1"), col("c2")])?
1453            .build()?;
1454
1455        assert_optimized_plan_equal!(
1456            plan,
1457            @ r"
1458        Projection: c1, c2
1459          Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1460            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1461              TableScan: test
1462        "
1463        )
1464    }
1465
1466    #[test]
1467    fn test_nested_common_expression() -> Result<()> {
1468        let table_scan = test_table_scan()?;
1469
1470        let nested_common_expr = col("a") + col("b");
1471        let common_expr = nested_common_expr.clone() * nested_common_expr;
1472        let plan = LogicalPlanBuilder::from(table_scan)
1473            .project(vec![
1474                common_expr.clone().alias("c1"),
1475                common_expr.alias("c2"),
1476            ])?
1477            .build()?;
1478
1479        assert_optimized_plan_equal!(
1480            plan,
1481            @ r"
1482        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1483          Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1484            Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1485              TableScan: test
1486        "
1487        )
1488    }
1489
1490    #[test]
1491    fn test_normalize_add_expression() -> Result<()> {
1492        // a + b <=> b + a
1493        let table_scan = test_table_scan()?;
1494        let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1495        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1496
1497        assert_optimized_plan_equal!(
1498            plan,
1499            @ r"
1500        Projection: test.a, test.b, test.c
1501          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1502            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1503              TableScan: test
1504        "
1505        )
1506    }
1507
1508    #[test]
1509    fn test_normalize_multi_expression() -> Result<()> {
1510        // a * b <=> b * a
1511        let table_scan = test_table_scan()?;
1512        let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1513        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1514
1515        assert_optimized_plan_equal!(
1516            plan,
1517            @ r"
1518        Projection: test.a, test.b, test.c
1519          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1520            Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1521              TableScan: test
1522        "
1523        )
1524    }
1525
1526    #[test]
1527    fn test_normalize_bitset_and_expression() -> Result<()> {
1528        // a & b <=> b & a
1529        let table_scan = test_table_scan()?;
1530        let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1531        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1532
1533        assert_optimized_plan_equal!(
1534            plan,
1535            @ r"
1536        Projection: test.a, test.b, test.c
1537          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1538            Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1539              TableScan: test
1540        "
1541        )
1542    }
1543
1544    #[test]
1545    fn test_normalize_bitset_or_expression() -> Result<()> {
1546        // a | b <=> b | a
1547        let table_scan = test_table_scan()?;
1548        let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1549        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1550
1551        assert_optimized_plan_equal!(
1552            plan,
1553            @ r"
1554        Projection: test.a, test.b, test.c
1555          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1556            Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1557              TableScan: test
1558        "
1559        )
1560    }
1561
1562    #[test]
1563    fn test_normalize_bitset_xor_expression() -> Result<()> {
1564        // a # b <=> b # a
1565        let table_scan = test_table_scan()?;
1566        let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1567        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1568
1569        assert_optimized_plan_equal!(
1570            plan,
1571            @ r"
1572        Projection: test.a, test.b, test.c
1573          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1574            Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1575              TableScan: test
1576        "
1577        )
1578    }
1579
1580    #[test]
1581    fn test_normalize_eq_expression() -> Result<()> {
1582        // a = b <=> b = a
1583        let table_scan = test_table_scan()?;
1584        let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1585        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1586
1587        assert_optimized_plan_equal!(
1588            plan,
1589            @ r"
1590        Projection: test.a, test.b, test.c
1591          Filter: __common_expr_1 AND __common_expr_1
1592            Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1593              TableScan: test
1594        "
1595        )
1596    }
1597
1598    #[test]
1599    fn test_normalize_ne_expression() -> Result<()> {
1600        // a != b <=> b != a
1601        let table_scan = test_table_scan()?;
1602        let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1603        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1604
1605        assert_optimized_plan_equal!(
1606            plan,
1607            @ r"
1608        Projection: test.a, test.b, test.c
1609          Filter: __common_expr_1 AND __common_expr_1
1610            Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1611              TableScan: test
1612        "
1613        )
1614    }
1615
1616    #[test]
1617    fn test_normalize_complex_expression() -> Result<()> {
1618        // case1: a + b * c <=> b * c + a
1619        let table_scan = test_table_scan()?;
1620        let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1621            .eq(lit(30));
1622        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1623
1624        assert_optimized_plan_equal!(
1625            plan,
1626            @ r"
1627        Projection: test.a, test.b, test.c
1628          Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1629            Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1630              TableScan: test
1631        "
1632        )?;
1633
1634        // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
1635        let table_scan = test_table_scan()?;
1636        let expr = (((col("a") + col("b") / col("c")) * col("c"))
1637            / (col("c") * (col("b") / col("c") + col("a")))
1638            + col("a"))
1639        .eq(lit(30));
1640        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1641
1642        assert_optimized_plan_equal!(
1643            plan,
1644            @ r"
1645        Projection: test.a, test.b, test.c
1646          Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1647            Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1648              TableScan: test
1649        "
1650        )?;
1651
1652        // c2 / (c1 + c3) <=> c2 / (c3 + c1)
1653        let table_scan = test_table_scan()?;
1654        let expr = ((col("b") / (col("a") + col("c")))
1655            * (col("b") / (col("c") + col("a"))))
1656        .eq(lit(30));
1657        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1658        assert_optimized_plan_equal!(
1659            plan,
1660            @ r"
1661        Projection: test.a, test.b, test.c
1662          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1663            Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1664              TableScan: test
1665        "
1666        )?;
1667
1668        Ok(())
1669    }
1670
1671    #[derive(Debug, PartialEq, Eq, Hash)]
1672    pub struct TestUdf {
1673        signature: Signature,
1674    }
1675
1676    impl TestUdf {
1677        pub fn new() -> Self {
1678            Self {
1679                signature: Signature::numeric(1, Volatility::Immutable),
1680            }
1681        }
1682    }
1683
1684    impl ScalarUDFImpl for TestUdf {
1685        fn as_any(&self) -> &dyn Any {
1686            self
1687        }
1688        fn name(&self) -> &str {
1689            "my_udf"
1690        }
1691
1692        fn signature(&self) -> &Signature {
1693            &self.signature
1694        }
1695
1696        fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1697            Ok(DataType::Int32)
1698        }
1699
1700        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1701            panic!("not implemented")
1702        }
1703    }
1704
1705    #[test]
1706    fn test_normalize_inner_binary_expression() -> Result<()> {
1707        // Not(a == b) <=> Not(b == a)
1708        let table_scan = test_table_scan()?;
1709        let expr1 = not(col("a").eq(col("b")));
1710        let expr2 = not(col("b").eq(col("a")));
1711        let plan = LogicalPlanBuilder::from(table_scan)
1712            .project(vec![expr1, expr2])?
1713            .build()?;
1714        assert_optimized_plan_equal!(
1715            plan,
1716            @ r"
1717        Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1718          Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1719            TableScan: test
1720        "
1721        )?;
1722
1723        // is_null(a == b) <=> is_null(b == a)
1724        let table_scan = test_table_scan()?;
1725        let expr1 = is_null(col("a").eq(col("b")));
1726        let expr2 = is_null(col("b").eq(col("a")));
1727        let plan = LogicalPlanBuilder::from(table_scan)
1728            .project(vec![expr1, expr2])?
1729            .build()?;
1730        assert_optimized_plan_equal!(
1731            plan,
1732            @ r"
1733        Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1734          Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1735            TableScan: test
1736        "
1737        )?;
1738
1739        // a + b between 0 and 10 <=> b + a between 0 and 10
1740        let table_scan = test_table_scan()?;
1741        let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1742        let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1743        let plan = LogicalPlanBuilder::from(table_scan)
1744            .project(vec![expr1, expr2])?
1745            .build()?;
1746        assert_optimized_plan_equal!(
1747            plan,
1748            @ r"
1749        Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1750          Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1751            TableScan: test
1752        "
1753        )?;
1754
1755        // c between a + b and 10 <=> c between b + a and 10
1756        let table_scan = test_table_scan()?;
1757        let expr1 = col("c").between(col("a") + col("b"), lit(10));
1758        let expr2 = col("c").between(col("b") + col("a"), lit(10));
1759        let plan = LogicalPlanBuilder::from(table_scan)
1760            .project(vec![expr1, expr2])?
1761            .build()?;
1762        assert_optimized_plan_equal!(
1763            plan,
1764            @ r"
1765        Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1766          Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1767            TableScan: test
1768        "
1769        )?;
1770
1771        // function call with argument <=> function call with argument
1772        let udf = ScalarUDF::from(TestUdf::new());
1773        let table_scan = test_table_scan()?;
1774        let expr1 = udf.call(vec![col("a") + col("b")]);
1775        let expr2 = udf.call(vec![col("b") + col("a")]);
1776        let plan = LogicalPlanBuilder::from(table_scan)
1777            .project(vec![expr1, expr2])?
1778            .build()?;
1779        assert_optimized_plan_equal!(
1780            plan,
1781            @ r"
1782        Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1783          Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1784            TableScan: test
1785        "
1786        )
1787    }
1788
1789    /// returns a "random" function that is marked volatile (aka each invocation
1790    /// returns a different value)
1791    ///
1792    /// Does not use datafusion_functions::rand to avoid introducing a
1793    /// dependency on that crate.
1794    fn rand_func() -> ScalarUDF {
1795        ScalarUDF::new_from_impl(RandomStub::new())
1796    }
1797
1798    #[derive(Debug, PartialEq, Eq, Hash)]
1799    struct RandomStub {
1800        signature: Signature,
1801    }
1802
1803    impl RandomStub {
1804        fn new() -> Self {
1805            Self {
1806                signature: Signature::exact(vec![], Volatility::Volatile),
1807            }
1808        }
1809    }
1810    impl ScalarUDFImpl for RandomStub {
1811        fn as_any(&self) -> &dyn Any {
1812            self
1813        }
1814
1815        fn name(&self) -> &str {
1816            "random"
1817        }
1818
1819        fn signature(&self) -> &Signature {
1820            &self.signature
1821        }
1822
1823        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1824            Ok(DataType::Float64)
1825        }
1826
1827        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1828            panic!("dummy - not implemented")
1829        }
1830    }
1831}