datafusion_optimizer/optimize_projections/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`OptimizeProjections`] identifies and eliminates unused columns
19
20mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28    get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column,
29    DFSchema, HashMap, JoinType, Result,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33    logical_plan::LogicalPlan, Aggregate, Distinct, EmptyRelation, Expr, Projection,
34    TableScan, Unnest, Window,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43/// Optimizer rule to prune unnecessary columns from intermediate schemas
44/// inside the [`LogicalPlan`]. This rule:
45/// - Removes unnecessary columns that do not appear at the output and/or are
46///   not used during any computation step.
47/// - Adds projections to decrease table column size before operators that
48///   benefit from a smaller memory footprint at its input.
49/// - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`].
50///
51/// `OptimizeProjections` is an optimizer rule that identifies and eliminates
52/// columns from a logical plan that are not used by downstream operations.
53/// This can improve query performance and reduce unnecessary data processing.
54///
55/// The rule analyzes the input logical plan, determines the necessary column
56/// indices, and then removes any unnecessary columns. It also removes any
57/// unnecessary projections from the plan tree.
58#[derive(Default, Debug)]
59pub struct OptimizeProjections {}
60
61impl OptimizeProjections {
62    #[allow(missing_docs)]
63    pub fn new() -> Self {
64        Self {}
65    }
66}
67
68impl OptimizerRule for OptimizeProjections {
69    fn name(&self) -> &str {
70        "optimize_projections"
71    }
72
73    fn apply_order(&self) -> Option<ApplyOrder> {
74        None
75    }
76
77    fn supports_rewrite(&self) -> bool {
78        true
79    }
80
81    fn rewrite(
82        &self,
83        plan: LogicalPlan,
84        config: &dyn OptimizerConfig,
85    ) -> Result<Transformed<LogicalPlan>> {
86        // All output fields are necessary:
87        let indices = RequiredIndices::new_for_all_exprs(&plan);
88        optimize_projections(plan, config, indices)
89    }
90}
91
92/// Removes unnecessary columns (e.g. columns that do not appear in the output
93/// schema and/or are not used during any computation step such as expression
94/// evaluation) from the logical plan and its inputs.
95///
96/// # Parameters
97///
98/// - `plan`: A reference to the input `LogicalPlan` to optimize.
99/// - `config`: A reference to the optimizer configuration.
100/// - `indices`: A slice of column indices that represent the necessary column
101///   indices for downstream (parent) plan nodes.
102///
103/// # Returns
104///
105/// A `Result` object with the following semantics:
106///
107/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary
108///   columns.
109/// - `Ok(None)`: Signal that the given logical plan did not require any change.
110/// - `Err(error)`: An error occurred during the optimization process.
111#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
112fn optimize_projections(
113    plan: LogicalPlan,
114    config: &dyn OptimizerConfig,
115    indices: RequiredIndices,
116) -> Result<Transformed<LogicalPlan>> {
117    // Recursively rewrite any nodes that may be able to avoid computation given
118    // their parents' required indices.
119    match plan {
120        LogicalPlan::Projection(proj) => {
121            return merge_consecutive_projections(proj)?.transform_data(|proj| {
122                rewrite_projection_given_requirements(proj, config, &indices)
123            })
124        }
125        LogicalPlan::Aggregate(aggregate) => {
126            // Split parent requirements to GROUP BY and aggregate sections:
127            let n_group_exprs = aggregate.group_expr_len()?;
128            // Offset aggregate indices so that they point to valid indices at
129            // `aggregate.aggr_expr`:
130            let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
131
132            // Get absolutely necessary GROUP BY fields:
133            let group_by_expr_existing = aggregate
134                .group_expr
135                .iter()
136                .map(|group_by_expr| group_by_expr.schema_name().to_string())
137                .collect::<Vec<_>>();
138
139            let new_group_bys = if let Some(simplest_groupby_indices) =
140                get_required_group_by_exprs_indices(
141                    aggregate.input.schema(),
142                    &group_by_expr_existing,
143                ) {
144                // Some of the fields in the GROUP BY may be required by the
145                // parent even if these fields are unnecessary in terms of
146                // functional dependency.
147                group_by_reqs
148                    .append(&simplest_groupby_indices)
149                    .get_at_indices(&aggregate.group_expr)
150            } else {
151                aggregate.group_expr
152            };
153
154            // Only use the absolutely necessary aggregate expressions required
155            // by the parent:
156            let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
157
158            if new_group_bys.is_empty() && new_aggr_expr.is_empty() {
159                // Global aggregation with no aggregate functions always produces 1 row and no columns.
160                return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
161                    EmptyRelation {
162                        produce_one_row: true,
163                        schema: Arc::new(DFSchema::empty()),
164                    },
165                )));
166            }
167
168            let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
169            let schema = aggregate.input.schema();
170            let necessary_indices =
171                RequiredIndices::new().with_exprs(schema, all_exprs_iter);
172            let necessary_exprs = necessary_indices.get_required_exprs(schema);
173
174            return optimize_projections(
175                Arc::unwrap_or_clone(aggregate.input),
176                config,
177                necessary_indices,
178            )?
179            .transform_data(|aggregate_input| {
180                // Simplify the input of the aggregation by adding a projection so
181                // that its input only contains absolutely necessary columns for
182                // the aggregate expressions. Note that necessary_indices refer to
183                // fields in `aggregate.input.schema()`.
184                add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
185            })?
186            .map_data(|aggregate_input| {
187                // Create a new aggregate plan with the updated input and only the
188                // absolutely necessary fields:
189                Aggregate::try_new(
190                    Arc::new(aggregate_input),
191                    new_group_bys,
192                    new_aggr_expr,
193                )
194                .map(LogicalPlan::Aggregate)
195            });
196        }
197        LogicalPlan::Window(window) => {
198            let input_schema = Arc::clone(window.input.schema());
199            // Split parent requirements to child and window expression sections:
200            let n_input_fields = input_schema.fields().len();
201            // Offset window expression indices so that they point to valid
202            // indices at `window.window_expr`:
203            let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
204
205            // Only use window expressions that are absolutely necessary according
206            // to parent requirements:
207            let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
208
209            // Get all the required column indices at the input, either by the
210            // parent or window expression requirements.
211            let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
212
213            return optimize_projections(
214                Arc::unwrap_or_clone(window.input),
215                config,
216                required_indices.clone(),
217            )?
218            .transform_data(|window_child| {
219                if new_window_expr.is_empty() {
220                    // When no window expression is necessary, use the input directly:
221                    Ok(Transformed::no(window_child))
222                } else {
223                    // Calculate required expressions at the input of the window.
224                    // Please note that we use `input_schema`, because `required_indices`
225                    // refers to that schema
226                    let required_exprs =
227                        required_indices.get_required_exprs(&input_schema);
228                    let window_child =
229                        add_projection_on_top_if_helpful(window_child, required_exprs)?
230                            .data;
231                    Window::try_new(new_window_expr, Arc::new(window_child))
232                        .map(LogicalPlan::Window)
233                        .map(Transformed::yes)
234                }
235            });
236        }
237        LogicalPlan::TableScan(table_scan) => {
238            let TableScan {
239                table_name,
240                source,
241                projection,
242                filters,
243                fetch,
244                projected_schema: _,
245            } = table_scan;
246
247            // Get indices referred to in the original (schema with all fields)
248            // given projected indices.
249            let projection = match &projection {
250                Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
251                None => indices.into_inner(),
252            };
253            return TableScan::try_new(
254                table_name,
255                source,
256                Some(projection),
257                filters,
258                fetch,
259            )
260            .map(LogicalPlan::TableScan)
261            .map(Transformed::yes);
262        }
263        // Other node types are handled below
264        _ => {}
265    };
266
267    // For other plan node types, calculate indices for columns they use and
268    // try to rewrite their children
269    let mut child_required_indices: Vec<RequiredIndices> = match &plan {
270        LogicalPlan::Sort(_)
271        | LogicalPlan::Filter(_)
272        | LogicalPlan::Repartition(_)
273        | LogicalPlan::Union(_)
274        | LogicalPlan::SubqueryAlias(_)
275        | LogicalPlan::Distinct(Distinct::On(_)) => {
276            // Pass index requirements from the parent as well as column indices
277            // that appear in this plan's expressions to its child. All these
278            // operators benefit from "small" inputs, so the projection_beneficial
279            // flag is `true`.
280            plan.inputs()
281                .into_iter()
282                .map(|input| {
283                    indices
284                        .clone()
285                        .with_projection_beneficial()
286                        .with_plan_exprs(&plan, input.schema())
287                })
288                .collect::<Result<_>>()?
289        }
290        LogicalPlan::Limit(_) => {
291            // Pass index requirements from the parent as well as column indices
292            // that appear in this plan's expressions to its child. These operators
293            // do not benefit from "small" inputs, so the projection_beneficial
294            // flag is `false`.
295            plan.inputs()
296                .into_iter()
297                .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
298                .collect::<Result<_>>()?
299        }
300        LogicalPlan::Copy(_)
301        | LogicalPlan::Ddl(_)
302        | LogicalPlan::Dml(_)
303        | LogicalPlan::Explain(_)
304        | LogicalPlan::Analyze(_)
305        | LogicalPlan::Subquery(_)
306        | LogicalPlan::Statement(_)
307        | LogicalPlan::Distinct(Distinct::All(_)) => {
308            // These plans require all their fields, and their children should
309            // be treated as final plans -- otherwise, we may have schema a
310            // mismatch.
311            // TODO: For some subquery variants (e.g. a subquery arising from an
312            //       EXISTS expression), we may not need to require all indices.
313            plan.inputs()
314                .into_iter()
315                .map(RequiredIndices::new_for_all_exprs)
316                .collect()
317        }
318        LogicalPlan::Extension(extension) => {
319            let Some(necessary_children_indices) =
320                extension.node.necessary_children_exprs(indices.indices())
321            else {
322                // Requirements from parent cannot be routed down to user defined logical plan safely
323                return Ok(Transformed::no(plan));
324            };
325            let children = extension.node.inputs();
326            if children.len() != necessary_children_indices.len() {
327                return internal_err!("Inconsistent length between children and necessary children indices. \
328                Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \
329                consistent with actual children length for the node.");
330            }
331            children
332                .into_iter()
333                .zip(necessary_children_indices)
334                .map(|(child, necessary_indices)| {
335                    RequiredIndices::new_from_indices(necessary_indices)
336                        .with_plan_exprs(&plan, child.schema())
337                })
338                .collect::<Result<Vec<_>>>()?
339        }
340        LogicalPlan::EmptyRelation(_)
341        | LogicalPlan::RecursiveQuery(_)
342        | LogicalPlan::Values(_)
343        | LogicalPlan::DescribeTable(_) => {
344            // These operators have no inputs, so stop the optimization process.
345            return Ok(Transformed::no(plan));
346        }
347        LogicalPlan::Join(join) => {
348            let left_len = join.left.schema().fields().len();
349            let (left_req_indices, right_req_indices) =
350                split_join_requirements(left_len, indices, &join.join_type);
351            let left_indices =
352                left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
353            let right_indices =
354                right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
355            // Joins benefit from "small" input tables (lower memory usage).
356            // Therefore, each child benefits from projection:
357            vec![
358                left_indices.with_projection_beneficial(),
359                right_indices.with_projection_beneficial(),
360            ]
361        }
362        // these nodes are explicitly rewritten in the match statement above
363        LogicalPlan::Projection(_)
364        | LogicalPlan::Aggregate(_)
365        | LogicalPlan::Window(_)
366        | LogicalPlan::TableScan(_) => {
367            return internal_err!(
368                "OptimizeProjection: should have handled in the match statement above"
369            );
370        }
371        LogicalPlan::Unnest(Unnest {
372            input,
373            dependency_indices,
374            ..
375        }) => {
376            // at least provide the indices for the exec-columns as a starting point
377            let required_indices =
378                RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
379
380            // Add additional required indices from the parent
381            let mut additional_necessary_child_indices = Vec::new();
382            indices.indices().iter().for_each(|idx| {
383                if let Some(index) = dependency_indices.get(*idx) {
384                    additional_necessary_child_indices.push(*index);
385                }
386            });
387            vec![required_indices.append(&additional_necessary_child_indices)]
388        }
389    };
390
391    // Required indices are currently ordered (child0, child1, ...)
392    // but the loop pops off the last element, so we need to reverse the order
393    child_required_indices.reverse();
394    if child_required_indices.len() != plan.inputs().len() {
395        return internal_err!(
396            "OptimizeProjection: child_required_indices length mismatch with plan inputs"
397        );
398    }
399
400    // Rewrite children of the plan
401    let transformed_plan = plan.map_children(|child| {
402        let required_indices = child_required_indices.pop().ok_or_else(|| {
403            internal_datafusion_err!(
404                "Unexpected number of required_indices in OptimizeProjections rule"
405            )
406        })?;
407
408        let projection_beneficial = required_indices.projection_beneficial();
409        let project_exprs = required_indices.get_required_exprs(child.schema());
410
411        optimize_projections(child, config, required_indices)?.transform_data(
412            |new_input| {
413                if projection_beneficial {
414                    add_projection_on_top_if_helpful(new_input, project_exprs)
415                } else {
416                    Ok(Transformed::no(new_input))
417                }
418            },
419        )
420    })?;
421
422    // If any of the children are transformed, we need to potentially update the plan's schema
423    if transformed_plan.transformed {
424        transformed_plan.map_data(|plan| plan.recompute_schema())
425    } else {
426        Ok(transformed_plan)
427    }
428}
429
430/// Merges consecutive projections.
431///
432/// Given a projection `proj`, this function attempts to merge it with a previous
433/// projection if it exists and if merging is beneficial. Merging is considered
434/// beneficial when expressions in the current projection are non-trivial and
435/// appear more than once in its input fields. This can act as a caching mechanism
436/// for non-trivial computations.
437///
438/// # Parameters
439///
440/// * `proj` - A reference to the `Projection` to be merged.
441///
442/// # Returns
443///
444/// A `Result` object with the following semantics:
445///
446/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the
447///   merged projection.
448/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place).
449/// - `Err(error)`: An error occurred during the function call.
450fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
451    let Projection {
452        expr,
453        input,
454        schema,
455        ..
456    } = proj;
457    let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
458        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
459    };
460
461    // A fast path: if the previous projection is same as the current projection
462    // we can directly remove the current projection and return child projection.
463    if prev_projection.expr == expr {
464        return Projection::try_new_with_schema(
465            expr,
466            Arc::clone(&prev_projection.input),
467            schema,
468        )
469        .map(Transformed::yes);
470    }
471
472    // Count usages (referrals) of each projection expression in its input fields:
473    let mut column_referral_map = HashMap::<&Column, usize>::new();
474    expr.iter()
475        .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
476
477    // If an expression is non-trivial and appears more than once, do not merge
478    // them as consecutive projections will benefit from a compute-once approach.
479    // For details, see: https://github.com/apache/datafusion/issues/8296
480    if column_referral_map.into_iter().any(|(col, usage)| {
481        usage > 1
482            && !is_expr_trivial(
483                &prev_projection.expr
484                    [prev_projection.schema.index_of_column(col).unwrap()],
485            )
486    }) {
487        // no change
488        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
489    }
490
491    let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
492        // We know it is a `LogicalPlan::Projection` from check above
493        unreachable!();
494    };
495
496    // Try to rewrite the expressions in the current projection using the
497    // previous projection as input:
498    let name_preserver = NamePreserver::new_for_projection();
499    let mut original_names = vec![];
500    let new_exprs = expr.map_elements(|expr| {
501        original_names.push(name_preserver.save(&expr));
502
503        // do not rewrite top level Aliases (rewriter will remove all aliases within exprs)
504        match expr {
505            Expr::Alias(Alias {
506                expr,
507                relation,
508                name,
509                metadata,
510            }) => rewrite_expr(*expr, &prev_projection).map(|result| {
511                result.update_data(|expr| {
512                    Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata))
513                })
514            }),
515            e => rewrite_expr(e, &prev_projection),
516        }
517    })?;
518
519    // if the expressions could be rewritten, create a new projection with the
520    // new expressions
521    if new_exprs.transformed {
522        // Add any needed aliases back to the expressions
523        let new_exprs = new_exprs
524            .data
525            .into_iter()
526            .zip(original_names)
527            .map(|(expr, original_name)| original_name.restore(expr))
528            .collect::<Vec<_>>();
529        Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
530    } else {
531        // not rewritten, so put the projection back together
532        let input = Arc::new(LogicalPlan::Projection(prev_projection));
533        Projection::try_new_with_schema(new_exprs.data, input, schema)
534            .map(Transformed::no)
535    }
536}
537
538// Check whether `expr` is trivial; i.e. it doesn't imply any computation.
539fn is_expr_trivial(expr: &Expr) -> bool {
540    matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
541}
542
543/// Rewrites a projection expression using the projection before it (i.e. its input)
544/// This is a subroutine to the `merge_consecutive_projections` function.
545///
546/// # Parameters
547///
548/// * `expr` - A reference to the expression to rewrite.
549/// * `input` - A reference to the input of the projection expression (itself
550///   a projection).
551///
552/// # Returns
553///
554/// A `Result` object with the following semantics:
555///
556/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result.
557/// - `Ok(None)`: Signals that `expr` can not be rewritten.
558/// - `Err(error)`: An error occurred during the function call.
559///
560/// # Notes
561/// This rewrite also removes any unnecessary layers of aliasing.
562///
563/// Without trimming, we can end up with unnecessary indirections inside expressions
564/// during projection merges.
565///
566/// Consider:
567///
568/// ```text
569/// Projection(a1 + b1 as sum1)
570/// --Projection(a as a1, b as b1)
571/// ----Source(a, b)
572/// ```
573///
574/// After merge, we want to produce:
575///
576/// ```text
577/// Projection(a + b as sum1)
578/// --Source(a, b)
579/// ```
580///
581/// Without trimming, we would end up with:
582///
583/// ```text
584/// Projection((a as a1 + b as b1) as sum1)
585/// --Source(a, b)
586/// ```
587fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
588    expr.transform_up(|expr| {
589        match expr {
590            //  remove any intermediate aliases if they do not carry metadata
591            Expr::Alias(alias) => {
592                match alias
593                    .metadata
594                    .as_ref()
595                    .map(|h| h.is_empty())
596                    .unwrap_or(true)
597                {
598                    true => Ok(Transformed::yes(*alias.expr)),
599                    false => Ok(Transformed::no(Expr::Alias(alias))),
600                }
601            }
602            Expr::Column(col) => {
603                // Find index of column:
604                let idx = input.schema.index_of_column(&col)?;
605                // get the corresponding unaliased input expression
606                //
607                // For example:
608                // * the input projection is [`a + b` as c, `d + e` as f]
609                // * the current column is an expression "f"
610                //
611                // return the expression `d + e` (not `d + e` as f)
612                let input_expr = input.expr[idx].clone().unalias_nested().data;
613                Ok(Transformed::yes(input_expr))
614            }
615            // Unsupported type for consecutive projection merge analysis.
616            _ => Ok(Transformed::no(expr)),
617        }
618    })
619}
620
621/// Accumulates outer-referenced columns by the
622/// given expression, `expr`.
623///
624/// # Parameters
625///
626/// * `expr` - The expression to analyze for outer-referenced columns.
627/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
628///   columns are collected.
629fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
630    // inspect_expr_pre doesn't handle subquery references, so find them explicitly
631    expr.apply(|expr| {
632        match expr {
633            Expr::OuterReferenceColumn(_, col) => {
634                columns.insert(col);
635            }
636            Expr::ScalarSubquery(subquery) => {
637                outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
638            }
639            Expr::Exists(exists) => {
640                outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
641            }
642            Expr::InSubquery(insubquery) => {
643                outer_columns_helper_multi(
644                    &insubquery.subquery.outer_ref_columns,
645                    columns,
646                );
647            }
648            _ => {}
649        };
650        Ok(TreeNodeRecursion::Continue)
651    })
652    // unwrap: closure above never returns Err, so can not be Err here
653    .unwrap();
654}
655
656/// A recursive subroutine that accumulates outer-referenced columns by the
657/// given expressions (`exprs`).
658///
659/// # Parameters
660///
661/// * `exprs` - The expressions to analyze for outer-referenced columns.
662/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
663///   columns are collected.
664fn outer_columns_helper_multi<'a, 'b>(
665    exprs: impl IntoIterator<Item = &'a Expr>,
666    columns: &'b mut HashSet<&'a Column>,
667) {
668    exprs.into_iter().for_each(|e| outer_columns(e, columns));
669}
670
671/// Splits requirement indices for a join into left and right children based on
672/// the join type.
673///
674/// This function takes the length of the left child, a slice of requirement
675/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments.
676/// Depending on the join type, it divides the requirement indices into those
677/// that apply to the left child and those that apply to the right child.
678///
679/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins,
680///   the requirements are split between left and right children. The right
681///   child indices are adjusted to point to valid positions within the right
682///   child by subtracting the length of the left child.
683///
684/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all
685///   requirements are re-routed to either the left child or the right child
686///   directly, depending on the join type.
687///
688/// # Parameters
689///
690/// * `left_len` - The length of the left child.
691/// * `indices` - A slice of requirement indices.
692/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`).
693///
694/// # Returns
695///
696/// A tuple containing two vectors of `usize` indices: The first vector represents
697/// the requirements for the left child, and the second vector represents the
698/// requirements for the right child. The indices are appropriately split and
699/// adjusted based on the join type.
700fn split_join_requirements(
701    left_len: usize,
702    indices: RequiredIndices,
703    join_type: &JoinType,
704) -> (RequiredIndices, RequiredIndices) {
705    match join_type {
706        // In these cases requirements are split between left/right children:
707        JoinType::Inner
708        | JoinType::Left
709        | JoinType::Right
710        | JoinType::Full
711        | JoinType::LeftMark
712        | JoinType::RightMark => {
713            // Decrease right side indices by `left_len` so that they point to valid
714            // positions within the right child:
715            indices.split_off(left_len)
716        }
717        // All requirements can be re-routed to left child directly.
718        JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
719        // All requirements can be re-routed to right side directly.
720        // No need to change index, join schema is right child schema.
721        JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
722    }
723}
724
725/// Adds a projection on top of a logical plan if doing so reduces the number
726/// of columns for the parent operator.
727///
728/// This function takes a `LogicalPlan` and a list of projection expressions.
729/// If the projection is beneficial (it reduces the number of columns in the
730/// plan) a new `LogicalPlan` with the projection is created and returned, along
731/// with a `true` flag. If the projection doesn't reduce the number of columns,
732/// the original plan is returned with a `false` flag.
733///
734/// # Parameters
735///
736/// * `plan` - The input `LogicalPlan` to potentially add a projection to.
737/// * `project_exprs` - A list of expressions for the projection.
738///
739/// # Returns
740///
741/// A `Transformed` indicating if a projection was added
742fn add_projection_on_top_if_helpful(
743    plan: LogicalPlan,
744    project_exprs: Vec<Expr>,
745) -> Result<Transformed<LogicalPlan>> {
746    // Make sure projection decreases the number of columns, otherwise it is unnecessary.
747    if project_exprs.len() >= plan.schema().fields().len() {
748        Ok(Transformed::no(plan))
749    } else {
750        Projection::try_new(project_exprs, Arc::new(plan))
751            .map(LogicalPlan::Projection)
752            .map(Transformed::yes)
753    }
754}
755
756/// Rewrite the given projection according to the fields required by its
757/// ancestors.
758///
759/// # Parameters
760///
761/// * `proj` - A reference to the original projection to rewrite.
762/// * `config` - A reference to the optimizer configuration.
763/// * `indices` - A slice of indices representing the columns required by the
764///   ancestors of the given projection.
765///
766/// # Returns
767///
768/// A `Result` object with the following semantics:
769///
770/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection
771/// - `Ok(None)`: No rewrite necessary.
772/// - `Err(error)`: An error occurred during the function call.
773fn rewrite_projection_given_requirements(
774    proj: Projection,
775    config: &dyn OptimizerConfig,
776    indices: &RequiredIndices,
777) -> Result<Transformed<LogicalPlan>> {
778    let Projection { expr, input, .. } = proj;
779
780    let exprs_used = indices.get_at_indices(&expr);
781
782    let required_indices =
783        RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
784
785    // rewrite the children projection, and if they are changed rewrite the
786    // projection down
787    optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
788        .transform_data(|input| {
789            if is_projection_unnecessary(&input, &exprs_used)? {
790                Ok(Transformed::yes(input))
791            } else {
792                Projection::try_new(exprs_used, Arc::new(input))
793                    .map(LogicalPlan::Projection)
794                    .map(Transformed::yes)
795            }
796        })
797}
798
799/// Projection is unnecessary, when
800/// - input schema of the projection, output schema of the projection are same, and
801/// - all projection expressions are either Column or Literal
802pub fn is_projection_unnecessary(
803    input: &LogicalPlan,
804    proj_exprs: &[Expr],
805) -> Result<bool> {
806    // First check if the number of expressions is equal to the number of fields in the input schema.
807    if proj_exprs.len() != input.schema().fields().len() {
808        return Ok(false);
809    }
810    Ok(input.schema().iter().zip(proj_exprs.iter()).all(
811        |((field_relation, field_name), expr)| {
812            // Check if the expression is a column and if it matches the field name
813            if let Expr::Column(col) = expr {
814                col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
815            } else {
816                false
817            }
818        },
819    ))
820}
821
822#[cfg(test)]
823mod tests {
824    use std::cmp::Ordering;
825    use std::collections::HashMap;
826    use std::fmt::Formatter;
827    use std::ops::Add;
828    use std::sync::Arc;
829    use std::vec;
830
831    use crate::optimize_projections::OptimizeProjections;
832    use crate::optimizer::Optimizer;
833    use crate::test::{
834        assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
835        test_table_scan_with_name,
836    };
837    use crate::{OptimizerContext, OptimizerRule};
838    use arrow::datatypes::{DataType, Field, Schema};
839    use datafusion_common::{
840        Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
841    };
842    use datafusion_expr::ExprFunctionExt;
843    use datafusion_expr::{
844        binary_expr, build_join_schema,
845        builder::table_scan_with_filters,
846        col,
847        expr::{self, Cast},
848        lit,
849        logical_plan::{builder::LogicalPlanBuilder, table_scan},
850        not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
851        Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition,
852    };
853    use insta::assert_snapshot;
854
855    use crate::assert_optimized_plan_eq_snapshot;
856    use datafusion_functions_aggregate::count::count_udaf;
857    use datafusion_functions_aggregate::expr_fn::{count, max, min};
858    use datafusion_functions_aggregate::min_max::max_udaf;
859
860    macro_rules! assert_optimized_plan_equal {
861        (
862            $plan:expr,
863            @ $expected:literal $(,)?
864        ) => {{
865            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
866            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
867            assert_optimized_plan_eq_snapshot!(
868                optimizer_ctx,
869                rules,
870                $plan,
871                @ $expected,
872            )
873        }};
874    }
875
876    #[derive(Debug, Hash, PartialEq, Eq)]
877    struct NoOpUserDefined {
878        exprs: Vec<Expr>,
879        schema: DFSchemaRef,
880        input: Arc<LogicalPlan>,
881    }
882
883    impl NoOpUserDefined {
884        fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
885            Self {
886                exprs: vec![],
887                schema,
888                input,
889            }
890        }
891
892        fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
893            self.exprs = exprs;
894            self
895        }
896    }
897
898    // Manual implementation needed because of `schema` field. Comparison excludes this field.
899    impl PartialOrd for NoOpUserDefined {
900        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
901            match self.exprs.partial_cmp(&other.exprs) {
902                Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
903                cmp => cmp,
904            }
905        }
906    }
907
908    impl UserDefinedLogicalNodeCore for NoOpUserDefined {
909        fn name(&self) -> &str {
910            "NoOpUserDefined"
911        }
912
913        fn inputs(&self) -> Vec<&LogicalPlan> {
914            vec![&self.input]
915        }
916
917        fn schema(&self) -> &DFSchemaRef {
918            &self.schema
919        }
920
921        fn expressions(&self) -> Vec<Expr> {
922            self.exprs.clone()
923        }
924
925        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
926            write!(f, "NoOpUserDefined")
927        }
928
929        fn with_exprs_and_inputs(
930            &self,
931            exprs: Vec<Expr>,
932            mut inputs: Vec<LogicalPlan>,
933        ) -> Result<Self> {
934            Ok(Self {
935                exprs,
936                input: Arc::new(inputs.swap_remove(0)),
937                schema: Arc::clone(&self.schema),
938            })
939        }
940
941        fn necessary_children_exprs(
942            &self,
943            output_columns: &[usize],
944        ) -> Option<Vec<Vec<usize>>> {
945            // Since schema is same. Output columns requires their corresponding version in the input columns.
946            Some(vec![output_columns.to_vec()])
947        }
948
949        fn supports_limit_pushdown(&self) -> bool {
950            false // Disallow limit push-down by default
951        }
952    }
953
954    #[derive(Debug, Hash, PartialEq, Eq)]
955    struct UserDefinedCrossJoin {
956        exprs: Vec<Expr>,
957        schema: DFSchemaRef,
958        left_child: Arc<LogicalPlan>,
959        right_child: Arc<LogicalPlan>,
960    }
961
962    impl UserDefinedCrossJoin {
963        fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
964            let left_schema = left_child.schema();
965            let right_schema = right_child.schema();
966            let schema = Arc::new(
967                build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
968            );
969            Self {
970                exprs: vec![],
971                schema,
972                left_child,
973                right_child,
974            }
975        }
976    }
977
978    // Manual implementation needed because of `schema` field. Comparison excludes this field.
979    impl PartialOrd for UserDefinedCrossJoin {
980        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
981            match self.exprs.partial_cmp(&other.exprs) {
982                Some(Ordering::Equal) => {
983                    match self.left_child.partial_cmp(&other.left_child) {
984                        Some(Ordering::Equal) => {
985                            self.right_child.partial_cmp(&other.right_child)
986                        }
987                        cmp => cmp,
988                    }
989                }
990                cmp => cmp,
991            }
992        }
993    }
994
995    impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
996        fn name(&self) -> &str {
997            "UserDefinedCrossJoin"
998        }
999
1000        fn inputs(&self) -> Vec<&LogicalPlan> {
1001            vec![&self.left_child, &self.right_child]
1002        }
1003
1004        fn schema(&self) -> &DFSchemaRef {
1005            &self.schema
1006        }
1007
1008        fn expressions(&self) -> Vec<Expr> {
1009            self.exprs.clone()
1010        }
1011
1012        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1013            write!(f, "UserDefinedCrossJoin")
1014        }
1015
1016        fn with_exprs_and_inputs(
1017            &self,
1018            exprs: Vec<Expr>,
1019            mut inputs: Vec<LogicalPlan>,
1020        ) -> Result<Self> {
1021            assert_eq!(inputs.len(), 2);
1022            Ok(Self {
1023                exprs,
1024                left_child: Arc::new(inputs.remove(0)),
1025                right_child: Arc::new(inputs.remove(0)),
1026                schema: Arc::clone(&self.schema),
1027            })
1028        }
1029
1030        fn necessary_children_exprs(
1031            &self,
1032            output_columns: &[usize],
1033        ) -> Option<Vec<Vec<usize>>> {
1034            let left_child_len = self.left_child.schema().fields().len();
1035            let mut left_reqs = vec![];
1036            let mut right_reqs = vec![];
1037            for &out_idx in output_columns {
1038                if out_idx < left_child_len {
1039                    left_reqs.push(out_idx);
1040                } else {
1041                    // Output indices further than the left_child_len
1042                    // comes from right children
1043                    right_reqs.push(out_idx - left_child_len)
1044                }
1045            }
1046            Some(vec![left_reqs, right_reqs])
1047        }
1048
1049        fn supports_limit_pushdown(&self) -> bool {
1050            false // Disallow limit push-down by default
1051        }
1052    }
1053
1054    #[test]
1055    fn merge_two_projection() -> Result<()> {
1056        let table_scan = test_table_scan()?;
1057        let plan = LogicalPlanBuilder::from(table_scan)
1058            .project(vec![col("a")])?
1059            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1060            .build()?;
1061
1062        assert_optimized_plan_equal!(
1063            plan,
1064            @r"
1065        Projection: Int32(1) + test.a
1066          TableScan: test projection=[a]
1067        "
1068        )
1069    }
1070
1071    #[test]
1072    fn merge_three_projection() -> Result<()> {
1073        let table_scan = test_table_scan()?;
1074        let plan = LogicalPlanBuilder::from(table_scan)
1075            .project(vec![col("a"), col("b")])?
1076            .project(vec![col("a")])?
1077            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1078            .build()?;
1079
1080        assert_optimized_plan_equal!(
1081            plan,
1082            @r"
1083        Projection: Int32(1) + test.a
1084          TableScan: test projection=[a]
1085        "
1086        )
1087    }
1088
1089    #[test]
1090    fn merge_alias() -> Result<()> {
1091        let table_scan = test_table_scan()?;
1092        let plan = LogicalPlanBuilder::from(table_scan)
1093            .project(vec![col("a")])?
1094            .project(vec![col("a").alias("alias")])?
1095            .build()?;
1096
1097        assert_optimized_plan_equal!(
1098            plan,
1099            @r"
1100        Projection: test.a AS alias
1101          TableScan: test projection=[a]
1102        "
1103        )
1104    }
1105
1106    #[test]
1107    fn merge_nested_alias() -> Result<()> {
1108        let table_scan = test_table_scan()?;
1109        let plan = LogicalPlanBuilder::from(table_scan)
1110            .project(vec![col("a").alias("alias1").alias("alias2")])?
1111            .project(vec![col("alias2").alias("alias")])?
1112            .build()?;
1113
1114        assert_optimized_plan_equal!(
1115            plan,
1116            @r"
1117        Projection: test.a AS alias
1118          TableScan: test projection=[a]
1119        "
1120        )
1121    }
1122
1123    #[test]
1124    fn test_nested_count() -> Result<()> {
1125        let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1126
1127        let groups: Vec<Expr> = vec![];
1128
1129        let plan = table_scan(TableReference::none(), &schema, None)
1130            .unwrap()
1131            .aggregate(groups.clone(), vec![count(lit(1))])
1132            .unwrap()
1133            .aggregate(groups, vec![count(lit(1))])
1134            .unwrap()
1135            .build()
1136            .unwrap();
1137
1138        assert_optimized_plan_equal!(
1139            plan,
1140            @r"
1141        Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1142          EmptyRelation: rows=1
1143        "
1144        )
1145    }
1146
1147    #[test]
1148    fn test_neg_push_down() -> Result<()> {
1149        let table_scan = test_table_scan()?;
1150        let plan = LogicalPlanBuilder::from(table_scan)
1151            .project(vec![-col("a")])?
1152            .build()?;
1153
1154        assert_optimized_plan_equal!(
1155            plan,
1156            @r"
1157        Projection: (- test.a)
1158          TableScan: test projection=[a]
1159        "
1160        )
1161    }
1162
1163    #[test]
1164    fn test_is_null() -> Result<()> {
1165        let table_scan = test_table_scan()?;
1166        let plan = LogicalPlanBuilder::from(table_scan)
1167            .project(vec![col("a").is_null()])?
1168            .build()?;
1169
1170        assert_optimized_plan_equal!(
1171            plan,
1172            @r"
1173        Projection: test.a IS NULL
1174          TableScan: test projection=[a]
1175        "
1176        )
1177    }
1178
1179    #[test]
1180    fn test_is_not_null() -> Result<()> {
1181        let table_scan = test_table_scan()?;
1182        let plan = LogicalPlanBuilder::from(table_scan)
1183            .project(vec![col("a").is_not_null()])?
1184            .build()?;
1185
1186        assert_optimized_plan_equal!(
1187            plan,
1188            @r"
1189        Projection: test.a IS NOT NULL
1190          TableScan: test projection=[a]
1191        "
1192        )
1193    }
1194
1195    #[test]
1196    fn test_is_true() -> Result<()> {
1197        let table_scan = test_table_scan()?;
1198        let plan = LogicalPlanBuilder::from(table_scan)
1199            .project(vec![col("a").is_true()])?
1200            .build()?;
1201
1202        assert_optimized_plan_equal!(
1203            plan,
1204            @r"
1205        Projection: test.a IS TRUE
1206          TableScan: test projection=[a]
1207        "
1208        )
1209    }
1210
1211    #[test]
1212    fn test_is_not_true() -> Result<()> {
1213        let table_scan = test_table_scan()?;
1214        let plan = LogicalPlanBuilder::from(table_scan)
1215            .project(vec![col("a").is_not_true()])?
1216            .build()?;
1217
1218        assert_optimized_plan_equal!(
1219            plan,
1220            @r"
1221        Projection: test.a IS NOT TRUE
1222          TableScan: test projection=[a]
1223        "
1224        )
1225    }
1226
1227    #[test]
1228    fn test_is_false() -> Result<()> {
1229        let table_scan = test_table_scan()?;
1230        let plan = LogicalPlanBuilder::from(table_scan)
1231            .project(vec![col("a").is_false()])?
1232            .build()?;
1233
1234        assert_optimized_plan_equal!(
1235            plan,
1236            @r"
1237        Projection: test.a IS FALSE
1238          TableScan: test projection=[a]
1239        "
1240        )
1241    }
1242
1243    #[test]
1244    fn test_is_not_false() -> Result<()> {
1245        let table_scan = test_table_scan()?;
1246        let plan = LogicalPlanBuilder::from(table_scan)
1247            .project(vec![col("a").is_not_false()])?
1248            .build()?;
1249
1250        assert_optimized_plan_equal!(
1251            plan,
1252            @r"
1253        Projection: test.a IS NOT FALSE
1254          TableScan: test projection=[a]
1255        "
1256        )
1257    }
1258
1259    #[test]
1260    fn test_is_unknown() -> Result<()> {
1261        let table_scan = test_table_scan()?;
1262        let plan = LogicalPlanBuilder::from(table_scan)
1263            .project(vec![col("a").is_unknown()])?
1264            .build()?;
1265
1266        assert_optimized_plan_equal!(
1267            plan,
1268            @r"
1269        Projection: test.a IS UNKNOWN
1270          TableScan: test projection=[a]
1271        "
1272        )
1273    }
1274
1275    #[test]
1276    fn test_is_not_unknown() -> Result<()> {
1277        let table_scan = test_table_scan()?;
1278        let plan = LogicalPlanBuilder::from(table_scan)
1279            .project(vec![col("a").is_not_unknown()])?
1280            .build()?;
1281
1282        assert_optimized_plan_equal!(
1283            plan,
1284            @r"
1285        Projection: test.a IS NOT UNKNOWN
1286          TableScan: test projection=[a]
1287        "
1288        )
1289    }
1290
1291    #[test]
1292    fn test_not() -> Result<()> {
1293        let table_scan = test_table_scan()?;
1294        let plan = LogicalPlanBuilder::from(table_scan)
1295            .project(vec![not(col("a"))])?
1296            .build()?;
1297
1298        assert_optimized_plan_equal!(
1299            plan,
1300            @r"
1301        Projection: NOT test.a
1302          TableScan: test projection=[a]
1303        "
1304        )
1305    }
1306
1307    #[test]
1308    fn test_try_cast() -> Result<()> {
1309        let table_scan = test_table_scan()?;
1310        let plan = LogicalPlanBuilder::from(table_scan)
1311            .project(vec![try_cast(col("a"), DataType::Float64)])?
1312            .build()?;
1313
1314        assert_optimized_plan_equal!(
1315            plan,
1316            @r"
1317        Projection: TRY_CAST(test.a AS Float64)
1318          TableScan: test projection=[a]
1319        "
1320        )
1321    }
1322
1323    #[test]
1324    fn test_similar_to() -> Result<()> {
1325        let table_scan = test_table_scan()?;
1326        let expr = Box::new(col("a"));
1327        let pattern = Box::new(lit("[0-9]"));
1328        let similar_to_expr =
1329            Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1330        let plan = LogicalPlanBuilder::from(table_scan)
1331            .project(vec![similar_to_expr])?
1332            .build()?;
1333
1334        assert_optimized_plan_equal!(
1335            plan,
1336            @r#"
1337        Projection: test.a SIMILAR TO Utf8("[0-9]")
1338          TableScan: test projection=[a]
1339        "#
1340        )
1341    }
1342
1343    #[test]
1344    fn test_between() -> Result<()> {
1345        let table_scan = test_table_scan()?;
1346        let plan = LogicalPlanBuilder::from(table_scan)
1347            .project(vec![col("a").between(lit(1), lit(3))])?
1348            .build()?;
1349
1350        assert_optimized_plan_equal!(
1351            plan,
1352            @r"
1353        Projection: test.a BETWEEN Int32(1) AND Int32(3)
1354          TableScan: test projection=[a]
1355        "
1356        )
1357    }
1358
1359    // Test Case expression
1360    #[test]
1361    fn test_case_merged() -> Result<()> {
1362        let table_scan = test_table_scan()?;
1363        let plan = LogicalPlanBuilder::from(table_scan)
1364            .project(vec![col("a"), lit(0).alias("d")])?
1365            .project(vec![
1366                col("a"),
1367                when(col("a").eq(lit(1)), lit(10))
1368                    .otherwise(col("d"))?
1369                    .alias("d"),
1370            ])?
1371            .build()?;
1372
1373        assert_optimized_plan_equal!(
1374            plan,
1375            @r"
1376        Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1377          TableScan: test projection=[a]
1378        "
1379        )
1380    }
1381
1382    // Test outer projection isn't discarded despite the same schema as inner
1383    // https://github.com/apache/datafusion/issues/8942
1384    #[test]
1385    fn test_derived_column() -> Result<()> {
1386        let table_scan = test_table_scan()?;
1387        let plan = LogicalPlanBuilder::from(table_scan)
1388            .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1389            .project(vec![
1390                col("a"),
1391                when(col("a").eq(lit(1)), lit(10))
1392                    .otherwise(col("d"))?
1393                    .alias("d"),
1394            ])?
1395            .build()?;
1396
1397        assert_optimized_plan_equal!(
1398            plan,
1399            @r"
1400        Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1401          Projection: test.a + Int32(1) AS a, Int32(0) AS d
1402            TableScan: test projection=[a]
1403        "
1404        )
1405    }
1406
1407    // Since only column `a` is referred at the output. Scan should only contain projection=[a].
1408    // User defined node should be able to propagate necessary expressions by its parent to its child.
1409    #[test]
1410    fn test_user_defined_logical_plan_node() -> Result<()> {
1411        let table_scan = test_table_scan()?;
1412        let custom_plan = LogicalPlan::Extension(Extension {
1413            node: Arc::new(NoOpUserDefined::new(
1414                Arc::clone(table_scan.schema()),
1415                Arc::new(table_scan.clone()),
1416            )),
1417        });
1418        let plan = LogicalPlanBuilder::from(custom_plan)
1419            .project(vec![col("a"), lit(0).alias("d")])?
1420            .build()?;
1421
1422        assert_optimized_plan_equal!(
1423            plan,
1424            @r"
1425        Projection: test.a, Int32(0) AS d
1426          NoOpUserDefined
1427            TableScan: test projection=[a]
1428        "
1429        )
1430    }
1431
1432    // Only column `a` is referred at the output. However, User defined node itself uses column `b`
1433    // during its operation. Hence, scan should contain projection=[a, b].
1434    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1435    // required expressions.
1436    #[test]
1437    fn test_user_defined_logical_plan_node2() -> Result<()> {
1438        let table_scan = test_table_scan()?;
1439        let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1440        let custom_plan = LogicalPlan::Extension(Extension {
1441            node: Arc::new(
1442                NoOpUserDefined::new(
1443                    Arc::clone(table_scan.schema()),
1444                    Arc::new(table_scan.clone()),
1445                )
1446                .with_exprs(exprs),
1447            ),
1448        });
1449        let plan = LogicalPlanBuilder::from(custom_plan)
1450            .project(vec![col("a"), lit(0).alias("d")])?
1451            .build()?;
1452
1453        assert_optimized_plan_equal!(
1454            plan,
1455            @r"
1456        Projection: test.a, Int32(0) AS d
1457          NoOpUserDefined
1458            TableScan: test projection=[a, b]
1459        "
1460        )
1461    }
1462
1463    // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c`
1464    // during its operation. Hence, scan should contain projection=[a, b, c].
1465    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1466    // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions
1467    // should be propagated also.
1468    #[test]
1469    fn test_user_defined_logical_plan_node3() -> Result<()> {
1470        let table_scan = test_table_scan()?;
1471        let left_expr = Expr::Column(Column::from_qualified_name("b"));
1472        let right_expr = Expr::Column(Column::from_qualified_name("c"));
1473        let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1474            Box::new(left_expr),
1475            Operator::Plus,
1476            Box::new(right_expr),
1477        ));
1478        let exprs = vec![binary_expr];
1479        let custom_plan = LogicalPlan::Extension(Extension {
1480            node: Arc::new(
1481                NoOpUserDefined::new(
1482                    Arc::clone(table_scan.schema()),
1483                    Arc::new(table_scan.clone()),
1484                )
1485                .with_exprs(exprs),
1486            ),
1487        });
1488        let plan = LogicalPlanBuilder::from(custom_plan)
1489            .project(vec![col("a"), lit(0).alias("d")])?
1490            .build()?;
1491
1492        assert_optimized_plan_equal!(
1493            plan,
1494            @r"
1495        Projection: test.a, Int32(0) AS d
1496          NoOpUserDefined
1497            TableScan: test projection=[a, b, c]
1498        "
1499        )
1500    }
1501
1502    // Columns `l.a`, `l.c`, `r.a` is referred at the output.
1503    // User defined node should be able to propagate necessary expressions by its parent, to its children.
1504    // Even if it has multiple children.
1505    // left child should have `projection=[a, c]`, and right side should have `projection=[a]`.
1506    #[test]
1507    fn test_user_defined_logical_plan_node4() -> Result<()> {
1508        let left_table = test_table_scan_with_name("l")?;
1509        let right_table = test_table_scan_with_name("r")?;
1510        let custom_plan = LogicalPlan::Extension(Extension {
1511            node: Arc::new(UserDefinedCrossJoin::new(
1512                Arc::new(left_table),
1513                Arc::new(right_table),
1514            )),
1515        });
1516        let plan = LogicalPlanBuilder::from(custom_plan)
1517            .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1518            .build()?;
1519
1520        assert_optimized_plan_equal!(
1521            plan,
1522            @r"
1523        Projection: l.a, l.c, r.a, Int32(0) AS d
1524          UserDefinedCrossJoin
1525            TableScan: l projection=[a, c]
1526            TableScan: r projection=[a]
1527        "
1528        )
1529    }
1530
1531    #[test]
1532    fn aggregate_no_group_by() -> Result<()> {
1533        let table_scan = test_table_scan()?;
1534
1535        let plan = LogicalPlanBuilder::from(table_scan)
1536            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1537            .build()?;
1538
1539        assert_optimized_plan_equal!(
1540            plan,
1541            @r"
1542        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1543          TableScan: test projection=[b]
1544        "
1545        )
1546    }
1547
1548    #[test]
1549    fn aggregate_group_by() -> Result<()> {
1550        let table_scan = test_table_scan()?;
1551
1552        let plan = LogicalPlanBuilder::from(table_scan)
1553            .aggregate(vec![col("c")], vec![max(col("b"))])?
1554            .build()?;
1555
1556        assert_optimized_plan_equal!(
1557            plan,
1558            @r"
1559        Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1560          TableScan: test projection=[b, c]
1561        "
1562        )
1563    }
1564
1565    #[test]
1566    fn aggregate_group_by_with_table_alias() -> Result<()> {
1567        let table_scan = test_table_scan()?;
1568
1569        let plan = LogicalPlanBuilder::from(table_scan)
1570            .alias("a")?
1571            .aggregate(vec![col("c")], vec![max(col("b"))])?
1572            .build()?;
1573
1574        assert_optimized_plan_equal!(
1575            plan,
1576            @r"
1577        Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1578          SubqueryAlias: a
1579            TableScan: test projection=[b, c]
1580        "
1581        )
1582    }
1583
1584    #[test]
1585    fn aggregate_no_group_by_with_filter() -> Result<()> {
1586        let table_scan = test_table_scan()?;
1587
1588        let plan = LogicalPlanBuilder::from(table_scan)
1589            .filter(col("c").gt(lit(1)))?
1590            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1591            .build()?;
1592
1593        assert_optimized_plan_equal!(
1594            plan,
1595            @r"
1596        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1597          Projection: test.b
1598            Filter: test.c > Int32(1)
1599              TableScan: test projection=[b, c]
1600        "
1601        )
1602    }
1603
1604    #[test]
1605    fn aggregate_with_periods() -> Result<()> {
1606        let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1607
1608        // Build a plan that looks as follows (note "tag.one" is a column named
1609        // "tag.one", not a column named "one" in a table named "tag"):
1610        //
1611        // Projection: tag.one
1612        //   Aggregate: groupBy=[], aggr=[max("tag.one") AS "tag.one"]
1613        //    TableScan
1614        let plan = table_scan(Some("m4"), &schema, None)?
1615            .aggregate(
1616                Vec::<Expr>::new(),
1617                vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1618            )?
1619            .project([col(Column::new_unqualified("tag.one"))])?
1620            .build()?;
1621
1622        assert_optimized_plan_equal!(
1623            plan,
1624            @r"
1625        Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1626          TableScan: m4 projection=[tag.one]
1627        "
1628        )
1629    }
1630
1631    #[test]
1632    fn redundant_project() -> Result<()> {
1633        let table_scan = test_table_scan()?;
1634
1635        let plan = LogicalPlanBuilder::from(table_scan)
1636            .project(vec![col("a"), col("b"), col("c")])?
1637            .project(vec![col("a"), col("c"), col("b")])?
1638            .build()?;
1639        assert_optimized_plan_equal!(
1640            plan,
1641            @r"
1642        Projection: test.a, test.c, test.b
1643          TableScan: test projection=[a, b, c]
1644        "
1645        )
1646    }
1647
1648    #[test]
1649    fn reorder_scan() -> Result<()> {
1650        let schema = Schema::new(test_table_scan_fields());
1651
1652        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1653        assert_optimized_plan_equal!(
1654            plan,
1655            @"TableScan: test projection=[b, a, c]"
1656        )
1657    }
1658
1659    #[test]
1660    fn reorder_scan_projection() -> Result<()> {
1661        let schema = Schema::new(test_table_scan_fields());
1662
1663        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1664            .project(vec![col("a"), col("b")])?
1665            .build()?;
1666        assert_optimized_plan_equal!(
1667            plan,
1668            @r"
1669        Projection: test.a, test.b
1670          TableScan: test projection=[b, a]
1671        "
1672        )
1673    }
1674
1675    #[test]
1676    fn reorder_projection() -> Result<()> {
1677        let table_scan = test_table_scan()?;
1678
1679        let plan = LogicalPlanBuilder::from(table_scan)
1680            .project(vec![col("c"), col("b"), col("a")])?
1681            .build()?;
1682        assert_optimized_plan_equal!(
1683            plan,
1684            @r"
1685        Projection: test.c, test.b, test.a
1686          TableScan: test projection=[a, b, c]
1687        "
1688        )
1689    }
1690
1691    #[test]
1692    fn noncontinuous_redundant_projection() -> Result<()> {
1693        let table_scan = test_table_scan()?;
1694
1695        let plan = LogicalPlanBuilder::from(table_scan)
1696            .project(vec![col("c"), col("b"), col("a")])?
1697            .filter(col("c").gt(lit(1)))?
1698            .project(vec![col("c"), col("a"), col("b")])?
1699            .filter(col("b").gt(lit(1)))?
1700            .filter(col("a").gt(lit(1)))?
1701            .project(vec![col("a"), col("c"), col("b")])?
1702            .build()?;
1703        assert_optimized_plan_equal!(
1704            plan,
1705            @r"
1706        Projection: test.a, test.c, test.b
1707          Filter: test.a > Int32(1)
1708            Filter: test.b > Int32(1)
1709              Projection: test.c, test.a, test.b
1710                Filter: test.c > Int32(1)
1711                  Projection: test.c, test.b, test.a
1712                    TableScan: test projection=[a, b, c]
1713        "
1714        )
1715    }
1716
1717    #[test]
1718    fn join_schema_trim_full_join_column_projection() -> Result<()> {
1719        let table_scan = test_table_scan()?;
1720
1721        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1722        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1723
1724        let plan = LogicalPlanBuilder::from(table_scan)
1725            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1726            .project(vec![col("a"), col("b"), col("c1")])?
1727            .build()?;
1728
1729        let optimized_plan = optimize(plan)?;
1730
1731        // make sure projections are pushed down to both table scans
1732        assert_snapshot!(
1733            optimized_plan.clone(),
1734            @r"
1735        Left Join: test.a = test2.c1
1736          TableScan: test projection=[a, b]
1737          TableScan: test2 projection=[c1]
1738        "
1739        );
1740
1741        // make sure schema for join node include both join columns
1742        let optimized_join = optimized_plan;
1743        assert_eq!(
1744            **optimized_join.schema(),
1745            DFSchema::new_with_metadata(
1746                vec![
1747                    (
1748                        Some("test".into()),
1749                        Arc::new(Field::new("a", DataType::UInt32, false))
1750                    ),
1751                    (
1752                        Some("test".into()),
1753                        Arc::new(Field::new("b", DataType::UInt32, false))
1754                    ),
1755                    (
1756                        Some("test2".into()),
1757                        Arc::new(Field::new("c1", DataType::UInt32, true))
1758                    ),
1759                ],
1760                HashMap::new()
1761            )?,
1762        );
1763
1764        Ok(())
1765    }
1766
1767    #[test]
1768    fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1769        // test join column push down without explicit column projections
1770
1771        let table_scan = test_table_scan()?;
1772
1773        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1774        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1775
1776        let plan = LogicalPlanBuilder::from(table_scan)
1777            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1778            // projecting joined column `a` should push the right side column `c1` projection as
1779            // well into test2 table even though `c1` is not referenced in projection.
1780            .project(vec![col("a"), col("b")])?
1781            .build()?;
1782
1783        let optimized_plan = optimize(plan)?;
1784
1785        // make sure projections are pushed down to both table scans
1786        assert_snapshot!(
1787            optimized_plan.clone(),
1788            @r"
1789        Projection: test.a, test.b
1790          Left Join: test.a = test2.c1
1791            TableScan: test projection=[a, b]
1792            TableScan: test2 projection=[c1]
1793        "
1794        );
1795
1796        // make sure schema for join node include both join columns
1797        let optimized_join = optimized_plan.inputs()[0];
1798        assert_eq!(
1799            **optimized_join.schema(),
1800            DFSchema::new_with_metadata(
1801                vec![
1802                    (
1803                        Some("test".into()),
1804                        Arc::new(Field::new("a", DataType::UInt32, false))
1805                    ),
1806                    (
1807                        Some("test".into()),
1808                        Arc::new(Field::new("b", DataType::UInt32, false))
1809                    ),
1810                    (
1811                        Some("test2".into()),
1812                        Arc::new(Field::new("c1", DataType::UInt32, true))
1813                    ),
1814                ],
1815                HashMap::new()
1816            )?,
1817        );
1818
1819        Ok(())
1820    }
1821
1822    #[test]
1823    fn join_schema_trim_using_join() -> Result<()> {
1824        // shared join columns from using join should be pushed to both sides
1825
1826        let table_scan = test_table_scan()?;
1827
1828        let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1829        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1830
1831        let plan = LogicalPlanBuilder::from(table_scan)
1832            .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1833            .project(vec![col("a"), col("b")])?
1834            .build()?;
1835
1836        let optimized_plan = optimize(plan)?;
1837
1838        // make sure projections are pushed down to table scan
1839        assert_snapshot!(
1840            optimized_plan.clone(),
1841            @r"
1842        Projection: test.a, test.b
1843          Left Join: Using test.a = test2.a
1844            TableScan: test projection=[a, b]
1845            TableScan: test2 projection=[a]
1846        "
1847        );
1848
1849        // make sure schema for join node include both join columns
1850        let optimized_join = optimized_plan.inputs()[0];
1851        assert_eq!(
1852            **optimized_join.schema(),
1853            DFSchema::new_with_metadata(
1854                vec![
1855                    (
1856                        Some("test".into()),
1857                        Arc::new(Field::new("a", DataType::UInt32, false))
1858                    ),
1859                    (
1860                        Some("test".into()),
1861                        Arc::new(Field::new("b", DataType::UInt32, false))
1862                    ),
1863                    (
1864                        Some("test2".into()),
1865                        Arc::new(Field::new("a", DataType::UInt32, true))
1866                    ),
1867                ],
1868                HashMap::new()
1869            )?,
1870        );
1871
1872        Ok(())
1873    }
1874
1875    #[test]
1876    fn cast() -> Result<()> {
1877        let table_scan = test_table_scan()?;
1878
1879        let plan = LogicalPlanBuilder::from(table_scan)
1880            .project(vec![Expr::Cast(Cast::new(
1881                Box::new(col("c")),
1882                DataType::Float64,
1883            ))])?
1884            .build()?;
1885
1886        assert_optimized_plan_equal!(
1887            plan,
1888            @r"
1889        Projection: CAST(test.c AS Float64)
1890          TableScan: test projection=[c]
1891        "
1892        )
1893    }
1894
1895    #[test]
1896    fn table_scan_projected_schema() -> Result<()> {
1897        let table_scan = test_table_scan()?;
1898        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1899            .project(vec![col("a"), col("b")])?
1900            .build()?;
1901
1902        assert_eq!(3, table_scan.schema().fields().len());
1903        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1904        assert_fields_eq(&plan, vec!["a", "b"]);
1905
1906        assert_optimized_plan_equal!(
1907            plan,
1908            @"TableScan: test projection=[a, b]"
1909        )
1910    }
1911
1912    #[test]
1913    fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
1914        let table_scan = test_table_scan()?;
1915        let input_schema = table_scan.schema();
1916        assert_eq!(3, input_schema.fields().len());
1917        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1918
1919        // Build the LogicalPlan directly (don't use PlanBuilder), so
1920        // that the Column references are unqualified (e.g. their
1921        // relation is `None`). PlanBuilder resolves the expressions
1922        let expr = vec![col("test.a"), col("test.b")];
1923        let plan =
1924            LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
1925
1926        assert_fields_eq(&plan, vec!["a", "b"]);
1927
1928        assert_optimized_plan_equal!(
1929            plan,
1930            @"TableScan: test projection=[a, b]"
1931        )
1932    }
1933
1934    #[test]
1935    fn table_limit() -> Result<()> {
1936        let table_scan = test_table_scan()?;
1937        assert_eq!(3, table_scan.schema().fields().len());
1938        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1939
1940        let plan = LogicalPlanBuilder::from(table_scan)
1941            .project(vec![col("c"), col("a")])?
1942            .limit(0, Some(5))?
1943            .build()?;
1944
1945        assert_fields_eq(&plan, vec!["c", "a"]);
1946
1947        assert_optimized_plan_equal!(
1948            plan,
1949            @r"
1950        Limit: skip=0, fetch=5
1951          Projection: test.c, test.a
1952            TableScan: test projection=[a, c]
1953        "
1954        )
1955    }
1956
1957    #[test]
1958    fn table_scan_without_projection() -> Result<()> {
1959        let table_scan = test_table_scan()?;
1960        let plan = LogicalPlanBuilder::from(table_scan).build()?;
1961        // should expand projection to all columns without projection
1962        assert_optimized_plan_equal!(
1963            plan,
1964            @"TableScan: test projection=[a, b, c]"
1965        )
1966    }
1967
1968    #[test]
1969    fn table_scan_with_literal_projection() -> Result<()> {
1970        let table_scan = test_table_scan()?;
1971        let plan = LogicalPlanBuilder::from(table_scan)
1972            .project(vec![lit(1_i64), lit(2_i64)])?
1973            .build()?;
1974        assert_optimized_plan_equal!(
1975            plan,
1976            @r"
1977        Projection: Int64(1), Int64(2)
1978          TableScan: test projection=[]
1979        "
1980        )
1981    }
1982
1983    /// tests that it removes unused columns in projections
1984    #[test]
1985    fn table_unused_column() -> Result<()> {
1986        let table_scan = test_table_scan()?;
1987        assert_eq!(3, table_scan.schema().fields().len());
1988        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1989
1990        // we never use "b" in the first projection => remove it
1991        let plan = LogicalPlanBuilder::from(table_scan)
1992            .project(vec![col("c"), col("a"), col("b")])?
1993            .filter(col("c").gt(lit(1)))?
1994            .aggregate(vec![col("c")], vec![max(col("a"))])?
1995            .build()?;
1996
1997        assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
1998
1999        let plan = optimize(plan).expect("failed to optimize plan");
2000        assert_optimized_plan_equal!(
2001            plan,
2002            @r"
2003        Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2004          Filter: test.c > Int32(1)
2005            Projection: test.c, test.a
2006              TableScan: test projection=[a, c]
2007        "
2008        )
2009    }
2010
2011    /// tests that it removes un-needed projections
2012    #[test]
2013    fn table_unused_projection() -> Result<()> {
2014        let table_scan = test_table_scan()?;
2015        assert_eq!(3, table_scan.schema().fields().len());
2016        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2017
2018        // there is no need for the first projection
2019        let plan = LogicalPlanBuilder::from(table_scan)
2020            .project(vec![col("b")])?
2021            .project(vec![lit(1).alias("a")])?
2022            .build()?;
2023
2024        assert_fields_eq(&plan, vec!["a"]);
2025
2026        assert_optimized_plan_equal!(
2027            plan,
2028            @r"
2029        Projection: Int32(1) AS a
2030          TableScan: test projection=[]
2031        "
2032        )
2033    }
2034
2035    #[test]
2036    fn table_full_filter_pushdown() -> Result<()> {
2037        let schema = Schema::new(test_table_scan_fields());
2038
2039        let table_scan = table_scan_with_filters(
2040            Some("test"),
2041            &schema,
2042            None,
2043            vec![col("b").eq(lit(1))],
2044        )?
2045        .build()?;
2046        assert_eq!(3, table_scan.schema().fields().len());
2047        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2048
2049        // there is no need for the first projection
2050        let plan = LogicalPlanBuilder::from(table_scan)
2051            .project(vec![col("b")])?
2052            .project(vec![lit(1).alias("a")])?
2053            .build()?;
2054
2055        assert_fields_eq(&plan, vec!["a"]);
2056
2057        assert_optimized_plan_equal!(
2058            plan,
2059            @r"
2060        Projection: Int32(1) AS a
2061          TableScan: test projection=[], full_filters=[b = Int32(1)]
2062        "
2063        )
2064    }
2065
2066    /// tests that optimizing twice yields same plan
2067    #[test]
2068    fn test_double_optimization() -> Result<()> {
2069        let table_scan = test_table_scan()?;
2070
2071        let plan = LogicalPlanBuilder::from(table_scan)
2072            .project(vec![col("b")])?
2073            .project(vec![lit(1).alias("a")])?
2074            .build()?;
2075
2076        let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2077        let optimized_plan2 =
2078            optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2079
2080        let formatted_plan1 = format!("{optimized_plan1:?}");
2081        let formatted_plan2 = format!("{optimized_plan2:?}");
2082        assert_eq!(formatted_plan1, formatted_plan2);
2083        Ok(())
2084    }
2085
2086    /// tests that it removes an aggregate is never used downstream
2087    #[test]
2088    fn table_unused_aggregate() -> Result<()> {
2089        let table_scan = test_table_scan()?;
2090        assert_eq!(3, table_scan.schema().fields().len());
2091        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2092
2093        // we never use "min(b)" => remove it
2094        let plan = LogicalPlanBuilder::from(table_scan)
2095            .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2096            .filter(col("c").gt(lit(1)))?
2097            .project(vec![col("c"), col("a"), col("max(test.b)")])?
2098            .build()?;
2099
2100        assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2101
2102        assert_optimized_plan_equal!(
2103            plan,
2104            @r"
2105        Projection: test.c, test.a, max(test.b)
2106          Filter: test.c > Int32(1)
2107            Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2108              TableScan: test projection=[a, b, c]
2109        "
2110        )
2111    }
2112
2113    #[test]
2114    fn aggregate_filter_pushdown() -> Result<()> {
2115        let table_scan = test_table_scan()?;
2116        let aggr_with_filter = count_udaf()
2117            .call(vec![col("b")])
2118            .filter(col("c").gt(lit(42)))
2119            .build()?;
2120        let plan = LogicalPlanBuilder::from(table_scan)
2121            .aggregate(
2122                vec![col("a")],
2123                vec![count(col("b")), aggr_with_filter.alias("count2")],
2124            )?
2125            .build()?;
2126
2127        assert_optimized_plan_equal!(
2128            plan,
2129            @r"
2130        Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2131          TableScan: test projection=[a, b, c]
2132        "
2133        )
2134    }
2135
2136    #[test]
2137    fn pushdown_through_distinct() -> Result<()> {
2138        let table_scan = test_table_scan()?;
2139
2140        let plan = LogicalPlanBuilder::from(table_scan)
2141            .project(vec![col("a"), col("b")])?
2142            .distinct()?
2143            .project(vec![col("a")])?
2144            .build()?;
2145
2146        assert_optimized_plan_equal!(
2147            plan,
2148            @r"
2149        Projection: test.a
2150          Distinct:
2151            TableScan: test projection=[a, b]
2152        "
2153        )
2154    }
2155
2156    #[test]
2157    fn test_window() -> Result<()> {
2158        let table_scan = test_table_scan()?;
2159
2160        let max1 = Expr::from(expr::WindowFunction::new(
2161            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2162            vec![col("test.a")],
2163        ))
2164        .partition_by(vec![col("test.b")])
2165        .build()
2166        .unwrap();
2167
2168        let max2 = Expr::from(expr::WindowFunction::new(
2169            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2170            vec![col("test.b")],
2171        ));
2172        let col1 = col(max1.schema_name().to_string());
2173        let col2 = col(max2.schema_name().to_string());
2174
2175        let plan = LogicalPlanBuilder::from(table_scan)
2176            .window(vec![max1])?
2177            .window(vec![max2])?
2178            .project(vec![col1, col2])?
2179            .build()?;
2180
2181        assert_optimized_plan_equal!(
2182            plan,
2183            @r"
2184        Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2185          WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2186            Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2187              WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2188                TableScan: test projection=[a, b]
2189        "
2190        )
2191    }
2192
2193    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2194
2195    fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2196        let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2197        let optimized_plan =
2198            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2199        Ok(optimized_plan)
2200    }
2201}