datafusion_optimizer/
common_subexpr_eliminate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions
19
20use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{BinaryExpr, Case, Expr, Operator, SortExpr, col};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41/// Performs Common Sub-expression Elimination optimization.
42///
43/// This optimization improves query performance by computing expressions that
44/// appear more than once and reusing those results rather than re-computing the
45/// same value
46///
47/// Currently only common sub-expressions within a single `LogicalPlan` are
48/// eliminated.
49///
50/// # Example
51///
52/// Given a projection that computes the same expensive expression
53/// multiple times such as parsing as string as a date with `to_date` twice:
54///
55/// ```text
56/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))])
57/// ```
58///
59/// This optimization will rewrite the plan to compute the common expression once
60/// using a new `ProjectionExec` and then rewrite the original expressions to
61/// refer to that new column.
62///
63/// ```text
64/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here
65///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
66/// ```
67#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71    pub fn new() -> Self {
72        Self {}
73    }
74
75    fn try_optimize_proj(
76        &self,
77        projection: Projection,
78        config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        let Projection {
81            expr,
82            input,
83            schema,
84            ..
85        } = projection;
86        let input = Arc::unwrap_or_clone(input);
87        self.try_unary_plan(expr, input, config)?
88            .map_data(|(new_expr, new_input)| {
89                Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90                    .map(LogicalPlan::Projection)
91            })
92    }
93
94    fn try_optimize_sort(
95        &self,
96        sort: Sort,
97        config: &dyn OptimizerConfig,
98    ) -> Result<Transformed<LogicalPlan>> {
99        let Sort { expr, input, fetch } = sort;
100        let input = Arc::unwrap_or_clone(input);
101        let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102            .into_iter()
103            .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104            .unzip();
105        let new_sort = self
106            .try_unary_plan(sort_expressions, input, config)?
107            .update_data(|(new_expr, new_input)| {
108                LogicalPlan::Sort(Sort {
109                    expr: new_expr
110                        .into_iter()
111                        .zip(sort_params)
112                        .map(|(expr, (asc, nulls_first))| SortExpr {
113                            expr,
114                            asc,
115                            nulls_first,
116                        })
117                        .collect(),
118                    input: Arc::new(new_input),
119                    fetch,
120                })
121            });
122        Ok(new_sort)
123    }
124
125    fn try_optimize_filter(
126        &self,
127        filter: Filter,
128        config: &dyn OptimizerConfig,
129    ) -> Result<Transformed<LogicalPlan>> {
130        let Filter {
131            predicate, input, ..
132        } = filter;
133        let input = Arc::unwrap_or_clone(input);
134        let expr = vec![predicate];
135        self.try_unary_plan(expr, input, config)?
136            .map_data(|(mut new_expr, new_input)| {
137                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
138                let new_predicate = new_expr.pop().unwrap();
139                Filter::try_new(new_predicate, Arc::new(new_input))
140                    .map(LogicalPlan::Filter)
141            })
142    }
143
144    fn try_optimize_window(
145        &self,
146        window: Window,
147        config: &dyn OptimizerConfig,
148    ) -> Result<Transformed<LogicalPlan>> {
149        // Collects window expressions from consecutive `LogicalPlan::Window` nodes into
150        // a list.
151        let (window_expr_list, window_schemas, input) =
152            get_consecutive_window_exprs(window);
153
154        // Extract common sub-expressions from the list.
155
156        match CSE::new(ExprCSEController::new(
157            config.alias_generator().as_ref(),
158            ExprMask::Normal,
159        ))
160        .extract_common_nodes(window_expr_list)?
161        {
162            // If there are common sub-expressions, then the insert a projection node
163            // with the common expressions between the new window nodes and the
164            // original input.
165            FoundCommonNodes::Yes {
166                common_nodes: common_exprs,
167                new_nodes_list: new_exprs_list,
168                original_nodes_list: original_exprs_list,
169            } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170                Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171            }),
172            FoundCommonNodes::No {
173                original_nodes_list: original_exprs_list,
174            } => Ok(Transformed::no((original_exprs_list, input, None))),
175        }?
176        // Recurse into the new input.
177        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
178        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179            self.rewrite(new_input, config)?.map_data(|new_input| {
180                Ok((new_window_expr_list, new_input, window_expr_list))
181            })
182        })?
183        // Rebuild the consecutive window nodes.
184        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185            // If there were common expressions extracted, then we need to make sure
186            // we restore the original column names.
187            // TODO: Although `find_common_exprs()` inserts aliases around extracted
188            //  common expressions this doesn't mean that the original column names
189            //  (schema) are preserved due to the inserted aliases are not always at
190            //  the top of the expression.
191            //  Let's consider improving `find_common_exprs()` to always keep column
192            //  names and get rid of additional name preserving logic here.
193            if let Some(window_expr_list) = window_expr_list {
194                let name_preserver = NamePreserver::new_for_projection();
195                let saved_names = window_expr_list
196                    .iter()
197                    .map(|exprs| {
198                        exprs
199                            .iter()
200                            .map(|expr| name_preserver.save(expr))
201                            .collect::<Vec<_>>()
202                    })
203                    .collect::<Vec<_>>();
204                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205                    new_input,
206                    |plan, (new_window_expr, saved_names)| {
207                        let new_window_expr = new_window_expr
208                            .into_iter()
209                            .zip(saved_names)
210                            .map(|(new_window_expr, saved_name)| {
211                                saved_name.restore(new_window_expr)
212                            })
213                            .collect::<Vec<_>>();
214                        Window::try_new(new_window_expr, Arc::new(plan))
215                            .map(LogicalPlan::Window)
216                    },
217                )
218            } else {
219                new_window_expr_list
220                    .into_iter()
221                    .zip(window_schemas)
222                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223                        Window::try_new_with_schema(
224                            new_window_expr,
225                            Arc::new(plan),
226                            schema,
227                        )
228                        .map(LogicalPlan::Window)
229                    })
230            }
231        })
232    }
233
234    fn try_optimize_aggregate(
235        &self,
236        aggregate: Aggregate,
237        config: &dyn OptimizerConfig,
238    ) -> Result<Transformed<LogicalPlan>> {
239        let Aggregate {
240            group_expr,
241            aggr_expr,
242            input,
243            schema,
244            ..
245        } = aggregate;
246        let input = Arc::unwrap_or_clone(input);
247        // Extract common sub-expressions from the aggregate and grouping expressions.
248        match CSE::new(ExprCSEController::new(
249            config.alias_generator().as_ref(),
250            ExprMask::Normal,
251        ))
252        .extract_common_nodes(vec![group_expr, aggr_expr])?
253        {
254            // If there are common sub-expressions, then insert a projection node
255            // with the common expressions between the new aggregate node and the
256            // original input.
257            FoundCommonNodes::Yes {
258                common_nodes: common_exprs,
259                new_nodes_list: mut new_exprs_list,
260                original_nodes_list: mut original_exprs_list,
261            } => {
262                let new_aggr_expr = new_exprs_list.pop().unwrap();
263                let new_group_expr = new_exprs_list.pop().unwrap();
264
265                build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266                    let aggr_expr = original_exprs_list.pop().unwrap();
267                    Transformed::yes((
268                        new_aggr_expr,
269                        new_group_expr,
270                        new_input,
271                        Some(aggr_expr),
272                    ))
273                })
274            }
275
276            FoundCommonNodes::No {
277                original_nodes_list: mut original_exprs_list,
278            } => {
279                let new_aggr_expr = original_exprs_list.pop().unwrap();
280                let new_group_expr = original_exprs_list.pop().unwrap();
281
282                Ok(Transformed::no((
283                    new_aggr_expr,
284                    new_group_expr,
285                    input,
286                    None,
287                )))
288            }
289        }?
290        // Recurse into the new input.
291        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
292        .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293            self.rewrite(new_input, config)?.map_data(|new_input| {
294                Ok((
295                    new_aggr_expr,
296                    new_group_expr,
297                    aggr_expr,
298                    Arc::new(new_input),
299                ))
300            })
301        })?
302        // Try extracting common aggregate expressions and rebuild the aggregate node.
303        .transform_data(
304            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305                // Extract common aggregate sub-expressions from the aggregate expressions.
306                match CSE::new(ExprCSEController::new(
307                    config.alias_generator().as_ref(),
308                    ExprMask::NormalAndAggregates,
309                ))
310                .extract_common_nodes(vec![new_aggr_expr])?
311                {
312                    FoundCommonNodes::Yes {
313                        common_nodes: common_exprs,
314                        new_nodes_list: mut new_exprs_list,
315                        original_nodes_list: mut original_exprs_list,
316                    } => {
317                        let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318                        let new_aggr_expr = original_exprs_list.pop().unwrap();
319                        let saved_names = if let Some(aggr_expr) = aggr_expr {
320                            let name_preserver = NamePreserver::new_for_projection();
321                            aggr_expr
322                                .iter()
323                                .map(|expr| Some(name_preserver.save(expr)))
324                                .collect::<Vec<_>>()
325                        } else {
326                            new_aggr_expr
327                                .clone()
328                                .into_iter()
329                                .map(|_| None)
330                                .collect::<Vec<_>>()
331                        };
332
333                        let mut agg_exprs = common_exprs
334                            .into_iter()
335                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
336                            .collect::<Vec<_>>();
337
338                        let mut proj_exprs = vec![];
339                        for expr in &new_group_expr {
340                            extract_expressions(expr, &mut proj_exprs)
341                        }
342                        for ((expr_rewritten, expr_orig), saved_name) in
343                            rewritten_aggr_expr
344                                .into_iter()
345                                .zip(new_aggr_expr)
346                                .zip(saved_names)
347                        {
348                            if expr_rewritten == expr_orig {
349                                let expr_rewritten = if let Some(saved_name) = saved_name
350                                {
351                                    saved_name.restore(expr_rewritten)
352                                } else {
353                                    expr_rewritten
354                                };
355                                if let Expr::Alias(Alias { expr, name, .. }) =
356                                    expr_rewritten
357                                {
358                                    agg_exprs.push(expr.alias(&name));
359                                    proj_exprs
360                                        .push(Expr::Column(Column::from_name(name)));
361                                } else {
362                                    let expr_alias =
363                                        config.alias_generator().next(CSE_PREFIX);
364                                    let (qualifier, field_name) =
365                                        expr_rewritten.qualified_name();
366                                    let out_name =
367                                        qualified_name(qualifier.as_ref(), &field_name);
368
369                                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
370                                    proj_exprs.push(
371                                        Expr::Column(Column::from_name(expr_alias))
372                                            .alias(out_name),
373                                    );
374                                }
375                            } else {
376                                proj_exprs.push(expr_rewritten);
377                            }
378                        }
379
380                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
381                            new_input,
382                            new_group_expr,
383                            agg_exprs,
384                        )?);
385                        Projection::try_new(proj_exprs, Arc::new(agg))
386                            .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
387                    }
388
389                    // If there aren't any common aggregate sub-expressions, then just
390                    // rebuild the aggregate node.
391                    FoundCommonNodes::No {
392                        original_nodes_list: mut original_exprs_list,
393                    } => {
394                        let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
395
396                        // If there were common expressions extracted, then we need to
397                        // make sure we restore the original column names.
398                        // TODO: Although `find_common_exprs()` inserts aliases around
399                        //  extracted common expressions this doesn't mean that the
400                        //  original column names (schema) are preserved due to the
401                        //  inserted aliases are not always at the top of the
402                        //  expression.
403                        //  Let's consider improving `find_common_exprs()` to always
404                        //  keep column names and get rid of additional name
405                        //  preserving logic here.
406                        if let Some(aggr_expr) = aggr_expr {
407                            let name_preserver = NamePreserver::new_for_projection();
408                            let saved_names = aggr_expr
409                                .iter()
410                                .map(|expr| name_preserver.save(expr))
411                                .collect::<Vec<_>>();
412                            let new_aggr_expr = rewritten_aggr_expr
413                                .into_iter()
414                                .zip(saved_names)
415                                .map(|(new_expr, saved_name)| {
416                                    saved_name.restore(new_expr)
417                                })
418                                .collect::<Vec<Expr>>();
419
420                            // Since `group_expr` may have changed, schema may also.
421                            // Use `try_new()` method.
422                            Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
423                                .map(LogicalPlan::Aggregate)
424                                .map(Transformed::no)
425                        } else {
426                            Aggregate::try_new_with_schema(
427                                new_input,
428                                new_group_expr,
429                                rewritten_aggr_expr,
430                                schema,
431                            )
432                            .map(LogicalPlan::Aggregate)
433                            .map(Transformed::no)
434                        }
435                    }
436                }
437            },
438        )
439    }
440
441    /// Rewrites the expr list and input to remove common subexpressions
442    ///
443    /// # Parameters
444    ///
445    /// * `exprs`: List of expressions in the node
446    /// * `input`: input plan (that produces the columns referred to in `exprs`)
447    ///
448    /// # Return value
449    ///
450    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
451    ///
452    /// 1. The original `input` of no common subexpressions were extracted
453    /// 2. A newly added projection on top of the original input
454    ///    that computes the common subexpressions
455    fn try_unary_plan(
456        &self,
457        exprs: Vec<Expr>,
458        input: LogicalPlan,
459        config: &dyn OptimizerConfig,
460    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
461        // Extract common sub-expressions from the expressions.
462        match CSE::new(ExprCSEController::new(
463            config.alias_generator().as_ref(),
464            ExprMask::Normal,
465        ))
466        .extract_common_nodes(vec![exprs])?
467        {
468            FoundCommonNodes::Yes {
469                common_nodes: common_exprs,
470                new_nodes_list: mut new_exprs_list,
471                original_nodes_list: _,
472            } => {
473                let new_exprs = new_exprs_list.pop().unwrap();
474                build_common_expr_project_plan(input, common_exprs)
475                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
476            }
477            FoundCommonNodes::No {
478                original_nodes_list: mut original_exprs_list,
479            } => {
480                let new_exprs = original_exprs_list.pop().unwrap();
481                Ok(Transformed::no((new_exprs, input)))
482            }
483        }?
484        // Recurse into the new input.
485        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
486        .transform_data(|(new_exprs, new_input)| {
487            self.rewrite(new_input, config)?
488                .map_data(|new_input| Ok((new_exprs, new_input)))
489        })
490    }
491}
492
493/// Get all window expressions inside the consecutive window operators.
494///
495/// Returns the window expressions, and the input to the deepest child
496/// LogicalPlan.
497///
498/// For example, if the input window looks like
499///
500/// ```text
501///   LogicalPlan::Window(exprs=[a, b, c])
502///     LogicalPlan::Window(exprs=[d])
503///       InputPlan
504/// ```
505///
506/// Returns:
507/// *  `window_exprs`: `[[a, b, c], [d]]`
508/// * InputPlan
509///
510/// Consecutive window expressions may refer to same complex expression.
511///
512/// If same complex expression is referred more than once by subsequent
513/// `WindowAggr`s, we can cache complex expression by evaluating it with a
514/// projection before the first WindowAggr.
515///
516/// This enables us to cache complex expression "c3+c4" for following plan:
517///
518/// ```text
519/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
520/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
521/// ```
522///
523/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
524fn get_consecutive_window_exprs(
525    window: Window,
526) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
527    let mut window_expr_list = vec![];
528    let mut window_schemas = vec![];
529    let mut plan = LogicalPlan::Window(window);
530    while let LogicalPlan::Window(Window {
531        input,
532        window_expr,
533        schema,
534    }) = plan
535    {
536        window_expr_list.push(window_expr);
537        window_schemas.push(schema);
538
539        plan = Arc::unwrap_or_clone(input);
540    }
541    (window_expr_list, window_schemas, plan)
542}
543
544impl OptimizerRule for CommonSubexprEliminate {
545    fn supports_rewrite(&self) -> bool {
546        true
547    }
548
549    fn apply_order(&self) -> Option<ApplyOrder> {
550        // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
551        // This is because in some cases adjacent nodes are collected (e.g. `Window`) and
552        // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
553        None
554    }
555
556    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
557    fn rewrite(
558        &self,
559        plan: LogicalPlan,
560        config: &dyn OptimizerConfig,
561    ) -> Result<Transformed<LogicalPlan>> {
562        let original_schema = Arc::clone(plan.schema());
563
564        let optimized_plan = match plan {
565            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
566            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
567            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
568            LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
569            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
570            LogicalPlan::Join(_)
571            | LogicalPlan::Repartition(_)
572            | LogicalPlan::Union(_)
573            | LogicalPlan::TableScan(_)
574            | LogicalPlan::Values(_)
575            | LogicalPlan::EmptyRelation(_)
576            | LogicalPlan::Subquery(_)
577            | LogicalPlan::SubqueryAlias(_)
578            | LogicalPlan::Limit(_)
579            | LogicalPlan::Ddl(_)
580            | LogicalPlan::Explain(_)
581            | LogicalPlan::Analyze(_)
582            | LogicalPlan::Statement(_)
583            | LogicalPlan::DescribeTable(_)
584            | LogicalPlan::Distinct(_)
585            | LogicalPlan::Extension(_)
586            | LogicalPlan::Dml(_)
587            | LogicalPlan::Copy(_)
588            | LogicalPlan::Unnest(_)
589            | LogicalPlan::RecursiveQuery(_) => {
590                // This rule handles recursion itself in a `ApplyOrder::TopDown` like
591                // manner.
592                plan.map_children(|c| self.rewrite(c, config))?
593            }
594        };
595
596        // If we rewrote the plan, ensure the schema stays the same
597        if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
598        {
599            optimized_plan.map_data(|optimized_plan| {
600                build_recover_project_plan(&original_schema, optimized_plan)
601            })
602        } else {
603            Ok(optimized_plan)
604        }
605    }
606
607    fn name(&self) -> &str {
608        "common_sub_expression_eliminate"
609    }
610}
611
612/// Which type of [expressions](Expr) should be considered for rewriting?
613#[derive(Debug, Clone, Copy)]
614enum ExprMask {
615    /// Ignores:
616    ///
617    /// - [`Literal`](Expr::Literal)
618    /// - [`Columns`](Expr::Column)
619    /// - [`ScalarVariable`](Expr::ScalarVariable)
620    /// - [`Alias`](Expr::Alias)
621    /// - [`Wildcard`](Expr::Wildcard)
622    /// - [`AggregateFunction`](Expr::AggregateFunction)
623    Normal,
624
625    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
626    NormalAndAggregates,
627}
628
629struct ExprCSEController<'a> {
630    alias_generator: &'a AliasGenerator,
631    mask: ExprMask,
632
633    // how many aliases have we seen so far
634    alias_counter: usize,
635}
636
637impl<'a> ExprCSEController<'a> {
638    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
639        Self {
640            alias_generator,
641            mask,
642            alias_counter: 0,
643        }
644    }
645}
646
647impl CSEController for ExprCSEController<'_> {
648    type Node = Expr;
649
650    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
651        match node {
652            // In case of `ScalarFunction`s we don't know which children are surely
653            // executed so start visiting all children conditionally and stop the
654            // recursion with `TreeNodeRecursion::Jump`.
655            Expr::ScalarFunction(ScalarFunction { func, args }) => {
656                func.conditional_arguments(args)
657            }
658
659            // In case of `And` and `Or` the first child is surely executed, but we
660            // account subexpressions as conditional in the second.
661            Expr::BinaryExpr(BinaryExpr {
662                left,
663                op: Operator::And | Operator::Or,
664                right,
665            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
666
667            // In case of `Case` the optional base expression and the first when
668            // expressions are surely executed, but we account subexpressions as
669            // conditional in the others.
670            Expr::Case(Case {
671                expr,
672                when_then_expr,
673                else_expr,
674            }) => Some((
675                expr.iter()
676                    .map(|e| e.as_ref())
677                    .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
678                    .collect(),
679                when_then_expr
680                    .iter()
681                    .take(1)
682                    .map(|(_, then)| then.as_ref())
683                    .chain(
684                        when_then_expr
685                            .iter()
686                            .skip(1)
687                            .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
688                    )
689                    .chain(else_expr.iter().map(|e| e.as_ref()))
690                    .collect(),
691            )),
692            _ => None,
693        }
694    }
695
696    fn is_valid(node: &Expr) -> bool {
697        !node.is_volatile_node()
698    }
699
700    fn is_ignored(&self, node: &Expr) -> bool {
701        // TODO: remove the next line after `Expr::Wildcard` is removed
702        #[expect(deprecated)]
703        let is_normal_minus_aggregates = matches!(
704            node,
705            Expr::Literal(..)
706                | Expr::Column(..)
707                | Expr::ScalarVariable(..)
708                | Expr::Alias(..)
709                | Expr::Wildcard { .. }
710        );
711
712        let is_aggr = matches!(node, Expr::AggregateFunction(..));
713
714        match self.mask {
715            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
716            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
717        }
718    }
719
720    fn generate_alias(&self) -> String {
721        self.alias_generator.next(CSE_PREFIX)
722    }
723
724    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
725        // alias the expressions without an `Alias` ancestor node
726        if self.alias_counter > 0 {
727            col(alias)
728        } else {
729            self.alias_counter += 1;
730            col(alias).alias(node.schema_name().to_string())
731        }
732    }
733
734    fn rewrite_f_down(&mut self, node: &Expr) {
735        if matches!(node, Expr::Alias(_)) {
736            self.alias_counter += 1;
737        }
738    }
739    fn rewrite_f_up(&mut self, node: &Expr) {
740        if matches!(node, Expr::Alias(_)) {
741            self.alias_counter -= 1
742        }
743    }
744}
745
746impl Default for CommonSubexprEliminate {
747    fn default() -> Self {
748        Self::new()
749    }
750}
751
752/// Build the "intermediate" projection plan that evaluates the extracted common
753/// expressions.
754///
755/// # Arguments
756/// input: the input plan
757///
758/// common_exprs: which common subexpressions were used (and thus are added to
759/// intermediate projection)
760///
761/// expr_stats: the set of common subexpressions
762fn build_common_expr_project_plan(
763    input: LogicalPlan,
764    common_exprs: Vec<(Expr, String)>,
765) -> Result<LogicalPlan> {
766    let mut fields_set = BTreeSet::new();
767    let mut project_exprs = common_exprs
768        .into_iter()
769        .map(|(expr, expr_alias)| {
770            fields_set.insert(expr_alias.clone());
771            Ok(expr.alias(expr_alias))
772        })
773        .collect::<Result<Vec<_>>>()?;
774
775    for (qualifier, field) in input.schema().iter() {
776        if fields_set.insert(qualified_name(qualifier, field.name())) {
777            project_exprs.push(Expr::from((qualifier, field)));
778        }
779    }
780
781    Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
782}
783
784/// Build the projection plan to eliminate unnecessary columns produced by
785/// the "intermediate" projection plan built in [build_common_expr_project_plan].
786///
787/// This is required to keep the schema the same for plans that pass the input
788/// on to the output, such as `Filter` or `Sort`.
789fn build_recover_project_plan(
790    schema: &DFSchema,
791    input: LogicalPlan,
792) -> Result<LogicalPlan> {
793    let col_exprs = schema.iter().map(Expr::from).collect();
794    Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
795}
796
797fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
798    if let Expr::GroupingSet(groupings) = expr {
799        for e in groupings.distinct_expr() {
800            let (qualifier, field_name) = e.qualified_name();
801            let col = Column::new(qualifier, field_name);
802            result.push(Expr::Column(col))
803        }
804    } else {
805        let (qualifier, field_name) = expr.qualified_name();
806        let col = Column::new(qualifier, field_name);
807        result.push(Expr::Column(col));
808    }
809}
810
811#[cfg(test)]
812mod test {
813    use std::any::Any;
814    use std::iter;
815
816    use arrow::datatypes::{DataType, Field, Schema};
817    use datafusion_expr::logical_plan::{JoinType, table_scan};
818    use datafusion_expr::{
819        AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs,
820        ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
821        grouping_set, is_null, not,
822    };
823    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
824
825    use super::*;
826    use crate::assert_optimized_plan_eq_snapshot;
827    use crate::optimizer::OptimizerContext;
828    use crate::test::*;
829    use datafusion_expr::test::function_stub::{avg, sum};
830
831    macro_rules! assert_optimized_plan_equal {
832        (
833            $config:expr,
834            $plan:expr,
835            @ $expected:literal $(,)?
836        ) => {{
837            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
838            assert_optimized_plan_eq_snapshot!(
839                $config,
840                rules,
841                $plan,
842                @ $expected,
843            )
844        }};
845
846        (
847            $plan:expr,
848            @ $expected:literal $(,)?
849        ) => {{
850            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
851            let optimizer_ctx = OptimizerContext::new();
852            assert_optimized_plan_eq_snapshot!(
853                optimizer_ctx,
854                rules,
855                $plan,
856                @ $expected,
857            )
858        }};
859    }
860
861    #[test]
862    fn tpch_q1_simplified() -> Result<()> {
863        // SQL:
864        //  select
865        //      sum(a * (1 - b)),
866        //      sum(a * (1 - b) * (1 + c))
867        //  from T;
868        //
869        // The manual assembled logical plan don't contains the outermost `Projection`.
870
871        let table_scan = test_table_scan()?;
872
873        let plan = LogicalPlanBuilder::from(table_scan)
874            .aggregate(
875                iter::empty::<Expr>(),
876                vec![
877                    sum(col("a") * (lit(1) - col("b"))),
878                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
879                ],
880            )?
881            .build()?;
882
883        assert_optimized_plan_equal!(
884            plan,
885            @ r"
886        Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
887          Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
888            TableScan: test
889        "
890        )
891    }
892
893    #[test]
894    fn nested_aliases() -> Result<()> {
895        let table_scan = test_table_scan()?;
896
897        let plan = LogicalPlanBuilder::from(table_scan)
898            .project(vec![
899                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
900                col("a") + col("b"),
901            ])?
902            .build()?;
903
904        assert_optimized_plan_equal!(
905            plan,
906            @ r"
907        Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
908          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
909            TableScan: test
910        "
911        )
912    }
913
914    #[test]
915    fn aggregate() -> Result<()> {
916        let table_scan = test_table_scan()?;
917
918        let return_type = DataType::UInt32;
919        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
920        let udf_agg = |inner: Expr| {
921            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
922                Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
923                    "my_agg",
924                    Signature::exact(vec![DataType::UInt32], Volatility::Stable),
925                    return_type.clone(),
926                    Arc::clone(&accumulator),
927                    vec![Field::new("value", DataType::UInt32, true).into()],
928                ))),
929                vec![inner],
930                false,
931                None,
932                vec![],
933                None,
934            ))
935        };
936
937        // test: common aggregates
938        let plan = LogicalPlanBuilder::from(table_scan.clone())
939            .aggregate(
940                iter::empty::<Expr>(),
941                vec![
942                    // common: avg(col("a"))
943                    avg(col("a")).alias("col1"),
944                    avg(col("a")).alias("col2"),
945                    // no common
946                    avg(col("b")).alias("col3"),
947                    avg(col("c")),
948                    // common: udf_agg(col("a"))
949                    udf_agg(col("a")).alias("col4"),
950                    udf_agg(col("a")).alias("col5"),
951                    // no common
952                    udf_agg(col("b")).alias("col6"),
953                    udf_agg(col("c")),
954                ],
955            )?
956            .build()?;
957
958        assert_optimized_plan_equal!(
959            plan,
960            @ r"
961        Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
962          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
963            TableScan: test
964        "
965        )?;
966
967        // test: trafo after aggregate
968        let plan = LogicalPlanBuilder::from(table_scan.clone())
969            .aggregate(
970                iter::empty::<Expr>(),
971                vec![
972                    lit(1) + avg(col("a")),
973                    lit(1) - avg(col("a")),
974                    lit(1) + udf_agg(col("a")),
975                    lit(1) - udf_agg(col("a")),
976                ],
977            )?
978            .build()?;
979
980        assert_optimized_plan_equal!(
981            plan,
982            @ r"
983        Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
984          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
985            TableScan: test
986        "
987        )?;
988
989        // test: transformation before aggregate
990        let plan = LogicalPlanBuilder::from(table_scan.clone())
991            .aggregate(
992                iter::empty::<Expr>(),
993                vec![
994                    avg(lit(1u32) + col("a")).alias("col1"),
995                    udf_agg(lit(1u32) + col("a")).alias("col2"),
996                ],
997            )?
998            .build()?;
999
1000        assert_optimized_plan_equal!(
1001            plan,
1002            @ r"
1003        Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1004          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1005            TableScan: test
1006        "
1007        )?;
1008
1009        // test: common between agg and group
1010        let plan = LogicalPlanBuilder::from(table_scan.clone())
1011            .aggregate(
1012                vec![lit(1u32) + col("a")],
1013                vec![
1014                    avg(lit(1u32) + col("a")).alias("col1"),
1015                    udf_agg(lit(1u32) + col("a")).alias("col2"),
1016                ],
1017            )?
1018            .build()?;
1019
1020        assert_optimized_plan_equal!(
1021            plan,
1022            @ r"
1023        Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1024          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1025            TableScan: test
1026        "
1027        )?;
1028
1029        // test: all mixed
1030        let plan = LogicalPlanBuilder::from(table_scan)
1031            .aggregate(
1032                vec![lit(1u32) + col("a")],
1033                vec![
1034                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1035                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1036                    avg(lit(1u32) + col("a")),
1037                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1038                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1039                    udf_agg(lit(1u32) + col("a")),
1040                ],
1041            )?
1042            .build()?;
1043
1044        assert_optimized_plan_equal!(
1045            plan,
1046            @ r"
1047        Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1048          Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1049            Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1050              TableScan: test
1051        "
1052        )
1053    }
1054
1055    #[test]
1056    fn aggregate_with_relations_and_dots() -> Result<()> {
1057        let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1058        let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1059
1060        let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1061
1062        let plan = LogicalPlanBuilder::from(table_scan)
1063            .aggregate(
1064                vec![col_a.clone()],
1065                vec![
1066                    (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1067                    avg(lit(1u32) + col_a),
1068                ],
1069            )?
1070            .build()?;
1071
1072        assert_optimized_plan_equal!(
1073            plan,
1074            @ r"
1075        Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1076          Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1077            Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1078              TableScan: table.test
1079        "
1080        )
1081    }
1082
1083    #[test]
1084    fn subexpr_in_same_order() -> Result<()> {
1085        let table_scan = test_table_scan()?;
1086
1087        let plan = LogicalPlanBuilder::from(table_scan)
1088            .project(vec![
1089                (lit(1) + col("a")).alias("first"),
1090                (lit(1) + col("a")).alias("second"),
1091            ])?
1092            .build()?;
1093
1094        assert_optimized_plan_equal!(
1095            plan,
1096            @ r"
1097        Projection: __common_expr_1 AS first, __common_expr_1 AS second
1098          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1099            TableScan: test
1100        "
1101        )
1102    }
1103
1104    #[test]
1105    fn subexpr_in_different_order() -> Result<()> {
1106        let table_scan = test_table_scan()?;
1107
1108        let plan = LogicalPlanBuilder::from(table_scan)
1109            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1110            .build()?;
1111
1112        assert_optimized_plan_equal!(
1113            plan,
1114            @ r"
1115        Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1116          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1117            TableScan: test
1118        "
1119        )
1120    }
1121
1122    #[test]
1123    fn cross_plans_subexpr() -> Result<()> {
1124        let table_scan = test_table_scan()?;
1125
1126        let plan = LogicalPlanBuilder::from(table_scan)
1127            .project(vec![lit(1) + col("a"), col("a")])?
1128            .project(vec![lit(1) + col("a")])?
1129            .build()?;
1130
1131        assert_optimized_plan_equal!(
1132            plan,
1133            @ r"
1134        Projection: Int32(1) + test.a
1135          Projection: Int32(1) + test.a, test.a
1136            TableScan: test
1137        "
1138        )
1139    }
1140
1141    #[test]
1142    fn redundant_project_fields() {
1143        let table_scan = test_table_scan().unwrap();
1144        let c_plus_a = col("c") + col("a");
1145        let b_plus_a = col("b") + col("a");
1146        let common_exprs_1 = vec![
1147            (c_plus_a, format!("{CSE_PREFIX}_1")),
1148            (b_plus_a, format!("{CSE_PREFIX}_2")),
1149        ];
1150        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1151        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1152        let common_exprs_2 = vec![
1153            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1154            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1155        ];
1156        let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1157        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1158
1159        let mut field_set = BTreeSet::new();
1160        for name in project_2.schema().field_names() {
1161            assert!(field_set.insert(name));
1162        }
1163    }
1164
1165    #[test]
1166    fn redundant_project_fields_join_input() {
1167        let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1168        let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1169        let join = LogicalPlanBuilder::from(table_scan_1)
1170            .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1171            .unwrap()
1172            .build()
1173            .unwrap();
1174        let c_plus_a = col("test1.c") + col("test1.a");
1175        let b_plus_a = col("test1.b") + col("test1.a");
1176        let common_exprs_1 = vec![
1177            (c_plus_a, format!("{CSE_PREFIX}_1")),
1178            (b_plus_a, format!("{CSE_PREFIX}_2")),
1179        ];
1180        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1181        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1182        let common_exprs_2 = vec![
1183            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1184            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1185        ];
1186        let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1187        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1188
1189        let mut field_set = BTreeSet::new();
1190        for name in project_2.schema().field_names() {
1191            assert!(field_set.insert(name));
1192        }
1193    }
1194
1195    #[test]
1196    fn eliminated_subexpr_datatype() {
1197        use datafusion_expr::cast;
1198
1199        let schema = Schema::new(vec![
1200            Field::new("a", DataType::UInt64, false),
1201            Field::new("b", DataType::UInt64, false),
1202            Field::new("c", DataType::UInt64, false),
1203        ]);
1204
1205        let plan = table_scan(Some("table"), &schema, None)
1206            .unwrap()
1207            .filter(
1208                cast(col("a"), DataType::Int64)
1209                    .lt(lit(1_i64))
1210                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1211            )
1212            .unwrap()
1213            .build()
1214            .unwrap();
1215        let rule = CommonSubexprEliminate::new();
1216        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1217        assert!(optimized_plan.transformed);
1218        let optimized_plan = optimized_plan.data;
1219
1220        let schema = optimized_plan.schema();
1221        let fields_with_datatypes: Vec<_> = schema
1222            .fields()
1223            .iter()
1224            .map(|field| (field.name(), field.data_type()))
1225            .collect();
1226        let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1227        let expected = r#"[
1228    (
1229        "a",
1230        UInt64,
1231    ),
1232    (
1233        "b",
1234        UInt64,
1235    ),
1236    (
1237        "c",
1238        UInt64,
1239    ),
1240]"#;
1241        assert_eq!(expected, formatted_fields_with_datatype);
1242    }
1243
1244    #[test]
1245    fn filter_schema_changed() -> Result<()> {
1246        let table_scan = test_table_scan()?;
1247
1248        let plan = LogicalPlanBuilder::from(table_scan)
1249            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1250            .build()?;
1251
1252        assert_optimized_plan_equal!(
1253            plan,
1254            @ r"
1255        Projection: test.a, test.b, test.c
1256          Filter: __common_expr_1 - Int32(10) > __common_expr_1
1257            Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1258              TableScan: test
1259        "
1260        )
1261    }
1262
1263    #[test]
1264    fn test_extract_expressions_from_grouping_set() -> Result<()> {
1265        let mut result = Vec::with_capacity(3);
1266        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1267        extract_expressions(&grouping, &mut result);
1268
1269        assert!(result.len() == 3);
1270        Ok(())
1271    }
1272
1273    #[test]
1274    fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1275        let mut result = Vec::with_capacity(2);
1276        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1277        extract_expressions(&grouping, &mut result);
1278        assert!(result.len() == 2);
1279        Ok(())
1280    }
1281
1282    #[test]
1283    fn test_alias_collision() -> Result<()> {
1284        let table_scan = test_table_scan()?;
1285
1286        let config = OptimizerContext::new();
1287        let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1288        let plan = LogicalPlanBuilder::from(table_scan.clone())
1289            .project(vec![
1290                (col("a") + col("b")).alias(common_expr_1.clone()),
1291                col("c"),
1292            ])?
1293            .project(vec![
1294                col(common_expr_1.clone()).alias("c1"),
1295                col(common_expr_1).alias("c2"),
1296                (col("c") + lit(2)).alias("c3"),
1297                (col("c") + lit(2)).alias("c4"),
1298            ])?
1299            .build()?;
1300
1301        assert_optimized_plan_equal!(
1302            config,
1303            plan,
1304            @ r"
1305        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1306          Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1307            Projection: test.a + test.b AS __common_expr_1, test.c
1308              TableScan: test
1309        "
1310        )?;
1311
1312        let config = OptimizerContext::new();
1313        let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1314        let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1315        let plan = LogicalPlanBuilder::from(table_scan)
1316            .project(vec![
1317                (col("a") + col("b")).alias(common_expr_2.clone()),
1318                col("c"),
1319            ])?
1320            .project(vec![
1321                col(common_expr_2.clone()).alias("c1"),
1322                col(common_expr_2).alias("c2"),
1323                (col("c") + lit(2)).alias("c3"),
1324                (col("c") + lit(2)).alias("c4"),
1325            ])?
1326            .build()?;
1327
1328        assert_optimized_plan_equal!(
1329            config,
1330            plan,
1331            @ r"
1332        Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1333          Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1334            Projection: test.a + test.b AS __common_expr_2, test.c
1335              TableScan: test
1336        "
1337        )?;
1338
1339        Ok(())
1340    }
1341
1342    #[test]
1343    fn test_extract_expressions_from_col() -> Result<()> {
1344        let mut result = Vec::with_capacity(1);
1345        extract_expressions(&col("a"), &mut result);
1346        assert!(result.len() == 1);
1347        Ok(())
1348    }
1349
1350    #[test]
1351    fn test_short_circuits() -> Result<()> {
1352        let table_scan = test_table_scan()?;
1353
1354        let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1355        let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1356        let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1357        let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1358        let plan = LogicalPlanBuilder::from(table_scan)
1359            .project(vec![
1360                extracted_short_circuit.clone().alias("c1"),
1361                extracted_short_circuit.alias("c2"),
1362                extracted_short_circuit_leg_1
1363                    .clone()
1364                    .or(not_extracted_short_circuit_leg_2.clone())
1365                    .alias("c3"),
1366                extracted_short_circuit_leg_1
1367                    .and(not_extracted_short_circuit_leg_2)
1368                    .alias("c4"),
1369                extracted_short_circuit_leg_3
1370                    .clone()
1371                    .or(extracted_short_circuit_leg_3)
1372                    .alias("c5"),
1373            ])?
1374            .build()?;
1375
1376        assert_optimized_plan_equal!(
1377            plan,
1378            @ r"
1379        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1380          Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1381            TableScan: test
1382        "
1383        )
1384    }
1385
1386    #[test]
1387    fn test_volatile() -> Result<()> {
1388        let table_scan = test_table_scan()?;
1389
1390        let extracted_child = col("a") + col("b");
1391        let rand = rand_func().call(vec![]);
1392        let not_extracted_volatile = extracted_child + rand;
1393        let plan = LogicalPlanBuilder::from(table_scan)
1394            .project(vec![
1395                not_extracted_volatile.clone().alias("c1"),
1396                not_extracted_volatile.alias("c2"),
1397            ])?
1398            .build()?;
1399
1400        assert_optimized_plan_equal!(
1401            plan,
1402            @ r"
1403        Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1404          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1405            TableScan: test
1406        "
1407        )
1408    }
1409
1410    #[test]
1411    fn test_volatile_short_circuits() -> Result<()> {
1412        let table_scan = test_table_scan()?;
1413
1414        let rand = rand_func().call(vec![]);
1415        let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1416        let not_extracted_volatile_short_circuit_1 =
1417            extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1418        let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1419        let not_extracted_volatile_short_circuit_2 =
1420            rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1421        let plan = LogicalPlanBuilder::from(table_scan)
1422            .project(vec![
1423                not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1424                not_extracted_volatile_short_circuit_1.alias("c2"),
1425                not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1426                not_extracted_volatile_short_circuit_2.alias("c4"),
1427            ])?
1428            .build()?;
1429
1430        assert_optimized_plan_equal!(
1431            plan,
1432            @ r"
1433        Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1434          Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1435            TableScan: test
1436        "
1437        )
1438    }
1439
1440    #[test]
1441    fn test_non_top_level_common_expression() -> Result<()> {
1442        let table_scan = test_table_scan()?;
1443
1444        let common_expr = col("a") + col("b");
1445        let plan = LogicalPlanBuilder::from(table_scan)
1446            .project(vec![
1447                common_expr.clone().alias("c1"),
1448                common_expr.alias("c2"),
1449            ])?
1450            .project(vec![col("c1"), col("c2")])?
1451            .build()?;
1452
1453        assert_optimized_plan_equal!(
1454            plan,
1455            @ r"
1456        Projection: c1, c2
1457          Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1458            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1459              TableScan: test
1460        "
1461        )
1462    }
1463
1464    #[test]
1465    fn test_nested_common_expression() -> Result<()> {
1466        let table_scan = test_table_scan()?;
1467
1468        let nested_common_expr = col("a") + col("b");
1469        let common_expr = nested_common_expr.clone() * nested_common_expr;
1470        let plan = LogicalPlanBuilder::from(table_scan)
1471            .project(vec![
1472                common_expr.clone().alias("c1"),
1473                common_expr.alias("c2"),
1474            ])?
1475            .build()?;
1476
1477        assert_optimized_plan_equal!(
1478            plan,
1479            @ r"
1480        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1481          Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1482            Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1483              TableScan: test
1484        "
1485        )
1486    }
1487
1488    #[test]
1489    fn test_normalize_add_expression() -> Result<()> {
1490        // a + b <=> b + a
1491        let table_scan = test_table_scan()?;
1492        let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1493        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1494
1495        assert_optimized_plan_equal!(
1496            plan,
1497            @ r"
1498        Projection: test.a, test.b, test.c
1499          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1500            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1501              TableScan: test
1502        "
1503        )
1504    }
1505
1506    #[test]
1507    fn test_normalize_multi_expression() -> Result<()> {
1508        // a * b <=> b * a
1509        let table_scan = test_table_scan()?;
1510        let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1511        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1512
1513        assert_optimized_plan_equal!(
1514            plan,
1515            @ r"
1516        Projection: test.a, test.b, test.c
1517          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1518            Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1519              TableScan: test
1520        "
1521        )
1522    }
1523
1524    #[test]
1525    fn test_normalize_bitset_and_expression() -> Result<()> {
1526        // a & b <=> b & a
1527        let table_scan = test_table_scan()?;
1528        let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1529        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1530
1531        assert_optimized_plan_equal!(
1532            plan,
1533            @ r"
1534        Projection: test.a, test.b, test.c
1535          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1536            Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1537              TableScan: test
1538        "
1539        )
1540    }
1541
1542    #[test]
1543    fn test_normalize_bitset_or_expression() -> Result<()> {
1544        // a | b <=> b | a
1545        let table_scan = test_table_scan()?;
1546        let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1547        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1548
1549        assert_optimized_plan_equal!(
1550            plan,
1551            @ r"
1552        Projection: test.a, test.b, test.c
1553          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1554            Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1555              TableScan: test
1556        "
1557        )
1558    }
1559
1560    #[test]
1561    fn test_normalize_bitset_xor_expression() -> Result<()> {
1562        // a # b <=> b # a
1563        let table_scan = test_table_scan()?;
1564        let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1565        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1566
1567        assert_optimized_plan_equal!(
1568            plan,
1569            @ r"
1570        Projection: test.a, test.b, test.c
1571          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1572            Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1573              TableScan: test
1574        "
1575        )
1576    }
1577
1578    #[test]
1579    fn test_normalize_eq_expression() -> Result<()> {
1580        // a = b <=> b = a
1581        let table_scan = test_table_scan()?;
1582        let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1583        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1584
1585        assert_optimized_plan_equal!(
1586            plan,
1587            @ r"
1588        Projection: test.a, test.b, test.c
1589          Filter: __common_expr_1 AND __common_expr_1
1590            Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1591              TableScan: test
1592        "
1593        )
1594    }
1595
1596    #[test]
1597    fn test_normalize_ne_expression() -> Result<()> {
1598        // a != b <=> b != a
1599        let table_scan = test_table_scan()?;
1600        let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1601        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1602
1603        assert_optimized_plan_equal!(
1604            plan,
1605            @ r"
1606        Projection: test.a, test.b, test.c
1607          Filter: __common_expr_1 AND __common_expr_1
1608            Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1609              TableScan: test
1610        "
1611        )
1612    }
1613
1614    #[test]
1615    fn test_normalize_complex_expression() -> Result<()> {
1616        // case1: a + b * c <=> b * c + a
1617        let table_scan = test_table_scan()?;
1618        let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1619            .eq(lit(30));
1620        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1621
1622        assert_optimized_plan_equal!(
1623            plan,
1624            @ r"
1625        Projection: test.a, test.b, test.c
1626          Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1627            Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1628              TableScan: test
1629        "
1630        )?;
1631
1632        // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
1633        let table_scan = test_table_scan()?;
1634        let expr = (((col("a") + col("b") / col("c")) * col("c"))
1635            / (col("c") * (col("b") / col("c") + col("a")))
1636            + col("a"))
1637        .eq(lit(30));
1638        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1639
1640        assert_optimized_plan_equal!(
1641            plan,
1642            @ r"
1643        Projection: test.a, test.b, test.c
1644          Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1645            Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1646              TableScan: test
1647        "
1648        )?;
1649
1650        // c2 / (c1 + c3) <=> c2 / (c3 + c1)
1651        let table_scan = test_table_scan()?;
1652        let expr = ((col("b") / (col("a") + col("c")))
1653            * (col("b") / (col("c") + col("a"))))
1654        .eq(lit(30));
1655        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1656        assert_optimized_plan_equal!(
1657            plan,
1658            @ r"
1659        Projection: test.a, test.b, test.c
1660          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1661            Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1662              TableScan: test
1663        "
1664        )?;
1665
1666        Ok(())
1667    }
1668
1669    #[derive(Debug, PartialEq, Eq, Hash)]
1670    pub struct TestUdf {
1671        signature: Signature,
1672    }
1673
1674    impl TestUdf {
1675        pub fn new() -> Self {
1676            Self {
1677                signature: Signature::numeric(1, Volatility::Immutable),
1678            }
1679        }
1680    }
1681
1682    impl ScalarUDFImpl for TestUdf {
1683        fn as_any(&self) -> &dyn Any {
1684            self
1685        }
1686        fn name(&self) -> &str {
1687            "my_udf"
1688        }
1689
1690        fn signature(&self) -> &Signature {
1691            &self.signature
1692        }
1693
1694        fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1695            Ok(DataType::Int32)
1696        }
1697
1698        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1699            panic!("not implemented")
1700        }
1701    }
1702
1703    #[test]
1704    fn test_normalize_inner_binary_expression() -> Result<()> {
1705        // Not(a == b) <=> Not(b == a)
1706        let table_scan = test_table_scan()?;
1707        let expr1 = not(col("a").eq(col("b")));
1708        let expr2 = not(col("b").eq(col("a")));
1709        let plan = LogicalPlanBuilder::from(table_scan)
1710            .project(vec![expr1, expr2])?
1711            .build()?;
1712        assert_optimized_plan_equal!(
1713            plan,
1714            @ r"
1715        Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1716          Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1717            TableScan: test
1718        "
1719        )?;
1720
1721        // is_null(a == b) <=> is_null(b == a)
1722        let table_scan = test_table_scan()?;
1723        let expr1 = is_null(col("a").eq(col("b")));
1724        let expr2 = is_null(col("b").eq(col("a")));
1725        let plan = LogicalPlanBuilder::from(table_scan)
1726            .project(vec![expr1, expr2])?
1727            .build()?;
1728        assert_optimized_plan_equal!(
1729            plan,
1730            @ r"
1731        Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1732          Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1733            TableScan: test
1734        "
1735        )?;
1736
1737        // a + b between 0 and 10 <=> b + a between 0 and 10
1738        let table_scan = test_table_scan()?;
1739        let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1740        let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1741        let plan = LogicalPlanBuilder::from(table_scan)
1742            .project(vec![expr1, expr2])?
1743            .build()?;
1744        assert_optimized_plan_equal!(
1745            plan,
1746            @ r"
1747        Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1748          Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1749            TableScan: test
1750        "
1751        )?;
1752
1753        // c between a + b and 10 <=> c between b + a and 10
1754        let table_scan = test_table_scan()?;
1755        let expr1 = col("c").between(col("a") + col("b"), lit(10));
1756        let expr2 = col("c").between(col("b") + col("a"), lit(10));
1757        let plan = LogicalPlanBuilder::from(table_scan)
1758            .project(vec![expr1, expr2])?
1759            .build()?;
1760        assert_optimized_plan_equal!(
1761            plan,
1762            @ r"
1763        Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1764          Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1765            TableScan: test
1766        "
1767        )?;
1768
1769        // function call with argument <=> function call with argument
1770        let udf = ScalarUDF::from(TestUdf::new());
1771        let table_scan = test_table_scan()?;
1772        let expr1 = udf.call(vec![col("a") + col("b")]);
1773        let expr2 = udf.call(vec![col("b") + col("a")]);
1774        let plan = LogicalPlanBuilder::from(table_scan)
1775            .project(vec![expr1, expr2])?
1776            .build()?;
1777        assert_optimized_plan_equal!(
1778            plan,
1779            @ r"
1780        Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1781          Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1782            TableScan: test
1783        "
1784        )
1785    }
1786
1787    /// returns a "random" function that is marked volatile (aka each invocation
1788    /// returns a different value)
1789    ///
1790    /// Does not use datafusion_functions::rand to avoid introducing a
1791    /// dependency on that crate.
1792    fn rand_func() -> ScalarUDF {
1793        ScalarUDF::new_from_impl(RandomStub::new())
1794    }
1795
1796    #[derive(Debug, PartialEq, Eq, Hash)]
1797    struct RandomStub {
1798        signature: Signature,
1799    }
1800
1801    impl RandomStub {
1802        fn new() -> Self {
1803            Self {
1804                signature: Signature::exact(vec![], Volatility::Volatile),
1805            }
1806        }
1807    }
1808    impl ScalarUDFImpl for RandomStub {
1809        fn as_any(&self) -> &dyn Any {
1810            self
1811        }
1812
1813        fn name(&self) -> &str {
1814            "random"
1815        }
1816
1817        fn signature(&self) -> &Signature {
1818            &self.signature
1819        }
1820
1821        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1822            Ok(DataType::Float64)
1823        }
1824
1825        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1826            panic!("dummy - not implemented")
1827        }
1828    }
1829}