datafusion_optimizer/optimize_projections/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`OptimizeProjections`] identifies and eliminates unused columns
19
20mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28    get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column,
29    HashMap, JoinType, Result,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33    logical_plan::LogicalPlan, Aggregate, Distinct, Expr, Projection, TableScan, Unnest,
34    Window,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43/// Optimizer rule to prune unnecessary columns from intermediate schemas
44/// inside the [`LogicalPlan`]. This rule:
45/// - Removes unnecessary columns that do not appear at the output and/or are
46///   not used during any computation step.
47/// - Adds projections to decrease table column size before operators that
48///   benefit from a smaller memory footprint at its input.
49/// - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`].
50///
51/// `OptimizeProjections` is an optimizer rule that identifies and eliminates
52/// columns from a logical plan that are not used by downstream operations.
53/// This can improve query performance and reduce unnecessary data processing.
54///
55/// The rule analyzes the input logical plan, determines the necessary column
56/// indices, and then removes any unnecessary columns. It also removes any
57/// unnecessary projections from the plan tree.
58#[derive(Default, Debug)]
59pub struct OptimizeProjections {}
60
61impl OptimizeProjections {
62    #[allow(missing_docs)]
63    pub fn new() -> Self {
64        Self {}
65    }
66}
67
68impl OptimizerRule for OptimizeProjections {
69    fn name(&self) -> &str {
70        "optimize_projections"
71    }
72
73    fn apply_order(&self) -> Option<ApplyOrder> {
74        None
75    }
76
77    fn supports_rewrite(&self) -> bool {
78        true
79    }
80
81    fn rewrite(
82        &self,
83        plan: LogicalPlan,
84        config: &dyn OptimizerConfig,
85    ) -> Result<Transformed<LogicalPlan>> {
86        // All output fields are necessary:
87        let indices = RequiredIndices::new_for_all_exprs(&plan);
88        optimize_projections(plan, config, indices)
89    }
90}
91
92/// Removes unnecessary columns (e.g. columns that do not appear in the output
93/// schema and/or are not used during any computation step such as expression
94/// evaluation) from the logical plan and its inputs.
95///
96/// # Parameters
97///
98/// - `plan`: A reference to the input `LogicalPlan` to optimize.
99/// - `config`: A reference to the optimizer configuration.
100/// - `indices`: A slice of column indices that represent the necessary column
101///   indices for downstream (parent) plan nodes.
102///
103/// # Returns
104///
105/// A `Result` object with the following semantics:
106///
107/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary
108///   columns.
109/// - `Ok(None)`: Signal that the given logical plan did not require any change.
110/// - `Err(error)`: An error occurred during the optimization process.
111#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
112fn optimize_projections(
113    plan: LogicalPlan,
114    config: &dyn OptimizerConfig,
115    indices: RequiredIndices,
116) -> Result<Transformed<LogicalPlan>> {
117    // Recursively rewrite any nodes that may be able to avoid computation given
118    // their parents' required indices.
119    match plan {
120        LogicalPlan::Projection(proj) => {
121            return merge_consecutive_projections(proj)?.transform_data(|proj| {
122                rewrite_projection_given_requirements(proj, config, &indices)
123            })
124        }
125        LogicalPlan::Aggregate(aggregate) => {
126            // Split parent requirements to GROUP BY and aggregate sections:
127            let n_group_exprs = aggregate.group_expr_len()?;
128            // Offset aggregate indices so that they point to valid indices at
129            // `aggregate.aggr_expr`:
130            let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
131
132            // Get absolutely necessary GROUP BY fields:
133            let group_by_expr_existing = aggregate
134                .group_expr
135                .iter()
136                .map(|group_by_expr| group_by_expr.schema_name().to_string())
137                .collect::<Vec<_>>();
138
139            let new_group_bys = if let Some(simplest_groupby_indices) =
140                get_required_group_by_exprs_indices(
141                    aggregate.input.schema(),
142                    &group_by_expr_existing,
143                ) {
144                // Some of the fields in the GROUP BY may be required by the
145                // parent even if these fields are unnecessary in terms of
146                // functional dependency.
147                group_by_reqs
148                    .append(&simplest_groupby_indices)
149                    .get_at_indices(&aggregate.group_expr)
150            } else {
151                aggregate.group_expr
152            };
153
154            // Only use the absolutely necessary aggregate expressions required
155            // by the parent:
156            let mut new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
157
158            // Aggregations always need at least one aggregate expression.
159            // With a nested count, we don't require any column as input, but
160            // still need to create a correct aggregate, which may be optimized
161            // out later. As an example, consider the following query:
162            //
163            // SELECT count(*) FROM (SELECT count(*) FROM [...])
164            //
165            // which always returns 1.
166            if new_aggr_expr.is_empty()
167                && new_group_bys.is_empty()
168                && !aggregate.aggr_expr.is_empty()
169            {
170                // take the old, first aggregate expression
171                new_aggr_expr = aggregate.aggr_expr;
172                new_aggr_expr.resize_with(1, || unreachable!());
173            }
174
175            let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
176            let schema = aggregate.input.schema();
177            let necessary_indices =
178                RequiredIndices::new().with_exprs(schema, all_exprs_iter);
179            let necessary_exprs = necessary_indices.get_required_exprs(schema);
180
181            return optimize_projections(
182                Arc::unwrap_or_clone(aggregate.input),
183                config,
184                necessary_indices,
185            )?
186            .transform_data(|aggregate_input| {
187                // Simplify the input of the aggregation by adding a projection so
188                // that its input only contains absolutely necessary columns for
189                // the aggregate expressions. Note that necessary_indices refer to
190                // fields in `aggregate.input.schema()`.
191                add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
192            })?
193            .map_data(|aggregate_input| {
194                // Create a new aggregate plan with the updated input and only the
195                // absolutely necessary fields:
196                Aggregate::try_new(
197                    Arc::new(aggregate_input),
198                    new_group_bys,
199                    new_aggr_expr,
200                )
201                .map(LogicalPlan::Aggregate)
202            });
203        }
204        LogicalPlan::Window(window) => {
205            let input_schema = Arc::clone(window.input.schema());
206            // Split parent requirements to child and window expression sections:
207            let n_input_fields = input_schema.fields().len();
208            // Offset window expression indices so that they point to valid
209            // indices at `window.window_expr`:
210            let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
211
212            // Only use window expressions that are absolutely necessary according
213            // to parent requirements:
214            let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
215
216            // Get all the required column indices at the input, either by the
217            // parent or window expression requirements.
218            let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
219
220            return optimize_projections(
221                Arc::unwrap_or_clone(window.input),
222                config,
223                required_indices.clone(),
224            )?
225            .transform_data(|window_child| {
226                if new_window_expr.is_empty() {
227                    // When no window expression is necessary, use the input directly:
228                    Ok(Transformed::no(window_child))
229                } else {
230                    // Calculate required expressions at the input of the window.
231                    // Please note that we use `input_schema`, because `required_indices`
232                    // refers to that schema
233                    let required_exprs =
234                        required_indices.get_required_exprs(&input_schema);
235                    let window_child =
236                        add_projection_on_top_if_helpful(window_child, required_exprs)?
237                            .data;
238                    Window::try_new(new_window_expr, Arc::new(window_child))
239                        .map(LogicalPlan::Window)
240                        .map(Transformed::yes)
241                }
242            });
243        }
244        LogicalPlan::TableScan(table_scan) => {
245            let TableScan {
246                table_name,
247                source,
248                projection,
249                filters,
250                fetch,
251                projected_schema: _,
252            } = table_scan;
253
254            // Get indices referred to in the original (schema with all fields)
255            // given projected indices.
256            let projection = match &projection {
257                Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
258                None => indices.into_inner(),
259            };
260            return TableScan::try_new(
261                table_name,
262                source,
263                Some(projection),
264                filters,
265                fetch,
266            )
267            .map(LogicalPlan::TableScan)
268            .map(Transformed::yes);
269        }
270        // Other node types are handled below
271        _ => {}
272    };
273
274    // For other plan node types, calculate indices for columns they use and
275    // try to rewrite their children
276    let mut child_required_indices: Vec<RequiredIndices> = match &plan {
277        LogicalPlan::Sort(_)
278        | LogicalPlan::Filter(_)
279        | LogicalPlan::Repartition(_)
280        | LogicalPlan::Union(_)
281        | LogicalPlan::SubqueryAlias(_)
282        | LogicalPlan::Distinct(Distinct::On(_)) => {
283            // Pass index requirements from the parent as well as column indices
284            // that appear in this plan's expressions to its child. All these
285            // operators benefit from "small" inputs, so the projection_beneficial
286            // flag is `true`.
287            plan.inputs()
288                .into_iter()
289                .map(|input| {
290                    indices
291                        .clone()
292                        .with_projection_beneficial()
293                        .with_plan_exprs(&plan, input.schema())
294                })
295                .collect::<Result<_>>()?
296        }
297        LogicalPlan::Limit(_) => {
298            // Pass index requirements from the parent as well as column indices
299            // that appear in this plan's expressions to its child. These operators
300            // do not benefit from "small" inputs, so the projection_beneficial
301            // flag is `false`.
302            plan.inputs()
303                .into_iter()
304                .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
305                .collect::<Result<_>>()?
306        }
307        LogicalPlan::Copy(_)
308        | LogicalPlan::Ddl(_)
309        | LogicalPlan::Dml(_)
310        | LogicalPlan::Explain(_)
311        | LogicalPlan::Analyze(_)
312        | LogicalPlan::Subquery(_)
313        | LogicalPlan::Statement(_)
314        | LogicalPlan::Distinct(Distinct::All(_)) => {
315            // These plans require all their fields, and their children should
316            // be treated as final plans -- otherwise, we may have schema a
317            // mismatch.
318            // TODO: For some subquery variants (e.g. a subquery arising from an
319            //       EXISTS expression), we may not need to require all indices.
320            plan.inputs()
321                .into_iter()
322                .map(RequiredIndices::new_for_all_exprs)
323                .collect()
324        }
325        LogicalPlan::Extension(extension) => {
326            let Some(necessary_children_indices) =
327                extension.node.necessary_children_exprs(indices.indices())
328            else {
329                // Requirements from parent cannot be routed down to user defined logical plan safely
330                return Ok(Transformed::no(plan));
331            };
332            let children = extension.node.inputs();
333            if children.len() != necessary_children_indices.len() {
334                return internal_err!("Inconsistent length between children and necessary children indices. \
335                Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \
336                consistent with actual children length for the node.");
337            }
338            children
339                .into_iter()
340                .zip(necessary_children_indices)
341                .map(|(child, necessary_indices)| {
342                    RequiredIndices::new_from_indices(necessary_indices)
343                        .with_plan_exprs(&plan, child.schema())
344                })
345                .collect::<Result<Vec<_>>>()?
346        }
347        LogicalPlan::EmptyRelation(_)
348        | LogicalPlan::RecursiveQuery(_)
349        | LogicalPlan::Values(_)
350        | LogicalPlan::DescribeTable(_) => {
351            // These operators have no inputs, so stop the optimization process.
352            return Ok(Transformed::no(plan));
353        }
354        LogicalPlan::Join(join) => {
355            let left_len = join.left.schema().fields().len();
356            let (left_req_indices, right_req_indices) =
357                split_join_requirements(left_len, indices, &join.join_type);
358            let left_indices =
359                left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
360            let right_indices =
361                right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
362            // Joins benefit from "small" input tables (lower memory usage).
363            // Therefore, each child benefits from projection:
364            vec![
365                left_indices.with_projection_beneficial(),
366                right_indices.with_projection_beneficial(),
367            ]
368        }
369        // these nodes are explicitly rewritten in the match statement above
370        LogicalPlan::Projection(_)
371        | LogicalPlan::Aggregate(_)
372        | LogicalPlan::Window(_)
373        | LogicalPlan::TableScan(_) => {
374            return internal_err!(
375                "OptimizeProjection: should have handled in the match statement above"
376            );
377        }
378        LogicalPlan::Unnest(Unnest {
379            input,
380            dependency_indices,
381            ..
382        }) => {
383            // at least provide the indices for the exec-columns as a starting point
384            let required_indices =
385                RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
386
387            // Add additional required indices from the parent
388            let mut additional_necessary_child_indices = Vec::new();
389            indices.indices().iter().for_each(|idx| {
390                if let Some(index) = dependency_indices.get(*idx) {
391                    additional_necessary_child_indices.push(*index);
392                }
393            });
394            vec![required_indices.append(&additional_necessary_child_indices)]
395        }
396    };
397
398    // Required indices are currently ordered (child0, child1, ...)
399    // but the loop pops off the last element, so we need to reverse the order
400    child_required_indices.reverse();
401    if child_required_indices.len() != plan.inputs().len() {
402        return internal_err!(
403            "OptimizeProjection: child_required_indices length mismatch with plan inputs"
404        );
405    }
406
407    // Rewrite children of the plan
408    let transformed_plan = plan.map_children(|child| {
409        let required_indices = child_required_indices.pop().ok_or_else(|| {
410            internal_datafusion_err!(
411                "Unexpected number of required_indices in OptimizeProjections rule"
412            )
413        })?;
414
415        let projection_beneficial = required_indices.projection_beneficial();
416        let project_exprs = required_indices.get_required_exprs(child.schema());
417
418        optimize_projections(child, config, required_indices)?.transform_data(
419            |new_input| {
420                if projection_beneficial {
421                    add_projection_on_top_if_helpful(new_input, project_exprs)
422                } else {
423                    Ok(Transformed::no(new_input))
424                }
425            },
426        )
427    })?;
428
429    // If any of the children are transformed, we need to potentially update the plan's schema
430    if transformed_plan.transformed {
431        transformed_plan.map_data(|plan| plan.recompute_schema())
432    } else {
433        Ok(transformed_plan)
434    }
435}
436
437/// Merges consecutive projections.
438///
439/// Given a projection `proj`, this function attempts to merge it with a previous
440/// projection if it exists and if merging is beneficial. Merging is considered
441/// beneficial when expressions in the current projection are non-trivial and
442/// appear more than once in its input fields. This can act as a caching mechanism
443/// for non-trivial computations.
444///
445/// # Parameters
446///
447/// * `proj` - A reference to the `Projection` to be merged.
448///
449/// # Returns
450///
451/// A `Result` object with the following semantics:
452///
453/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the
454///   merged projection.
455/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place).
456/// - `Err(error)`: An error occurred during the function call.
457fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
458    let Projection {
459        expr,
460        input,
461        schema,
462        ..
463    } = proj;
464    let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
465        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
466    };
467
468    // A fast path: if the previous projection is same as the current projection
469    // we can directly remove the current projection and return child projection.
470    if prev_projection.expr == expr {
471        return Projection::try_new_with_schema(
472            expr,
473            Arc::clone(&prev_projection.input),
474            schema,
475        )
476        .map(Transformed::yes);
477    }
478
479    // Count usages (referrals) of each projection expression in its input fields:
480    let mut column_referral_map = HashMap::<&Column, usize>::new();
481    expr.iter()
482        .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
483
484    // If an expression is non-trivial and appears more than once, do not merge
485    // them as consecutive projections will benefit from a compute-once approach.
486    // For details, see: https://github.com/apache/datafusion/issues/8296
487    if column_referral_map.into_iter().any(|(col, usage)| {
488        usage > 1
489            && !is_expr_trivial(
490                &prev_projection.expr
491                    [prev_projection.schema.index_of_column(col).unwrap()],
492            )
493    }) {
494        // no change
495        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
496    }
497
498    let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
499        // We know it is a `LogicalPlan::Projection` from check above
500        unreachable!();
501    };
502
503    // Try to rewrite the expressions in the current projection using the
504    // previous projection as input:
505    let name_preserver = NamePreserver::new_for_projection();
506    let mut original_names = vec![];
507    let new_exprs = expr.map_elements(|expr| {
508        original_names.push(name_preserver.save(&expr));
509
510        // do not rewrite top level Aliases (rewriter will remove all aliases within exprs)
511        match expr {
512            Expr::Alias(Alias {
513                expr,
514                relation,
515                name,
516                metadata,
517            }) => rewrite_expr(*expr, &prev_projection).map(|result| {
518                result.update_data(|expr| {
519                    Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata))
520                })
521            }),
522            e => rewrite_expr(e, &prev_projection),
523        }
524    })?;
525
526    // if the expressions could be rewritten, create a new projection with the
527    // new expressions
528    if new_exprs.transformed {
529        // Add any needed aliases back to the expressions
530        let new_exprs = new_exprs
531            .data
532            .into_iter()
533            .zip(original_names)
534            .map(|(expr, original_name)| original_name.restore(expr))
535            .collect::<Vec<_>>();
536        Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
537    } else {
538        // not rewritten, so put the projection back together
539        let input = Arc::new(LogicalPlan::Projection(prev_projection));
540        Projection::try_new_with_schema(new_exprs.data, input, schema)
541            .map(Transformed::no)
542    }
543}
544
545// Check whether `expr` is trivial; i.e. it doesn't imply any computation.
546fn is_expr_trivial(expr: &Expr) -> bool {
547    matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
548}
549
550/// Rewrites a projection expression using the projection before it (i.e. its input)
551/// This is a subroutine to the `merge_consecutive_projections` function.
552///
553/// # Parameters
554///
555/// * `expr` - A reference to the expression to rewrite.
556/// * `input` - A reference to the input of the projection expression (itself
557///   a projection).
558///
559/// # Returns
560///
561/// A `Result` object with the following semantics:
562///
563/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result.
564/// - `Ok(None)`: Signals that `expr` can not be rewritten.
565/// - `Err(error)`: An error occurred during the function call.
566///
567/// # Notes
568/// This rewrite also removes any unnecessary layers of aliasing.
569///
570/// Without trimming, we can end up with unnecessary indirections inside expressions
571/// during projection merges.
572///
573/// Consider:
574///
575/// ```text
576/// Projection(a1 + b1 as sum1)
577/// --Projection(a as a1, b as b1)
578/// ----Source(a, b)
579/// ```
580///
581/// After merge, we want to produce:
582///
583/// ```text
584/// Projection(a + b as sum1)
585/// --Source(a, b)
586/// ```
587///
588/// Without trimming, we would end up with:
589///
590/// ```text
591/// Projection((a as a1 + b as b1) as sum1)
592/// --Source(a, b)
593/// ```
594fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
595    expr.transform_up(|expr| {
596        match expr {
597            //  remove any intermediate aliases if they do not carry metadata
598            Expr::Alias(alias) => {
599                match alias
600                    .metadata
601                    .as_ref()
602                    .map(|h| h.is_empty())
603                    .unwrap_or(true)
604                {
605                    true => Ok(Transformed::yes(*alias.expr)),
606                    false => Ok(Transformed::no(Expr::Alias(alias))),
607                }
608            }
609            Expr::Column(col) => {
610                // Find index of column:
611                let idx = input.schema.index_of_column(&col)?;
612                // get the corresponding unaliased input expression
613                //
614                // For example:
615                // * the input projection is [`a + b` as c, `d + e` as f]
616                // * the current column is an expression "f"
617                //
618                // return the expression `d + e` (not `d + e` as f)
619                let input_expr = input.expr[idx].clone().unalias_nested().data;
620                Ok(Transformed::yes(input_expr))
621            }
622            // Unsupported type for consecutive projection merge analysis.
623            _ => Ok(Transformed::no(expr)),
624        }
625    })
626}
627
628/// Accumulates outer-referenced columns by the
629/// given expression, `expr`.
630///
631/// # Parameters
632///
633/// * `expr` - The expression to analyze for outer-referenced columns.
634/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
635///   columns are collected.
636fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
637    // inspect_expr_pre doesn't handle subquery references, so find them explicitly
638    expr.apply(|expr| {
639        match expr {
640            Expr::OuterReferenceColumn(_, col) => {
641                columns.insert(col);
642            }
643            Expr::ScalarSubquery(subquery) => {
644                outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
645            }
646            Expr::Exists(exists) => {
647                outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
648            }
649            Expr::InSubquery(insubquery) => {
650                outer_columns_helper_multi(
651                    &insubquery.subquery.outer_ref_columns,
652                    columns,
653                );
654            }
655            _ => {}
656        };
657        Ok(TreeNodeRecursion::Continue)
658    })
659    // unwrap: closure above never returns Err, so can not be Err here
660    .unwrap();
661}
662
663/// A recursive subroutine that accumulates outer-referenced columns by the
664/// given expressions (`exprs`).
665///
666/// # Parameters
667///
668/// * `exprs` - The expressions to analyze for outer-referenced columns.
669/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
670///   columns are collected.
671fn outer_columns_helper_multi<'a, 'b>(
672    exprs: impl IntoIterator<Item = &'a Expr>,
673    columns: &'b mut HashSet<&'a Column>,
674) {
675    exprs.into_iter().for_each(|e| outer_columns(e, columns));
676}
677
678/// Splits requirement indices for a join into left and right children based on
679/// the join type.
680///
681/// This function takes the length of the left child, a slice of requirement
682/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments.
683/// Depending on the join type, it divides the requirement indices into those
684/// that apply to the left child and those that apply to the right child.
685///
686/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins,
687///   the requirements are split between left and right children. The right
688///   child indices are adjusted to point to valid positions within the right
689///   child by subtracting the length of the left child.
690///
691/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all
692///   requirements are re-routed to either the left child or the right child
693///   directly, depending on the join type.
694///
695/// # Parameters
696///
697/// * `left_len` - The length of the left child.
698/// * `indices` - A slice of requirement indices.
699/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`).
700///
701/// # Returns
702///
703/// A tuple containing two vectors of `usize` indices: The first vector represents
704/// the requirements for the left child, and the second vector represents the
705/// requirements for the right child. The indices are appropriately split and
706/// adjusted based on the join type.
707fn split_join_requirements(
708    left_len: usize,
709    indices: RequiredIndices,
710    join_type: &JoinType,
711) -> (RequiredIndices, RequiredIndices) {
712    match join_type {
713        // In these cases requirements are split between left/right children:
714        JoinType::Inner
715        | JoinType::Left
716        | JoinType::Right
717        | JoinType::Full
718        | JoinType::LeftMark
719        | JoinType::RightMark => {
720            // Decrease right side indices by `left_len` so that they point to valid
721            // positions within the right child:
722            indices.split_off(left_len)
723        }
724        // All requirements can be re-routed to left child directly.
725        JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
726        // All requirements can be re-routed to right side directly.
727        // No need to change index, join schema is right child schema.
728        JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
729    }
730}
731
732/// Adds a projection on top of a logical plan if doing so reduces the number
733/// of columns for the parent operator.
734///
735/// This function takes a `LogicalPlan` and a list of projection expressions.
736/// If the projection is beneficial (it reduces the number of columns in the
737/// plan) a new `LogicalPlan` with the projection is created and returned, along
738/// with a `true` flag. If the projection doesn't reduce the number of columns,
739/// the original plan is returned with a `false` flag.
740///
741/// # Parameters
742///
743/// * `plan` - The input `LogicalPlan` to potentially add a projection to.
744/// * `project_exprs` - A list of expressions for the projection.
745///
746/// # Returns
747///
748/// A `Transformed` indicating if a projection was added
749fn add_projection_on_top_if_helpful(
750    plan: LogicalPlan,
751    project_exprs: Vec<Expr>,
752) -> Result<Transformed<LogicalPlan>> {
753    // Make sure projection decreases the number of columns, otherwise it is unnecessary.
754    if project_exprs.len() >= plan.schema().fields().len() {
755        Ok(Transformed::no(plan))
756    } else {
757        Projection::try_new(project_exprs, Arc::new(plan))
758            .map(LogicalPlan::Projection)
759            .map(Transformed::yes)
760    }
761}
762
763/// Rewrite the given projection according to the fields required by its
764/// ancestors.
765///
766/// # Parameters
767///
768/// * `proj` - A reference to the original projection to rewrite.
769/// * `config` - A reference to the optimizer configuration.
770/// * `indices` - A slice of indices representing the columns required by the
771///   ancestors of the given projection.
772///
773/// # Returns
774///
775/// A `Result` object with the following semantics:
776///
777/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection
778/// - `Ok(None)`: No rewrite necessary.
779/// - `Err(error)`: An error occurred during the function call.
780fn rewrite_projection_given_requirements(
781    proj: Projection,
782    config: &dyn OptimizerConfig,
783    indices: &RequiredIndices,
784) -> Result<Transformed<LogicalPlan>> {
785    let Projection { expr, input, .. } = proj;
786
787    let exprs_used = indices.get_at_indices(&expr);
788
789    let required_indices =
790        RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
791
792    // rewrite the children projection, and if they are changed rewrite the
793    // projection down
794    optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
795        .transform_data(|input| {
796            if is_projection_unnecessary(&input, &exprs_used)? {
797                Ok(Transformed::yes(input))
798            } else {
799                Projection::try_new(exprs_used, Arc::new(input))
800                    .map(LogicalPlan::Projection)
801                    .map(Transformed::yes)
802            }
803        })
804}
805
806/// Projection is unnecessary, when
807/// - input schema of the projection, output schema of the projection are same, and
808/// - all projection expressions are either Column or Literal
809pub fn is_projection_unnecessary(
810    input: &LogicalPlan,
811    proj_exprs: &[Expr],
812) -> Result<bool> {
813    // First check if the number of expressions is equal to the number of fields in the input schema.
814    if proj_exprs.len() != input.schema().fields().len() {
815        return Ok(false);
816    }
817    Ok(input.schema().iter().zip(proj_exprs.iter()).all(
818        |((field_relation, field_name), expr)| {
819            // Check if the expression is a column and if it matches the field name
820            if let Expr::Column(col) = expr {
821                col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
822            } else {
823                false
824            }
825        },
826    ))
827}
828
829#[cfg(test)]
830mod tests {
831    use std::cmp::Ordering;
832    use std::collections::HashMap;
833    use std::fmt::Formatter;
834    use std::ops::Add;
835    use std::sync::Arc;
836    use std::vec;
837
838    use crate::optimize_projections::OptimizeProjections;
839    use crate::optimizer::Optimizer;
840    use crate::test::{
841        assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
842        test_table_scan_with_name,
843    };
844    use crate::{OptimizerContext, OptimizerRule};
845    use arrow::datatypes::{DataType, Field, Schema};
846    use datafusion_common::{
847        Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
848    };
849    use datafusion_expr::ExprFunctionExt;
850    use datafusion_expr::{
851        binary_expr, build_join_schema,
852        builder::table_scan_with_filters,
853        col,
854        expr::{self, Cast},
855        lit,
856        logical_plan::{builder::LogicalPlanBuilder, table_scan},
857        not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
858        Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition,
859    };
860    use insta::assert_snapshot;
861
862    use crate::assert_optimized_plan_eq_snapshot;
863    use datafusion_functions_aggregate::count::count_udaf;
864    use datafusion_functions_aggregate::expr_fn::{count, max, min};
865    use datafusion_functions_aggregate::min_max::max_udaf;
866
867    macro_rules! assert_optimized_plan_equal {
868        (
869            $plan:expr,
870            @ $expected:literal $(,)?
871        ) => {{
872            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
873            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
874            assert_optimized_plan_eq_snapshot!(
875                optimizer_ctx,
876                rules,
877                $plan,
878                @ $expected,
879            )
880        }};
881    }
882
883    #[derive(Debug, Hash, PartialEq, Eq)]
884    struct NoOpUserDefined {
885        exprs: Vec<Expr>,
886        schema: DFSchemaRef,
887        input: Arc<LogicalPlan>,
888    }
889
890    impl NoOpUserDefined {
891        fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
892            Self {
893                exprs: vec![],
894                schema,
895                input,
896            }
897        }
898
899        fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
900            self.exprs = exprs;
901            self
902        }
903    }
904
905    // Manual implementation needed because of `schema` field. Comparison excludes this field.
906    impl PartialOrd for NoOpUserDefined {
907        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
908            match self.exprs.partial_cmp(&other.exprs) {
909                Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
910                cmp => cmp,
911            }
912        }
913    }
914
915    impl UserDefinedLogicalNodeCore for NoOpUserDefined {
916        fn name(&self) -> &str {
917            "NoOpUserDefined"
918        }
919
920        fn inputs(&self) -> Vec<&LogicalPlan> {
921            vec![&self.input]
922        }
923
924        fn schema(&self) -> &DFSchemaRef {
925            &self.schema
926        }
927
928        fn expressions(&self) -> Vec<Expr> {
929            self.exprs.clone()
930        }
931
932        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
933            write!(f, "NoOpUserDefined")
934        }
935
936        fn with_exprs_and_inputs(
937            &self,
938            exprs: Vec<Expr>,
939            mut inputs: Vec<LogicalPlan>,
940        ) -> Result<Self> {
941            Ok(Self {
942                exprs,
943                input: Arc::new(inputs.swap_remove(0)),
944                schema: Arc::clone(&self.schema),
945            })
946        }
947
948        fn necessary_children_exprs(
949            &self,
950            output_columns: &[usize],
951        ) -> Option<Vec<Vec<usize>>> {
952            // Since schema is same. Output columns requires their corresponding version in the input columns.
953            Some(vec![output_columns.to_vec()])
954        }
955
956        fn supports_limit_pushdown(&self) -> bool {
957            false // Disallow limit push-down by default
958        }
959    }
960
961    #[derive(Debug, Hash, PartialEq, Eq)]
962    struct UserDefinedCrossJoin {
963        exprs: Vec<Expr>,
964        schema: DFSchemaRef,
965        left_child: Arc<LogicalPlan>,
966        right_child: Arc<LogicalPlan>,
967    }
968
969    impl UserDefinedCrossJoin {
970        fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
971            let left_schema = left_child.schema();
972            let right_schema = right_child.schema();
973            let schema = Arc::new(
974                build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
975            );
976            Self {
977                exprs: vec![],
978                schema,
979                left_child,
980                right_child,
981            }
982        }
983    }
984
985    // Manual implementation needed because of `schema` field. Comparison excludes this field.
986    impl PartialOrd for UserDefinedCrossJoin {
987        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
988            match self.exprs.partial_cmp(&other.exprs) {
989                Some(Ordering::Equal) => {
990                    match self.left_child.partial_cmp(&other.left_child) {
991                        Some(Ordering::Equal) => {
992                            self.right_child.partial_cmp(&other.right_child)
993                        }
994                        cmp => cmp,
995                    }
996                }
997                cmp => cmp,
998            }
999        }
1000    }
1001
1002    impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
1003        fn name(&self) -> &str {
1004            "UserDefinedCrossJoin"
1005        }
1006
1007        fn inputs(&self) -> Vec<&LogicalPlan> {
1008            vec![&self.left_child, &self.right_child]
1009        }
1010
1011        fn schema(&self) -> &DFSchemaRef {
1012            &self.schema
1013        }
1014
1015        fn expressions(&self) -> Vec<Expr> {
1016            self.exprs.clone()
1017        }
1018
1019        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1020            write!(f, "UserDefinedCrossJoin")
1021        }
1022
1023        fn with_exprs_and_inputs(
1024            &self,
1025            exprs: Vec<Expr>,
1026            mut inputs: Vec<LogicalPlan>,
1027        ) -> Result<Self> {
1028            assert_eq!(inputs.len(), 2);
1029            Ok(Self {
1030                exprs,
1031                left_child: Arc::new(inputs.remove(0)),
1032                right_child: Arc::new(inputs.remove(0)),
1033                schema: Arc::clone(&self.schema),
1034            })
1035        }
1036
1037        fn necessary_children_exprs(
1038            &self,
1039            output_columns: &[usize],
1040        ) -> Option<Vec<Vec<usize>>> {
1041            let left_child_len = self.left_child.schema().fields().len();
1042            let mut left_reqs = vec![];
1043            let mut right_reqs = vec![];
1044            for &out_idx in output_columns {
1045                if out_idx < left_child_len {
1046                    left_reqs.push(out_idx);
1047                } else {
1048                    // Output indices further than the left_child_len
1049                    // comes from right children
1050                    right_reqs.push(out_idx - left_child_len)
1051                }
1052            }
1053            Some(vec![left_reqs, right_reqs])
1054        }
1055
1056        fn supports_limit_pushdown(&self) -> bool {
1057            false // Disallow limit push-down by default
1058        }
1059    }
1060
1061    #[test]
1062    fn merge_two_projection() -> Result<()> {
1063        let table_scan = test_table_scan()?;
1064        let plan = LogicalPlanBuilder::from(table_scan)
1065            .project(vec![col("a")])?
1066            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1067            .build()?;
1068
1069        assert_optimized_plan_equal!(
1070            plan,
1071            @r"
1072        Projection: Int32(1) + test.a
1073          TableScan: test projection=[a]
1074        "
1075        )
1076    }
1077
1078    #[test]
1079    fn merge_three_projection() -> Result<()> {
1080        let table_scan = test_table_scan()?;
1081        let plan = LogicalPlanBuilder::from(table_scan)
1082            .project(vec![col("a"), col("b")])?
1083            .project(vec![col("a")])?
1084            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1085            .build()?;
1086
1087        assert_optimized_plan_equal!(
1088            plan,
1089            @r"
1090        Projection: Int32(1) + test.a
1091          TableScan: test projection=[a]
1092        "
1093        )
1094    }
1095
1096    #[test]
1097    fn merge_alias() -> Result<()> {
1098        let table_scan = test_table_scan()?;
1099        let plan = LogicalPlanBuilder::from(table_scan)
1100            .project(vec![col("a")])?
1101            .project(vec![col("a").alias("alias")])?
1102            .build()?;
1103
1104        assert_optimized_plan_equal!(
1105            plan,
1106            @r"
1107        Projection: test.a AS alias
1108          TableScan: test projection=[a]
1109        "
1110        )
1111    }
1112
1113    #[test]
1114    fn merge_nested_alias() -> Result<()> {
1115        let table_scan = test_table_scan()?;
1116        let plan = LogicalPlanBuilder::from(table_scan)
1117            .project(vec![col("a").alias("alias1").alias("alias2")])?
1118            .project(vec![col("alias2").alias("alias")])?
1119            .build()?;
1120
1121        assert_optimized_plan_equal!(
1122            plan,
1123            @r"
1124        Projection: test.a AS alias
1125          TableScan: test projection=[a]
1126        "
1127        )
1128    }
1129
1130    #[test]
1131    fn test_nested_count() -> Result<()> {
1132        let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1133
1134        let groups: Vec<Expr> = vec![];
1135
1136        let plan = table_scan(TableReference::none(), &schema, None)
1137            .unwrap()
1138            .aggregate(groups.clone(), vec![count(lit(1))])
1139            .unwrap()
1140            .aggregate(groups, vec![count(lit(1))])
1141            .unwrap()
1142            .build()
1143            .unwrap();
1144
1145        assert_optimized_plan_equal!(
1146            plan,
1147            @r"
1148        Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1149          Projection:
1150            Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1151              TableScan: ?table? projection=[]
1152        "
1153        )
1154    }
1155
1156    #[test]
1157    fn test_neg_push_down() -> Result<()> {
1158        let table_scan = test_table_scan()?;
1159        let plan = LogicalPlanBuilder::from(table_scan)
1160            .project(vec![-col("a")])?
1161            .build()?;
1162
1163        assert_optimized_plan_equal!(
1164            plan,
1165            @r"
1166        Projection: (- test.a)
1167          TableScan: test projection=[a]
1168        "
1169        )
1170    }
1171
1172    #[test]
1173    fn test_is_null() -> Result<()> {
1174        let table_scan = test_table_scan()?;
1175        let plan = LogicalPlanBuilder::from(table_scan)
1176            .project(vec![col("a").is_null()])?
1177            .build()?;
1178
1179        assert_optimized_plan_equal!(
1180            plan,
1181            @r"
1182        Projection: test.a IS NULL
1183          TableScan: test projection=[a]
1184        "
1185        )
1186    }
1187
1188    #[test]
1189    fn test_is_not_null() -> Result<()> {
1190        let table_scan = test_table_scan()?;
1191        let plan = LogicalPlanBuilder::from(table_scan)
1192            .project(vec![col("a").is_not_null()])?
1193            .build()?;
1194
1195        assert_optimized_plan_equal!(
1196            plan,
1197            @r"
1198        Projection: test.a IS NOT NULL
1199          TableScan: test projection=[a]
1200        "
1201        )
1202    }
1203
1204    #[test]
1205    fn test_is_true() -> Result<()> {
1206        let table_scan = test_table_scan()?;
1207        let plan = LogicalPlanBuilder::from(table_scan)
1208            .project(vec![col("a").is_true()])?
1209            .build()?;
1210
1211        assert_optimized_plan_equal!(
1212            plan,
1213            @r"
1214        Projection: test.a IS TRUE
1215          TableScan: test projection=[a]
1216        "
1217        )
1218    }
1219
1220    #[test]
1221    fn test_is_not_true() -> Result<()> {
1222        let table_scan = test_table_scan()?;
1223        let plan = LogicalPlanBuilder::from(table_scan)
1224            .project(vec![col("a").is_not_true()])?
1225            .build()?;
1226
1227        assert_optimized_plan_equal!(
1228            plan,
1229            @r"
1230        Projection: test.a IS NOT TRUE
1231          TableScan: test projection=[a]
1232        "
1233        )
1234    }
1235
1236    #[test]
1237    fn test_is_false() -> Result<()> {
1238        let table_scan = test_table_scan()?;
1239        let plan = LogicalPlanBuilder::from(table_scan)
1240            .project(vec![col("a").is_false()])?
1241            .build()?;
1242
1243        assert_optimized_plan_equal!(
1244            plan,
1245            @r"
1246        Projection: test.a IS FALSE
1247          TableScan: test projection=[a]
1248        "
1249        )
1250    }
1251
1252    #[test]
1253    fn test_is_not_false() -> Result<()> {
1254        let table_scan = test_table_scan()?;
1255        let plan = LogicalPlanBuilder::from(table_scan)
1256            .project(vec![col("a").is_not_false()])?
1257            .build()?;
1258
1259        assert_optimized_plan_equal!(
1260            plan,
1261            @r"
1262        Projection: test.a IS NOT FALSE
1263          TableScan: test projection=[a]
1264        "
1265        )
1266    }
1267
1268    #[test]
1269    fn test_is_unknown() -> Result<()> {
1270        let table_scan = test_table_scan()?;
1271        let plan = LogicalPlanBuilder::from(table_scan)
1272            .project(vec![col("a").is_unknown()])?
1273            .build()?;
1274
1275        assert_optimized_plan_equal!(
1276            plan,
1277            @r"
1278        Projection: test.a IS UNKNOWN
1279          TableScan: test projection=[a]
1280        "
1281        )
1282    }
1283
1284    #[test]
1285    fn test_is_not_unknown() -> Result<()> {
1286        let table_scan = test_table_scan()?;
1287        let plan = LogicalPlanBuilder::from(table_scan)
1288            .project(vec![col("a").is_not_unknown()])?
1289            .build()?;
1290
1291        assert_optimized_plan_equal!(
1292            plan,
1293            @r"
1294        Projection: test.a IS NOT UNKNOWN
1295          TableScan: test projection=[a]
1296        "
1297        )
1298    }
1299
1300    #[test]
1301    fn test_not() -> Result<()> {
1302        let table_scan = test_table_scan()?;
1303        let plan = LogicalPlanBuilder::from(table_scan)
1304            .project(vec![not(col("a"))])?
1305            .build()?;
1306
1307        assert_optimized_plan_equal!(
1308            plan,
1309            @r"
1310        Projection: NOT test.a
1311          TableScan: test projection=[a]
1312        "
1313        )
1314    }
1315
1316    #[test]
1317    fn test_try_cast() -> Result<()> {
1318        let table_scan = test_table_scan()?;
1319        let plan = LogicalPlanBuilder::from(table_scan)
1320            .project(vec![try_cast(col("a"), DataType::Float64)])?
1321            .build()?;
1322
1323        assert_optimized_plan_equal!(
1324            plan,
1325            @r"
1326        Projection: TRY_CAST(test.a AS Float64)
1327          TableScan: test projection=[a]
1328        "
1329        )
1330    }
1331
1332    #[test]
1333    fn test_similar_to() -> Result<()> {
1334        let table_scan = test_table_scan()?;
1335        let expr = Box::new(col("a"));
1336        let pattern = Box::new(lit("[0-9]"));
1337        let similar_to_expr =
1338            Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1339        let plan = LogicalPlanBuilder::from(table_scan)
1340            .project(vec![similar_to_expr])?
1341            .build()?;
1342
1343        assert_optimized_plan_equal!(
1344            plan,
1345            @r#"
1346        Projection: test.a SIMILAR TO Utf8("[0-9]")
1347          TableScan: test projection=[a]
1348        "#
1349        )
1350    }
1351
1352    #[test]
1353    fn test_between() -> Result<()> {
1354        let table_scan = test_table_scan()?;
1355        let plan = LogicalPlanBuilder::from(table_scan)
1356            .project(vec![col("a").between(lit(1), lit(3))])?
1357            .build()?;
1358
1359        assert_optimized_plan_equal!(
1360            plan,
1361            @r"
1362        Projection: test.a BETWEEN Int32(1) AND Int32(3)
1363          TableScan: test projection=[a]
1364        "
1365        )
1366    }
1367
1368    // Test Case expression
1369    #[test]
1370    fn test_case_merged() -> Result<()> {
1371        let table_scan = test_table_scan()?;
1372        let plan = LogicalPlanBuilder::from(table_scan)
1373            .project(vec![col("a"), lit(0).alias("d")])?
1374            .project(vec![
1375                col("a"),
1376                when(col("a").eq(lit(1)), lit(10))
1377                    .otherwise(col("d"))?
1378                    .alias("d"),
1379            ])?
1380            .build()?;
1381
1382        assert_optimized_plan_equal!(
1383            plan,
1384            @r"
1385        Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1386          TableScan: test projection=[a]
1387        "
1388        )
1389    }
1390
1391    // Test outer projection isn't discarded despite the same schema as inner
1392    // https://github.com/apache/datafusion/issues/8942
1393    #[test]
1394    fn test_derived_column() -> Result<()> {
1395        let table_scan = test_table_scan()?;
1396        let plan = LogicalPlanBuilder::from(table_scan)
1397            .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1398            .project(vec![
1399                col("a"),
1400                when(col("a").eq(lit(1)), lit(10))
1401                    .otherwise(col("d"))?
1402                    .alias("d"),
1403            ])?
1404            .build()?;
1405
1406        assert_optimized_plan_equal!(
1407            plan,
1408            @r"
1409        Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1410          Projection: test.a + Int32(1) AS a, Int32(0) AS d
1411            TableScan: test projection=[a]
1412        "
1413        )
1414    }
1415
1416    // Since only column `a` is referred at the output. Scan should only contain projection=[a].
1417    // User defined node should be able to propagate necessary expressions by its parent to its child.
1418    #[test]
1419    fn test_user_defined_logical_plan_node() -> Result<()> {
1420        let table_scan = test_table_scan()?;
1421        let custom_plan = LogicalPlan::Extension(Extension {
1422            node: Arc::new(NoOpUserDefined::new(
1423                Arc::clone(table_scan.schema()),
1424                Arc::new(table_scan.clone()),
1425            )),
1426        });
1427        let plan = LogicalPlanBuilder::from(custom_plan)
1428            .project(vec![col("a"), lit(0).alias("d")])?
1429            .build()?;
1430
1431        assert_optimized_plan_equal!(
1432            plan,
1433            @r"
1434        Projection: test.a, Int32(0) AS d
1435          NoOpUserDefined
1436            TableScan: test projection=[a]
1437        "
1438        )
1439    }
1440
1441    // Only column `a` is referred at the output. However, User defined node itself uses column `b`
1442    // during its operation. Hence, scan should contain projection=[a, b].
1443    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1444    // required expressions.
1445    #[test]
1446    fn test_user_defined_logical_plan_node2() -> Result<()> {
1447        let table_scan = test_table_scan()?;
1448        let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1449        let custom_plan = LogicalPlan::Extension(Extension {
1450            node: Arc::new(
1451                NoOpUserDefined::new(
1452                    Arc::clone(table_scan.schema()),
1453                    Arc::new(table_scan.clone()),
1454                )
1455                .with_exprs(exprs),
1456            ),
1457        });
1458        let plan = LogicalPlanBuilder::from(custom_plan)
1459            .project(vec![col("a"), lit(0).alias("d")])?
1460            .build()?;
1461
1462        assert_optimized_plan_equal!(
1463            plan,
1464            @r"
1465        Projection: test.a, Int32(0) AS d
1466          NoOpUserDefined
1467            TableScan: test projection=[a, b]
1468        "
1469        )
1470    }
1471
1472    // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c`
1473    // during its operation. Hence, scan should contain projection=[a, b, c].
1474    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1475    // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions
1476    // should be propagated also.
1477    #[test]
1478    fn test_user_defined_logical_plan_node3() -> Result<()> {
1479        let table_scan = test_table_scan()?;
1480        let left_expr = Expr::Column(Column::from_qualified_name("b"));
1481        let right_expr = Expr::Column(Column::from_qualified_name("c"));
1482        let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1483            Box::new(left_expr),
1484            Operator::Plus,
1485            Box::new(right_expr),
1486        ));
1487        let exprs = vec![binary_expr];
1488        let custom_plan = LogicalPlan::Extension(Extension {
1489            node: Arc::new(
1490                NoOpUserDefined::new(
1491                    Arc::clone(table_scan.schema()),
1492                    Arc::new(table_scan.clone()),
1493                )
1494                .with_exprs(exprs),
1495            ),
1496        });
1497        let plan = LogicalPlanBuilder::from(custom_plan)
1498            .project(vec![col("a"), lit(0).alias("d")])?
1499            .build()?;
1500
1501        assert_optimized_plan_equal!(
1502            plan,
1503            @r"
1504        Projection: test.a, Int32(0) AS d
1505          NoOpUserDefined
1506            TableScan: test projection=[a, b, c]
1507        "
1508        )
1509    }
1510
1511    // Columns `l.a`, `l.c`, `r.a` is referred at the output.
1512    // User defined node should be able to propagate necessary expressions by its parent, to its children.
1513    // Even if it has multiple children.
1514    // left child should have `projection=[a, c]`, and right side should have `projection=[a]`.
1515    #[test]
1516    fn test_user_defined_logical_plan_node4() -> Result<()> {
1517        let left_table = test_table_scan_with_name("l")?;
1518        let right_table = test_table_scan_with_name("r")?;
1519        let custom_plan = LogicalPlan::Extension(Extension {
1520            node: Arc::new(UserDefinedCrossJoin::new(
1521                Arc::new(left_table),
1522                Arc::new(right_table),
1523            )),
1524        });
1525        let plan = LogicalPlanBuilder::from(custom_plan)
1526            .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1527            .build()?;
1528
1529        assert_optimized_plan_equal!(
1530            plan,
1531            @r"
1532        Projection: l.a, l.c, r.a, Int32(0) AS d
1533          UserDefinedCrossJoin
1534            TableScan: l projection=[a, c]
1535            TableScan: r projection=[a]
1536        "
1537        )
1538    }
1539
1540    #[test]
1541    fn aggregate_no_group_by() -> Result<()> {
1542        let table_scan = test_table_scan()?;
1543
1544        let plan = LogicalPlanBuilder::from(table_scan)
1545            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1546            .build()?;
1547
1548        assert_optimized_plan_equal!(
1549            plan,
1550            @r"
1551        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1552          TableScan: test projection=[b]
1553        "
1554        )
1555    }
1556
1557    #[test]
1558    fn aggregate_group_by() -> Result<()> {
1559        let table_scan = test_table_scan()?;
1560
1561        let plan = LogicalPlanBuilder::from(table_scan)
1562            .aggregate(vec![col("c")], vec![max(col("b"))])?
1563            .build()?;
1564
1565        assert_optimized_plan_equal!(
1566            plan,
1567            @r"
1568        Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1569          TableScan: test projection=[b, c]
1570        "
1571        )
1572    }
1573
1574    #[test]
1575    fn aggregate_group_by_with_table_alias() -> Result<()> {
1576        let table_scan = test_table_scan()?;
1577
1578        let plan = LogicalPlanBuilder::from(table_scan)
1579            .alias("a")?
1580            .aggregate(vec![col("c")], vec![max(col("b"))])?
1581            .build()?;
1582
1583        assert_optimized_plan_equal!(
1584            plan,
1585            @r"
1586        Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1587          SubqueryAlias: a
1588            TableScan: test projection=[b, c]
1589        "
1590        )
1591    }
1592
1593    #[test]
1594    fn aggregate_no_group_by_with_filter() -> Result<()> {
1595        let table_scan = test_table_scan()?;
1596
1597        let plan = LogicalPlanBuilder::from(table_scan)
1598            .filter(col("c").gt(lit(1)))?
1599            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1600            .build()?;
1601
1602        assert_optimized_plan_equal!(
1603            plan,
1604            @r"
1605        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1606          Projection: test.b
1607            Filter: test.c > Int32(1)
1608              TableScan: test projection=[b, c]
1609        "
1610        )
1611    }
1612
1613    #[test]
1614    fn aggregate_with_periods() -> Result<()> {
1615        let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1616
1617        // Build a plan that looks as follows (note "tag.one" is a column named
1618        // "tag.one", not a column named "one" in a table named "tag"):
1619        //
1620        // Projection: tag.one
1621        //   Aggregate: groupBy=[], aggr=[max("tag.one") AS "tag.one"]
1622        //    TableScan
1623        let plan = table_scan(Some("m4"), &schema, None)?
1624            .aggregate(
1625                Vec::<Expr>::new(),
1626                vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1627            )?
1628            .project([col(Column::new_unqualified("tag.one"))])?
1629            .build()?;
1630
1631        assert_optimized_plan_equal!(
1632            plan,
1633            @r"
1634        Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1635          TableScan: m4 projection=[tag.one]
1636        "
1637        )
1638    }
1639
1640    #[test]
1641    fn redundant_project() -> Result<()> {
1642        let table_scan = test_table_scan()?;
1643
1644        let plan = LogicalPlanBuilder::from(table_scan)
1645            .project(vec![col("a"), col("b"), col("c")])?
1646            .project(vec![col("a"), col("c"), col("b")])?
1647            .build()?;
1648        assert_optimized_plan_equal!(
1649            plan,
1650            @r"
1651        Projection: test.a, test.c, test.b
1652          TableScan: test projection=[a, b, c]
1653        "
1654        )
1655    }
1656
1657    #[test]
1658    fn reorder_scan() -> Result<()> {
1659        let schema = Schema::new(test_table_scan_fields());
1660
1661        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1662        assert_optimized_plan_equal!(
1663            plan,
1664            @"TableScan: test projection=[b, a, c]"
1665        )
1666    }
1667
1668    #[test]
1669    fn reorder_scan_projection() -> Result<()> {
1670        let schema = Schema::new(test_table_scan_fields());
1671
1672        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1673            .project(vec![col("a"), col("b")])?
1674            .build()?;
1675        assert_optimized_plan_equal!(
1676            plan,
1677            @r"
1678        Projection: test.a, test.b
1679          TableScan: test projection=[b, a]
1680        "
1681        )
1682    }
1683
1684    #[test]
1685    fn reorder_projection() -> Result<()> {
1686        let table_scan = test_table_scan()?;
1687
1688        let plan = LogicalPlanBuilder::from(table_scan)
1689            .project(vec![col("c"), col("b"), col("a")])?
1690            .build()?;
1691        assert_optimized_plan_equal!(
1692            plan,
1693            @r"
1694        Projection: test.c, test.b, test.a
1695          TableScan: test projection=[a, b, c]
1696        "
1697        )
1698    }
1699
1700    #[test]
1701    fn noncontinuous_redundant_projection() -> Result<()> {
1702        let table_scan = test_table_scan()?;
1703
1704        let plan = LogicalPlanBuilder::from(table_scan)
1705            .project(vec![col("c"), col("b"), col("a")])?
1706            .filter(col("c").gt(lit(1)))?
1707            .project(vec![col("c"), col("a"), col("b")])?
1708            .filter(col("b").gt(lit(1)))?
1709            .filter(col("a").gt(lit(1)))?
1710            .project(vec![col("a"), col("c"), col("b")])?
1711            .build()?;
1712        assert_optimized_plan_equal!(
1713            plan,
1714            @r"
1715        Projection: test.a, test.c, test.b
1716          Filter: test.a > Int32(1)
1717            Filter: test.b > Int32(1)
1718              Projection: test.c, test.a, test.b
1719                Filter: test.c > Int32(1)
1720                  Projection: test.c, test.b, test.a
1721                    TableScan: test projection=[a, b, c]
1722        "
1723        )
1724    }
1725
1726    #[test]
1727    fn join_schema_trim_full_join_column_projection() -> Result<()> {
1728        let table_scan = test_table_scan()?;
1729
1730        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1731        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1732
1733        let plan = LogicalPlanBuilder::from(table_scan)
1734            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1735            .project(vec![col("a"), col("b"), col("c1")])?
1736            .build()?;
1737
1738        let optimized_plan = optimize(plan)?;
1739
1740        // make sure projections are pushed down to both table scans
1741        assert_snapshot!(
1742            optimized_plan.clone(),
1743            @r"
1744        Left Join: test.a = test2.c1
1745          TableScan: test projection=[a, b]
1746          TableScan: test2 projection=[c1]
1747        "
1748        );
1749
1750        // make sure schema for join node include both join columns
1751        let optimized_join = optimized_plan;
1752        assert_eq!(
1753            **optimized_join.schema(),
1754            DFSchema::new_with_metadata(
1755                vec![
1756                    (
1757                        Some("test".into()),
1758                        Arc::new(Field::new("a", DataType::UInt32, false))
1759                    ),
1760                    (
1761                        Some("test".into()),
1762                        Arc::new(Field::new("b", DataType::UInt32, false))
1763                    ),
1764                    (
1765                        Some("test2".into()),
1766                        Arc::new(Field::new("c1", DataType::UInt32, true))
1767                    ),
1768                ],
1769                HashMap::new()
1770            )?,
1771        );
1772
1773        Ok(())
1774    }
1775
1776    #[test]
1777    fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1778        // test join column push down without explicit column projections
1779
1780        let table_scan = test_table_scan()?;
1781
1782        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1783        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1784
1785        let plan = LogicalPlanBuilder::from(table_scan)
1786            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1787            // projecting joined column `a` should push the right side column `c1` projection as
1788            // well into test2 table even though `c1` is not referenced in projection.
1789            .project(vec![col("a"), col("b")])?
1790            .build()?;
1791
1792        let optimized_plan = optimize(plan)?;
1793
1794        // make sure projections are pushed down to both table scans
1795        assert_snapshot!(
1796            optimized_plan.clone(),
1797            @r"
1798        Projection: test.a, test.b
1799          Left Join: test.a = test2.c1
1800            TableScan: test projection=[a, b]
1801            TableScan: test2 projection=[c1]
1802        "
1803        );
1804
1805        // make sure schema for join node include both join columns
1806        let optimized_join = optimized_plan.inputs()[0];
1807        assert_eq!(
1808            **optimized_join.schema(),
1809            DFSchema::new_with_metadata(
1810                vec![
1811                    (
1812                        Some("test".into()),
1813                        Arc::new(Field::new("a", DataType::UInt32, false))
1814                    ),
1815                    (
1816                        Some("test".into()),
1817                        Arc::new(Field::new("b", DataType::UInt32, false))
1818                    ),
1819                    (
1820                        Some("test2".into()),
1821                        Arc::new(Field::new("c1", DataType::UInt32, true))
1822                    ),
1823                ],
1824                HashMap::new()
1825            )?,
1826        );
1827
1828        Ok(())
1829    }
1830
1831    #[test]
1832    fn join_schema_trim_using_join() -> Result<()> {
1833        // shared join columns from using join should be pushed to both sides
1834
1835        let table_scan = test_table_scan()?;
1836
1837        let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1838        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1839
1840        let plan = LogicalPlanBuilder::from(table_scan)
1841            .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1842            .project(vec![col("a"), col("b")])?
1843            .build()?;
1844
1845        let optimized_plan = optimize(plan)?;
1846
1847        // make sure projections are pushed down to table scan
1848        assert_snapshot!(
1849            optimized_plan.clone(),
1850            @r"
1851        Projection: test.a, test.b
1852          Left Join: Using test.a = test2.a
1853            TableScan: test projection=[a, b]
1854            TableScan: test2 projection=[a]
1855        "
1856        );
1857
1858        // make sure schema for join node include both join columns
1859        let optimized_join = optimized_plan.inputs()[0];
1860        assert_eq!(
1861            **optimized_join.schema(),
1862            DFSchema::new_with_metadata(
1863                vec![
1864                    (
1865                        Some("test".into()),
1866                        Arc::new(Field::new("a", DataType::UInt32, false))
1867                    ),
1868                    (
1869                        Some("test".into()),
1870                        Arc::new(Field::new("b", DataType::UInt32, false))
1871                    ),
1872                    (
1873                        Some("test2".into()),
1874                        Arc::new(Field::new("a", DataType::UInt32, true))
1875                    ),
1876                ],
1877                HashMap::new()
1878            )?,
1879        );
1880
1881        Ok(())
1882    }
1883
1884    #[test]
1885    fn cast() -> Result<()> {
1886        let table_scan = test_table_scan()?;
1887
1888        let plan = LogicalPlanBuilder::from(table_scan)
1889            .project(vec![Expr::Cast(Cast::new(
1890                Box::new(col("c")),
1891                DataType::Float64,
1892            ))])?
1893            .build()?;
1894
1895        assert_optimized_plan_equal!(
1896            plan,
1897            @r"
1898        Projection: CAST(test.c AS Float64)
1899          TableScan: test projection=[c]
1900        "
1901        )
1902    }
1903
1904    #[test]
1905    fn table_scan_projected_schema() -> Result<()> {
1906        let table_scan = test_table_scan()?;
1907        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1908            .project(vec![col("a"), col("b")])?
1909            .build()?;
1910
1911        assert_eq!(3, table_scan.schema().fields().len());
1912        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1913        assert_fields_eq(&plan, vec!["a", "b"]);
1914
1915        assert_optimized_plan_equal!(
1916            plan,
1917            @"TableScan: test projection=[a, b]"
1918        )
1919    }
1920
1921    #[test]
1922    fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
1923        let table_scan = test_table_scan()?;
1924        let input_schema = table_scan.schema();
1925        assert_eq!(3, input_schema.fields().len());
1926        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1927
1928        // Build the LogicalPlan directly (don't use PlanBuilder), so
1929        // that the Column references are unqualified (e.g. their
1930        // relation is `None`). PlanBuilder resolves the expressions
1931        let expr = vec![col("test.a"), col("test.b")];
1932        let plan =
1933            LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
1934
1935        assert_fields_eq(&plan, vec!["a", "b"]);
1936
1937        assert_optimized_plan_equal!(
1938            plan,
1939            @"TableScan: test projection=[a, b]"
1940        )
1941    }
1942
1943    #[test]
1944    fn table_limit() -> Result<()> {
1945        let table_scan = test_table_scan()?;
1946        assert_eq!(3, table_scan.schema().fields().len());
1947        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1948
1949        let plan = LogicalPlanBuilder::from(table_scan)
1950            .project(vec![col("c"), col("a")])?
1951            .limit(0, Some(5))?
1952            .build()?;
1953
1954        assert_fields_eq(&plan, vec!["c", "a"]);
1955
1956        assert_optimized_plan_equal!(
1957            plan,
1958            @r"
1959        Limit: skip=0, fetch=5
1960          Projection: test.c, test.a
1961            TableScan: test projection=[a, c]
1962        "
1963        )
1964    }
1965
1966    #[test]
1967    fn table_scan_without_projection() -> Result<()> {
1968        let table_scan = test_table_scan()?;
1969        let plan = LogicalPlanBuilder::from(table_scan).build()?;
1970        // should expand projection to all columns without projection
1971        assert_optimized_plan_equal!(
1972            plan,
1973            @"TableScan: test projection=[a, b, c]"
1974        )
1975    }
1976
1977    #[test]
1978    fn table_scan_with_literal_projection() -> Result<()> {
1979        let table_scan = test_table_scan()?;
1980        let plan = LogicalPlanBuilder::from(table_scan)
1981            .project(vec![lit(1_i64), lit(2_i64)])?
1982            .build()?;
1983        assert_optimized_plan_equal!(
1984            plan,
1985            @r"
1986        Projection: Int64(1), Int64(2)
1987          TableScan: test projection=[]
1988        "
1989        )
1990    }
1991
1992    /// tests that it removes unused columns in projections
1993    #[test]
1994    fn table_unused_column() -> Result<()> {
1995        let table_scan = test_table_scan()?;
1996        assert_eq!(3, table_scan.schema().fields().len());
1997        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
1998
1999        // we never use "b" in the first projection => remove it
2000        let plan = LogicalPlanBuilder::from(table_scan)
2001            .project(vec![col("c"), col("a"), col("b")])?
2002            .filter(col("c").gt(lit(1)))?
2003            .aggregate(vec![col("c")], vec![max(col("a"))])?
2004            .build()?;
2005
2006        assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
2007
2008        let plan = optimize(plan).expect("failed to optimize plan");
2009        assert_optimized_plan_equal!(
2010            plan,
2011            @r"
2012        Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2013          Filter: test.c > Int32(1)
2014            Projection: test.c, test.a
2015              TableScan: test projection=[a, c]
2016        "
2017        )
2018    }
2019
2020    /// tests that it removes un-needed projections
2021    #[test]
2022    fn table_unused_projection() -> Result<()> {
2023        let table_scan = test_table_scan()?;
2024        assert_eq!(3, table_scan.schema().fields().len());
2025        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2026
2027        // there is no need for the first projection
2028        let plan = LogicalPlanBuilder::from(table_scan)
2029            .project(vec![col("b")])?
2030            .project(vec![lit(1).alias("a")])?
2031            .build()?;
2032
2033        assert_fields_eq(&plan, vec!["a"]);
2034
2035        assert_optimized_plan_equal!(
2036            plan,
2037            @r"
2038        Projection: Int32(1) AS a
2039          TableScan: test projection=[]
2040        "
2041        )
2042    }
2043
2044    #[test]
2045    fn table_full_filter_pushdown() -> Result<()> {
2046        let schema = Schema::new(test_table_scan_fields());
2047
2048        let table_scan = table_scan_with_filters(
2049            Some("test"),
2050            &schema,
2051            None,
2052            vec![col("b").eq(lit(1))],
2053        )?
2054        .build()?;
2055        assert_eq!(3, table_scan.schema().fields().len());
2056        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2057
2058        // there is no need for the first projection
2059        let plan = LogicalPlanBuilder::from(table_scan)
2060            .project(vec![col("b")])?
2061            .project(vec![lit(1).alias("a")])?
2062            .build()?;
2063
2064        assert_fields_eq(&plan, vec!["a"]);
2065
2066        assert_optimized_plan_equal!(
2067            plan,
2068            @r"
2069        Projection: Int32(1) AS a
2070          TableScan: test projection=[], full_filters=[b = Int32(1)]
2071        "
2072        )
2073    }
2074
2075    /// tests that optimizing twice yields same plan
2076    #[test]
2077    fn test_double_optimization() -> Result<()> {
2078        let table_scan = test_table_scan()?;
2079
2080        let plan = LogicalPlanBuilder::from(table_scan)
2081            .project(vec![col("b")])?
2082            .project(vec![lit(1).alias("a")])?
2083            .build()?;
2084
2085        let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2086        let optimized_plan2 =
2087            optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2088
2089        let formatted_plan1 = format!("{optimized_plan1:?}");
2090        let formatted_plan2 = format!("{optimized_plan2:?}");
2091        assert_eq!(formatted_plan1, formatted_plan2);
2092        Ok(())
2093    }
2094
2095    /// tests that it removes an aggregate is never used downstream
2096    #[test]
2097    fn table_unused_aggregate() -> Result<()> {
2098        let table_scan = test_table_scan()?;
2099        assert_eq!(3, table_scan.schema().fields().len());
2100        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2101
2102        // we never use "min(b)" => remove it
2103        let plan = LogicalPlanBuilder::from(table_scan)
2104            .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2105            .filter(col("c").gt(lit(1)))?
2106            .project(vec![col("c"), col("a"), col("max(test.b)")])?
2107            .build()?;
2108
2109        assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2110
2111        assert_optimized_plan_equal!(
2112            plan,
2113            @r"
2114        Projection: test.c, test.a, max(test.b)
2115          Filter: test.c > Int32(1)
2116            Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2117              TableScan: test projection=[a, b, c]
2118        "
2119        )
2120    }
2121
2122    #[test]
2123    fn aggregate_filter_pushdown() -> Result<()> {
2124        let table_scan = test_table_scan()?;
2125        let aggr_with_filter = count_udaf()
2126            .call(vec![col("b")])
2127            .filter(col("c").gt(lit(42)))
2128            .build()?;
2129        let plan = LogicalPlanBuilder::from(table_scan)
2130            .aggregate(
2131                vec![col("a")],
2132                vec![count(col("b")), aggr_with_filter.alias("count2")],
2133            )?
2134            .build()?;
2135
2136        assert_optimized_plan_equal!(
2137            plan,
2138            @r"
2139        Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2140          TableScan: test projection=[a, b, c]
2141        "
2142        )
2143    }
2144
2145    #[test]
2146    fn pushdown_through_distinct() -> Result<()> {
2147        let table_scan = test_table_scan()?;
2148
2149        let plan = LogicalPlanBuilder::from(table_scan)
2150            .project(vec![col("a"), col("b")])?
2151            .distinct()?
2152            .project(vec![col("a")])?
2153            .build()?;
2154
2155        assert_optimized_plan_equal!(
2156            plan,
2157            @r"
2158        Projection: test.a
2159          Distinct:
2160            TableScan: test projection=[a, b]
2161        "
2162        )
2163    }
2164
2165    #[test]
2166    fn test_window() -> Result<()> {
2167        let table_scan = test_table_scan()?;
2168
2169        let max1 = Expr::from(expr::WindowFunction::new(
2170            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2171            vec![col("test.a")],
2172        ))
2173        .partition_by(vec![col("test.b")])
2174        .build()
2175        .unwrap();
2176
2177        let max2 = Expr::from(expr::WindowFunction::new(
2178            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2179            vec![col("test.b")],
2180        ));
2181        let col1 = col(max1.schema_name().to_string());
2182        let col2 = col(max2.schema_name().to_string());
2183
2184        let plan = LogicalPlanBuilder::from(table_scan)
2185            .window(vec![max1])?
2186            .window(vec![max2])?
2187            .project(vec![col1, col2])?
2188            .build()?;
2189
2190        assert_optimized_plan_equal!(
2191            plan,
2192            @r"
2193        Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2194          WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2195            Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2196              WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2197                TableScan: test projection=[a, b]
2198        "
2199        )
2200    }
2201
2202    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2203
2204    fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2205        let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2206        let optimized_plan =
2207            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2208        Ok(optimized_plan)
2209    }
2210}