datafusion_optimizer/optimize_projections/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`OptimizeProjections`] identifies and eliminates unused columns
19
20mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28    get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column,
29    DFSchema, HashMap, JoinType, Result,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33    logical_plan::LogicalPlan, Aggregate, Distinct, EmptyRelation, Expr, Projection,
34    TableScan, Unnest, Window,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43/// Optimizer rule to prune unnecessary columns from intermediate schemas
44/// inside the [`LogicalPlan`]. This rule:
45/// - Removes unnecessary columns that do not appear at the output and/or are
46///   not used during any computation step.
47/// - Adds projections to decrease table column size before operators that
48///   benefit from a smaller memory footprint at its input.
49/// - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`].
50///
51/// `OptimizeProjections` is an optimizer rule that identifies and eliminates
52/// columns from a logical plan that are not used by downstream operations.
53/// This can improve query performance and reduce unnecessary data processing.
54///
55/// The rule analyzes the input logical plan, determines the necessary column
56/// indices, and then removes any unnecessary columns. It also removes any
57/// unnecessary projections from the plan tree.
58///
59/// ## Schema, Field Properties, and Metadata Handling
60///
61/// The `OptimizeProjections` rule preserves schema and field metadata in most optimization scenarios:
62///
63/// **Schema-level metadata preservation by plan type**:
64/// - **Window and Aggregate plans**: Schema metadata is preserved
65/// - **Projection plans**: Schema metadata is preserved per [`projection_schema`](datafusion_expr::logical_plan::projection_schema).
66/// - **Other logical plans**: Schema metadata is preserved unless [`LogicalPlan::recompute_schema`]
67///   is called on plan types that drop metadata
68///
69/// **Field-level properties and metadata**: Individual field properties are preserved when fields
70/// are retained in the optimized plan, determined by [`exprlist_to_fields`](datafusion_expr::utils::exprlist_to_fields)
71/// and [`ExprSchemable::to_field`](datafusion_expr::expr_schema::ExprSchemable::to_field).
72///
73/// **Field precedence**: When the same field appears multiple times, the optimizer
74/// maintains one occurrence and removes duplicates (refer to `RequiredIndices::compact()`),
75/// preserving the properties and metadata of that occurrence.
76#[derive(Default, Debug)]
77pub struct OptimizeProjections {}
78
79impl OptimizeProjections {
80    #[allow(missing_docs)]
81    pub fn new() -> Self {
82        Self {}
83    }
84}
85
86impl OptimizerRule for OptimizeProjections {
87    fn name(&self) -> &str {
88        "optimize_projections"
89    }
90
91    fn apply_order(&self) -> Option<ApplyOrder> {
92        None
93    }
94
95    fn supports_rewrite(&self) -> bool {
96        true
97    }
98
99    fn rewrite(
100        &self,
101        plan: LogicalPlan,
102        config: &dyn OptimizerConfig,
103    ) -> Result<Transformed<LogicalPlan>> {
104        // All output fields are necessary:
105        let indices = RequiredIndices::new_for_all_exprs(&plan);
106        optimize_projections(plan, config, indices)
107    }
108}
109
110/// Removes unnecessary columns (e.g. columns that do not appear in the output
111/// schema and/or are not used during any computation step such as expression
112/// evaluation) from the logical plan and its inputs.
113///
114/// # Parameters
115///
116/// - `plan`: A reference to the input `LogicalPlan` to optimize.
117/// - `config`: A reference to the optimizer configuration.
118/// - `indices`: A slice of column indices that represent the necessary column
119///   indices for downstream (parent) plan nodes.
120///
121/// # Returns
122///
123/// A `Result` object with the following semantics:
124///
125/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary
126///   columns.
127/// - `Ok(None)`: Signal that the given logical plan did not require any change.
128/// - `Err(error)`: An error occurred during the optimization process.
129#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
130fn optimize_projections(
131    plan: LogicalPlan,
132    config: &dyn OptimizerConfig,
133    indices: RequiredIndices,
134) -> Result<Transformed<LogicalPlan>> {
135    // Recursively rewrite any nodes that may be able to avoid computation given
136    // their parents' required indices.
137    match plan {
138        LogicalPlan::Projection(proj) => {
139            return merge_consecutive_projections(proj)?.transform_data(|proj| {
140                rewrite_projection_given_requirements(proj, config, &indices)
141            })
142        }
143        LogicalPlan::Aggregate(aggregate) => {
144            // Split parent requirements to GROUP BY and aggregate sections:
145            let n_group_exprs = aggregate.group_expr_len()?;
146            // Offset aggregate indices so that they point to valid indices at
147            // `aggregate.aggr_expr`:
148            let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
149
150            // Get absolutely necessary GROUP BY fields:
151            let group_by_expr_existing = aggregate
152                .group_expr
153                .iter()
154                .map(|group_by_expr| group_by_expr.schema_name().to_string())
155                .collect::<Vec<_>>();
156
157            let new_group_bys = if let Some(simplest_groupby_indices) =
158                get_required_group_by_exprs_indices(
159                    aggregate.input.schema(),
160                    &group_by_expr_existing,
161                ) {
162                // Some of the fields in the GROUP BY may be required by the
163                // parent even if these fields are unnecessary in terms of
164                // functional dependency.
165                group_by_reqs
166                    .append(&simplest_groupby_indices)
167                    .get_at_indices(&aggregate.group_expr)
168            } else {
169                aggregate.group_expr
170            };
171
172            // Only use the absolutely necessary aggregate expressions required
173            // by the parent:
174            let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
175
176            if new_group_bys.is_empty() && new_aggr_expr.is_empty() {
177                // Global aggregation with no aggregate functions always produces 1 row and no columns.
178                return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
179                    EmptyRelation {
180                        produce_one_row: true,
181                        schema: Arc::new(DFSchema::empty()),
182                    },
183                )));
184            }
185
186            let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
187            let schema = aggregate.input.schema();
188            let necessary_indices =
189                RequiredIndices::new().with_exprs(schema, all_exprs_iter);
190            let necessary_exprs = necessary_indices.get_required_exprs(schema);
191
192            return optimize_projections(
193                Arc::unwrap_or_clone(aggregate.input),
194                config,
195                necessary_indices,
196            )?
197            .transform_data(|aggregate_input| {
198                // Simplify the input of the aggregation by adding a projection so
199                // that its input only contains absolutely necessary columns for
200                // the aggregate expressions. Note that necessary_indices refer to
201                // fields in `aggregate.input.schema()`.
202                add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
203            })?
204            .map_data(|aggregate_input| {
205                // Create a new aggregate plan with the updated input and only the
206                // absolutely necessary fields:
207                Aggregate::try_new(
208                    Arc::new(aggregate_input),
209                    new_group_bys,
210                    new_aggr_expr,
211                )
212                .map(LogicalPlan::Aggregate)
213            });
214        }
215        LogicalPlan::Window(window) => {
216            let input_schema = Arc::clone(window.input.schema());
217            // Split parent requirements to child and window expression sections:
218            let n_input_fields = input_schema.fields().len();
219            // Offset window expression indices so that they point to valid
220            // indices at `window.window_expr`:
221            let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
222
223            // Only use window expressions that are absolutely necessary according
224            // to parent requirements:
225            let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
226
227            // Get all the required column indices at the input, either by the
228            // parent or window expression requirements.
229            let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
230
231            return optimize_projections(
232                Arc::unwrap_or_clone(window.input),
233                config,
234                required_indices.clone(),
235            )?
236            .transform_data(|window_child| {
237                if new_window_expr.is_empty() {
238                    // When no window expression is necessary, use the input directly:
239                    Ok(Transformed::no(window_child))
240                } else {
241                    // Calculate required expressions at the input of the window.
242                    // Please note that we use `input_schema`, because `required_indices`
243                    // refers to that schema
244                    let required_exprs =
245                        required_indices.get_required_exprs(&input_schema);
246                    let window_child =
247                        add_projection_on_top_if_helpful(window_child, required_exprs)?
248                            .data;
249                    Window::try_new(new_window_expr, Arc::new(window_child))
250                        .map(LogicalPlan::Window)
251                        .map(Transformed::yes)
252                }
253            });
254        }
255        LogicalPlan::TableScan(table_scan) => {
256            let TableScan {
257                table_name,
258                source,
259                projection,
260                filters,
261                fetch,
262                projected_schema: _,
263            } = table_scan;
264
265            // Get indices referred to in the original (schema with all fields)
266            // given projected indices.
267            let projection = match &projection {
268                Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
269                None => indices.into_inner(),
270            };
271            return TableScan::try_new(
272                table_name,
273                source,
274                Some(projection),
275                filters,
276                fetch,
277            )
278            .map(LogicalPlan::TableScan)
279            .map(Transformed::yes);
280        }
281        // Other node types are handled below
282        _ => {}
283    };
284
285    // For other plan node types, calculate indices for columns they use and
286    // try to rewrite their children
287    let mut child_required_indices: Vec<RequiredIndices> = match &plan {
288        LogicalPlan::Sort(_)
289        | LogicalPlan::Filter(_)
290        | LogicalPlan::Repartition(_)
291        | LogicalPlan::Union(_)
292        | LogicalPlan::SubqueryAlias(_)
293        | LogicalPlan::Distinct(Distinct::On(_)) => {
294            // Pass index requirements from the parent as well as column indices
295            // that appear in this plan's expressions to its child. All these
296            // operators benefit from "small" inputs, so the projection_beneficial
297            // flag is `true`.
298            plan.inputs()
299                .into_iter()
300                .map(|input| {
301                    indices
302                        .clone()
303                        .with_projection_beneficial()
304                        .with_plan_exprs(&plan, input.schema())
305                })
306                .collect::<Result<_>>()?
307        }
308        LogicalPlan::Limit(_) => {
309            // Pass index requirements from the parent as well as column indices
310            // that appear in this plan's expressions to its child. These operators
311            // do not benefit from "small" inputs, so the projection_beneficial
312            // flag is `false`.
313            plan.inputs()
314                .into_iter()
315                .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
316                .collect::<Result<_>>()?
317        }
318        LogicalPlan::Copy(_)
319        | LogicalPlan::Ddl(_)
320        | LogicalPlan::Dml(_)
321        | LogicalPlan::Explain(_)
322        | LogicalPlan::Analyze(_)
323        | LogicalPlan::Subquery(_)
324        | LogicalPlan::Statement(_)
325        | LogicalPlan::Distinct(Distinct::All(_)) => {
326            // These plans require all their fields, and their children should
327            // be treated as final plans -- otherwise, we may have schema a
328            // mismatch.
329            // TODO: For some subquery variants (e.g. a subquery arising from an
330            //       EXISTS expression), we may not need to require all indices.
331            plan.inputs()
332                .into_iter()
333                .map(RequiredIndices::new_for_all_exprs)
334                .collect()
335        }
336        LogicalPlan::Extension(extension) => {
337            let Some(necessary_children_indices) =
338                extension.node.necessary_children_exprs(indices.indices())
339            else {
340                // Requirements from parent cannot be routed down to user defined logical plan safely
341                return Ok(Transformed::no(plan));
342            };
343            let children = extension.node.inputs();
344            if children.len() != necessary_children_indices.len() {
345                return internal_err!("Inconsistent length between children and necessary children indices. \
346                Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \
347                consistent with actual children length for the node.");
348            }
349            children
350                .into_iter()
351                .zip(necessary_children_indices)
352                .map(|(child, necessary_indices)| {
353                    RequiredIndices::new_from_indices(necessary_indices)
354                        .with_plan_exprs(&plan, child.schema())
355                })
356                .collect::<Result<Vec<_>>>()?
357        }
358        LogicalPlan::EmptyRelation(_)
359        | LogicalPlan::Values(_)
360        | LogicalPlan::DescribeTable(_) => {
361            // These operators have no inputs, so stop the optimization process.
362            return Ok(Transformed::no(plan));
363        }
364        LogicalPlan::RecursiveQuery(recursive) => {
365            // Only allow subqueries that reference the current CTE; nested subqueries are not yet
366            // supported for projection pushdown for simplicity.
367            // TODO: be able to do projection pushdown on recursive CTEs with subqueries
368            if plan_contains_other_subqueries(
369                recursive.static_term.as_ref(),
370                &recursive.name,
371            ) || plan_contains_other_subqueries(
372                recursive.recursive_term.as_ref(),
373                &recursive.name,
374            ) {
375                return Ok(Transformed::no(plan));
376            }
377
378            plan.inputs()
379                .into_iter()
380                .map(|input| {
381                    indices
382                        .clone()
383                        .with_projection_beneficial()
384                        .with_plan_exprs(&plan, input.schema())
385                })
386                .collect::<Result<Vec<_>>>()?
387        }
388        LogicalPlan::Join(join) => {
389            let left_len = join.left.schema().fields().len();
390            let (left_req_indices, right_req_indices) =
391                split_join_requirements(left_len, indices, &join.join_type);
392            let left_indices =
393                left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
394            let right_indices =
395                right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
396            // Joins benefit from "small" input tables (lower memory usage).
397            // Therefore, each child benefits from projection:
398            vec![
399                left_indices.with_projection_beneficial(),
400                right_indices.with_projection_beneficial(),
401            ]
402        }
403        // these nodes are explicitly rewritten in the match statement above
404        LogicalPlan::Projection(_)
405        | LogicalPlan::Aggregate(_)
406        | LogicalPlan::Window(_)
407        | LogicalPlan::TableScan(_) => {
408            return internal_err!(
409                "OptimizeProjection: should have handled in the match statement above"
410            );
411        }
412        LogicalPlan::Unnest(Unnest {
413            input,
414            dependency_indices,
415            ..
416        }) => {
417            // at least provide the indices for the exec-columns as a starting point
418            let required_indices =
419                RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
420
421            // Add additional required indices from the parent
422            let mut additional_necessary_child_indices = Vec::new();
423            indices.indices().iter().for_each(|idx| {
424                if let Some(index) = dependency_indices.get(*idx) {
425                    additional_necessary_child_indices.push(*index);
426                }
427            });
428            vec![required_indices.append(&additional_necessary_child_indices)]
429        }
430    };
431
432    // Required indices are currently ordered (child0, child1, ...)
433    // but the loop pops off the last element, so we need to reverse the order
434    child_required_indices.reverse();
435    if child_required_indices.len() != plan.inputs().len() {
436        return internal_err!(
437            "OptimizeProjection: child_required_indices length mismatch with plan inputs"
438        );
439    }
440
441    // Rewrite children of the plan
442    let transformed_plan = plan.map_children(|child| {
443        let required_indices = child_required_indices.pop().ok_or_else(|| {
444            internal_datafusion_err!(
445                "Unexpected number of required_indices in OptimizeProjections rule"
446            )
447        })?;
448
449        let projection_beneficial = required_indices.projection_beneficial();
450        let project_exprs = required_indices.get_required_exprs(child.schema());
451
452        optimize_projections(child, config, required_indices)?.transform_data(
453            |new_input| {
454                if projection_beneficial {
455                    add_projection_on_top_if_helpful(new_input, project_exprs)
456                } else {
457                    Ok(Transformed::no(new_input))
458                }
459            },
460        )
461    })?;
462
463    // If any of the children are transformed, we need to potentially update the plan's schema
464    if transformed_plan.transformed {
465        transformed_plan.map_data(|plan| plan.recompute_schema())
466    } else {
467        Ok(transformed_plan)
468    }
469}
470
471/// Merges consecutive projections.
472///
473/// Given a projection `proj`, this function attempts to merge it with a previous
474/// projection if it exists and if merging is beneficial. Merging is considered
475/// beneficial when expressions in the current projection are non-trivial and
476/// appear more than once in its input fields. This can act as a caching mechanism
477/// for non-trivial computations.
478///
479/// ## Metadata Handling During Projection Merging
480///
481/// **Alias metadata preservation**: When merging projections, alias metadata from both
482/// the current and previous projections is carefully preserved. The presence of metadata
483/// precludes alias trimming.
484///
485/// **Schema, Fields, and metadata**: If a projection is rewritten, the schema and metadata
486/// are preserved. Individual field properties and metadata flows through expression rewriting
487/// and are preserved when fields are referenced in the merged projection.
488/// Refer to [`projection_schema`](datafusion_expr::logical_plan::projection_schema)
489/// for more details.
490///
491/// # Parameters
492///
493/// * `proj` - A reference to the `Projection` to be merged.
494///
495/// # Returns
496///
497/// A `Result` object with the following semantics:
498///
499/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the
500///   merged projection.
501/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place).
502/// - `Err(error)`: An error occurred during the function call.
503fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
504    let Projection {
505        expr,
506        input,
507        schema,
508        ..
509    } = proj;
510    let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
511        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
512    };
513
514    // A fast path: if the previous projection is same as the current projection
515    // we can directly remove the current projection and return child projection.
516    if prev_projection.expr == expr {
517        return Projection::try_new_with_schema(
518            expr,
519            Arc::clone(&prev_projection.input),
520            schema,
521        )
522        .map(Transformed::yes);
523    }
524
525    // Count usages (referrals) of each projection expression in its input fields:
526    let mut column_referral_map = HashMap::<&Column, usize>::new();
527    expr.iter()
528        .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
529
530    // If an expression is non-trivial and appears more than once, do not merge
531    // them as consecutive projections will benefit from a compute-once approach.
532    // For details, see: https://github.com/apache/datafusion/issues/8296
533    if column_referral_map.into_iter().any(|(col, usage)| {
534        usage > 1
535            && !is_expr_trivial(
536                &prev_projection.expr
537                    [prev_projection.schema.index_of_column(col).unwrap()],
538            )
539    }) {
540        // no change
541        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
542    }
543
544    let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
545        // We know it is a `LogicalPlan::Projection` from check above
546        unreachable!();
547    };
548
549    // Try to rewrite the expressions in the current projection using the
550    // previous projection as input:
551    let name_preserver = NamePreserver::new_for_projection();
552    let mut original_names = vec![];
553    let new_exprs = expr.map_elements(|expr| {
554        original_names.push(name_preserver.save(&expr));
555
556        // do not rewrite top level Aliases (rewriter will remove all aliases within exprs)
557        match expr {
558            Expr::Alias(Alias {
559                expr,
560                relation,
561                name,
562                metadata,
563            }) => rewrite_expr(*expr, &prev_projection).map(|result| {
564                result.update_data(|expr| {
565                    Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata))
566                })
567            }),
568            e => rewrite_expr(e, &prev_projection),
569        }
570    })?;
571
572    // if the expressions could be rewritten, create a new projection with the
573    // new expressions
574    if new_exprs.transformed {
575        // Add any needed aliases back to the expressions
576        let new_exprs = new_exprs
577            .data
578            .into_iter()
579            .zip(original_names)
580            .map(|(expr, original_name)| original_name.restore(expr))
581            .collect::<Vec<_>>();
582        Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
583    } else {
584        // not rewritten, so put the projection back together
585        let input = Arc::new(LogicalPlan::Projection(prev_projection));
586        Projection::try_new_with_schema(new_exprs.data, input, schema)
587            .map(Transformed::no)
588    }
589}
590
591// Check whether `expr` is trivial; i.e. it doesn't imply any computation.
592fn is_expr_trivial(expr: &Expr) -> bool {
593    matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
594}
595
596/// Rewrites a projection expression using the projection before it (i.e. its input)
597/// This is a subroutine to the `merge_consecutive_projections` function.
598///
599/// # Parameters
600///
601/// * `expr` - A reference to the expression to rewrite.
602/// * `input` - A reference to the input of the projection expression (itself
603///   a projection).
604///
605/// # Returns
606///
607/// A `Result` object with the following semantics:
608///
609/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result.
610/// - `Ok(None)`: Signals that `expr` can not be rewritten.
611/// - `Err(error)`: An error occurred during the function call.
612///
613/// # Notes
614/// This rewrite also removes any unnecessary layers of aliasing. "Unnecessary" is
615/// defined as not contributing new information, such as metadata.
616///
617/// Without trimming, we can end up with unnecessary indirections inside expressions
618/// during projection merges.
619///
620/// Consider:
621///
622/// ```text
623/// Projection(a1 + b1 as sum1)
624/// --Projection(a as a1, b as b1)
625/// ----Source(a, b)
626/// ```
627///
628/// After merge, we want to produce:
629///
630/// ```text
631/// Projection(a + b as sum1)
632/// --Source(a, b)
633/// ```
634///
635/// Without trimming, we would end up with:
636///
637/// ```text
638/// Projection((a as a1 + b as b1) as sum1)
639/// --Source(a, b)
640/// ```
641fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
642    expr.transform_up(|expr| {
643        match expr {
644            //  remove any intermediate aliases if they do not carry metadata
645            Expr::Alias(alias) => {
646                match alias
647                    .metadata
648                    .as_ref()
649                    .map(|h| h.is_empty())
650                    .unwrap_or(true)
651                {
652                    true => Ok(Transformed::yes(*alias.expr)),
653                    false => Ok(Transformed::no(Expr::Alias(alias))),
654                }
655            }
656            Expr::Column(col) => {
657                // Find index of column:
658                let idx = input.schema.index_of_column(&col)?;
659                // get the corresponding unaliased input expression
660                //
661                // For example:
662                // * the input projection is [`a + b` as c, `d + e` as f]
663                // * the current column is an expression "f"
664                //
665                // return the expression `d + e` (not `d + e` as f)
666                let input_expr = input.expr[idx].clone().unalias_nested().data;
667                Ok(Transformed::yes(input_expr))
668            }
669            // Unsupported type for consecutive projection merge analysis.
670            _ => Ok(Transformed::no(expr)),
671        }
672    })
673}
674
675/// Accumulates outer-referenced columns by the
676/// given expression, `expr`.
677///
678/// # Parameters
679///
680/// * `expr` - The expression to analyze for outer-referenced columns.
681/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
682///   columns are collected.
683fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
684    // inspect_expr_pre doesn't handle subquery references, so find them explicitly
685    expr.apply(|expr| {
686        match expr {
687            Expr::OuterReferenceColumn(_, col) => {
688                columns.insert(col);
689            }
690            Expr::ScalarSubquery(subquery) => {
691                outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
692            }
693            Expr::Exists(exists) => {
694                outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
695            }
696            Expr::InSubquery(insubquery) => {
697                outer_columns_helper_multi(
698                    &insubquery.subquery.outer_ref_columns,
699                    columns,
700                );
701            }
702            _ => {}
703        };
704        Ok(TreeNodeRecursion::Continue)
705    })
706    // unwrap: closure above never returns Err, so can not be Err here
707    .unwrap();
708}
709
710/// A recursive subroutine that accumulates outer-referenced columns by the
711/// given expressions (`exprs`).
712///
713/// # Parameters
714///
715/// * `exprs` - The expressions to analyze for outer-referenced columns.
716/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
717///   columns are collected.
718fn outer_columns_helper_multi<'a, 'b>(
719    exprs: impl IntoIterator<Item = &'a Expr>,
720    columns: &'b mut HashSet<&'a Column>,
721) {
722    exprs.into_iter().for_each(|e| outer_columns(e, columns));
723}
724
725/// Splits requirement indices for a join into left and right children based on
726/// the join type.
727///
728/// This function takes the length of the left child, a slice of requirement
729/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments.
730/// Depending on the join type, it divides the requirement indices into those
731/// that apply to the left child and those that apply to the right child.
732///
733/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins,
734///   the requirements are split between left and right children. The right
735///   child indices are adjusted to point to valid positions within the right
736///   child by subtracting the length of the left child.
737///
738/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all
739///   requirements are re-routed to either the left child or the right child
740///   directly, depending on the join type.
741///
742/// # Parameters
743///
744/// * `left_len` - The length of the left child.
745/// * `indices` - A slice of requirement indices.
746/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`).
747///
748/// # Returns
749///
750/// A tuple containing two vectors of `usize` indices: The first vector represents
751/// the requirements for the left child, and the second vector represents the
752/// requirements for the right child. The indices are appropriately split and
753/// adjusted based on the join type.
754fn split_join_requirements(
755    left_len: usize,
756    indices: RequiredIndices,
757    join_type: &JoinType,
758) -> (RequiredIndices, RequiredIndices) {
759    match join_type {
760        // In these cases requirements are split between left/right children:
761        JoinType::Inner
762        | JoinType::Left
763        | JoinType::Right
764        | JoinType::Full
765        | JoinType::LeftMark
766        | JoinType::RightMark => {
767            // Decrease right side indices by `left_len` so that they point to valid
768            // positions within the right child:
769            indices.split_off(left_len)
770        }
771        // All requirements can be re-routed to left child directly.
772        JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
773        // All requirements can be re-routed to right side directly.
774        // No need to change index, join schema is right child schema.
775        JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
776    }
777}
778
779/// Adds a projection on top of a logical plan if doing so reduces the number
780/// of columns for the parent operator.
781///
782/// This function takes a `LogicalPlan` and a list of projection expressions.
783/// If the projection is beneficial (it reduces the number of columns in the
784/// plan) a new `LogicalPlan` with the projection is created and returned, along
785/// with a `true` flag. If the projection doesn't reduce the number of columns,
786/// the original plan is returned with a `false` flag.
787///
788/// # Parameters
789///
790/// * `plan` - The input `LogicalPlan` to potentially add a projection to.
791/// * `project_exprs` - A list of expressions for the projection.
792///
793/// # Returns
794///
795/// A `Transformed` indicating if a projection was added
796fn add_projection_on_top_if_helpful(
797    plan: LogicalPlan,
798    project_exprs: Vec<Expr>,
799) -> Result<Transformed<LogicalPlan>> {
800    // Make sure projection decreases the number of columns, otherwise it is unnecessary.
801    if project_exprs.len() >= plan.schema().fields().len() {
802        Ok(Transformed::no(plan))
803    } else {
804        Projection::try_new(project_exprs, Arc::new(plan))
805            .map(LogicalPlan::Projection)
806            .map(Transformed::yes)
807    }
808}
809
810/// Rewrite the given projection according to the fields required by its
811/// ancestors.
812///
813/// # Parameters
814///
815/// * `proj` - A reference to the original projection to rewrite.
816/// * `config` - A reference to the optimizer configuration.
817/// * `indices` - A slice of indices representing the columns required by the
818///   ancestors of the given projection.
819///
820/// # Returns
821///
822/// A `Result` object with the following semantics:
823///
824/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection
825/// - `Ok(None)`: No rewrite necessary.
826/// - `Err(error)`: An error occurred during the function call.
827fn rewrite_projection_given_requirements(
828    proj: Projection,
829    config: &dyn OptimizerConfig,
830    indices: &RequiredIndices,
831) -> Result<Transformed<LogicalPlan>> {
832    let Projection { expr, input, .. } = proj;
833
834    let exprs_used = indices.get_at_indices(&expr);
835
836    let required_indices =
837        RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
838
839    // rewrite the children projection, and if they are changed rewrite the
840    // projection down
841    optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
842        .transform_data(|input| {
843            if is_projection_unnecessary(&input, &exprs_used)? {
844                Ok(Transformed::yes(input))
845            } else {
846                Projection::try_new(exprs_used, Arc::new(input))
847                    .map(LogicalPlan::Projection)
848                    .map(Transformed::yes)
849            }
850        })
851}
852
853/// Projection is unnecessary, when
854/// - input schema of the projection, output schema of the projection are same, and
855/// - all projection expressions are either Column or Literal
856pub fn is_projection_unnecessary(
857    input: &LogicalPlan,
858    proj_exprs: &[Expr],
859) -> Result<bool> {
860    // First check if the number of expressions is equal to the number of fields in the input schema.
861    if proj_exprs.len() != input.schema().fields().len() {
862        return Ok(false);
863    }
864    Ok(input.schema().iter().zip(proj_exprs.iter()).all(
865        |((field_relation, field_name), expr)| {
866            // Check if the expression is a column and if it matches the field name
867            if let Expr::Column(col) = expr {
868                col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
869            } else {
870                false
871            }
872        },
873    ))
874}
875
876/// Returns true if the plan subtree contains any subqueries that are not the
877/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`]
878/// node (including aliased relations) as a blocker, along with expression-level
879/// subqueries like scalar, EXISTS, or IN. These cases prevent projection
880/// pushdown for now because we cannot safely reason about their column usage.
881fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
882    if let LogicalPlan::SubqueryAlias(alias) = plan {
883        if alias.alias.table() != cte_name
884            && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
885        {
886            return true;
887        }
888    }
889
890    let mut found = false;
891    plan.apply_expressions(|expr| {
892        if expr_contains_subquery(expr) {
893            found = true;
894            Ok(TreeNodeRecursion::Stop)
895        } else {
896            Ok(TreeNodeRecursion::Continue)
897        }
898    })
899    .expect("expression traversal never fails");
900    if found {
901        return true;
902    }
903
904    plan.inputs()
905        .into_iter()
906        .any(|child| plan_contains_other_subqueries(child, cte_name))
907}
908
909fn expr_contains_subquery(expr: &Expr) -> bool {
910    expr.exists(|e| match e {
911        Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
912        _ => Ok(false),
913    })
914    // Safe unwrap since we are doing a simple boolean check
915    .unwrap()
916}
917
918fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool {
919    match plan {
920        LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name,
921        LogicalPlan::SubqueryAlias(alias) => {
922            subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
923        }
924        _ => {
925            let inputs = plan.inputs();
926            if inputs.len() == 1 {
927                subquery_alias_targets_recursive_cte(inputs[0], cte_name)
928            } else {
929                false
930            }
931        }
932    }
933}
934
935#[cfg(test)]
936mod tests {
937    use std::cmp::Ordering;
938    use std::collections::HashMap;
939    use std::fmt::Formatter;
940    use std::ops::Add;
941    use std::sync::Arc;
942    use std::vec;
943
944    use crate::optimize_projections::OptimizeProjections;
945    use crate::optimizer::Optimizer;
946    use crate::test::{
947        assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
948        test_table_scan_with_name,
949    };
950    use crate::{OptimizerContext, OptimizerRule};
951    use arrow::datatypes::{DataType, Field, Schema};
952    use datafusion_common::{
953        Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
954    };
955    use datafusion_expr::ExprFunctionExt;
956    use datafusion_expr::{
957        binary_expr, build_join_schema,
958        builder::table_scan_with_filters,
959        col,
960        expr::{self, Cast},
961        lit,
962        logical_plan::{builder::LogicalPlanBuilder, table_scan},
963        not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
964        Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition,
965    };
966    use insta::assert_snapshot;
967
968    use crate::assert_optimized_plan_eq_snapshot;
969    use datafusion_functions_aggregate::count::count_udaf;
970    use datafusion_functions_aggregate::expr_fn::{count, max, min};
971    use datafusion_functions_aggregate::min_max::max_udaf;
972
973    macro_rules! assert_optimized_plan_equal {
974        (
975            $plan:expr,
976            @ $expected:literal $(,)?
977        ) => {{
978            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
979            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
980            assert_optimized_plan_eq_snapshot!(
981                optimizer_ctx,
982                rules,
983                $plan,
984                @ $expected,
985            )
986        }};
987    }
988
989    #[derive(Debug, Hash, PartialEq, Eq)]
990    struct NoOpUserDefined {
991        exprs: Vec<Expr>,
992        schema: DFSchemaRef,
993        input: Arc<LogicalPlan>,
994    }
995
996    impl NoOpUserDefined {
997        fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
998            Self {
999                exprs: vec![],
1000                schema,
1001                input,
1002            }
1003        }
1004
1005        fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
1006            self.exprs = exprs;
1007            self
1008        }
1009    }
1010
1011    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1012    impl PartialOrd for NoOpUserDefined {
1013        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1014            match self.exprs.partial_cmp(&other.exprs) {
1015                Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
1016                cmp => cmp,
1017            }
1018            // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
1019            .filter(|cmp| *cmp != Ordering::Equal || self == other)
1020        }
1021    }
1022
1023    impl UserDefinedLogicalNodeCore for NoOpUserDefined {
1024        fn name(&self) -> &str {
1025            "NoOpUserDefined"
1026        }
1027
1028        fn inputs(&self) -> Vec<&LogicalPlan> {
1029            vec![&self.input]
1030        }
1031
1032        fn schema(&self) -> &DFSchemaRef {
1033            &self.schema
1034        }
1035
1036        fn expressions(&self) -> Vec<Expr> {
1037            self.exprs.clone()
1038        }
1039
1040        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1041            write!(f, "NoOpUserDefined")
1042        }
1043
1044        fn with_exprs_and_inputs(
1045            &self,
1046            exprs: Vec<Expr>,
1047            mut inputs: Vec<LogicalPlan>,
1048        ) -> Result<Self> {
1049            Ok(Self {
1050                exprs,
1051                input: Arc::new(inputs.swap_remove(0)),
1052                schema: Arc::clone(&self.schema),
1053            })
1054        }
1055
1056        fn necessary_children_exprs(
1057            &self,
1058            output_columns: &[usize],
1059        ) -> Option<Vec<Vec<usize>>> {
1060            // Since schema is same. Output columns requires their corresponding version in the input columns.
1061            Some(vec![output_columns.to_vec()])
1062        }
1063
1064        fn supports_limit_pushdown(&self) -> bool {
1065            false // Disallow limit push-down by default
1066        }
1067    }
1068
1069    #[derive(Debug, Hash, PartialEq, Eq)]
1070    struct UserDefinedCrossJoin {
1071        exprs: Vec<Expr>,
1072        schema: DFSchemaRef,
1073        left_child: Arc<LogicalPlan>,
1074        right_child: Arc<LogicalPlan>,
1075    }
1076
1077    impl UserDefinedCrossJoin {
1078        fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
1079            let left_schema = left_child.schema();
1080            let right_schema = right_child.schema();
1081            let schema = Arc::new(
1082                build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
1083            );
1084            Self {
1085                exprs: vec![],
1086                schema,
1087                left_child,
1088                right_child,
1089            }
1090        }
1091    }
1092
1093    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1094    impl PartialOrd for UserDefinedCrossJoin {
1095        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1096            match self.exprs.partial_cmp(&other.exprs) {
1097                Some(Ordering::Equal) => {
1098                    match self.left_child.partial_cmp(&other.left_child) {
1099                        Some(Ordering::Equal) => {
1100                            self.right_child.partial_cmp(&other.right_child)
1101                        }
1102                        cmp => cmp,
1103                    }
1104                }
1105                cmp => cmp,
1106            }
1107            // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
1108            .filter(|cmp| *cmp != Ordering::Equal || self == other)
1109        }
1110    }
1111
1112    impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
1113        fn name(&self) -> &str {
1114            "UserDefinedCrossJoin"
1115        }
1116
1117        fn inputs(&self) -> Vec<&LogicalPlan> {
1118            vec![&self.left_child, &self.right_child]
1119        }
1120
1121        fn schema(&self) -> &DFSchemaRef {
1122            &self.schema
1123        }
1124
1125        fn expressions(&self) -> Vec<Expr> {
1126            self.exprs.clone()
1127        }
1128
1129        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1130            write!(f, "UserDefinedCrossJoin")
1131        }
1132
1133        fn with_exprs_and_inputs(
1134            &self,
1135            exprs: Vec<Expr>,
1136            mut inputs: Vec<LogicalPlan>,
1137        ) -> Result<Self> {
1138            assert_eq!(inputs.len(), 2);
1139            Ok(Self {
1140                exprs,
1141                left_child: Arc::new(inputs.remove(0)),
1142                right_child: Arc::new(inputs.remove(0)),
1143                schema: Arc::clone(&self.schema),
1144            })
1145        }
1146
1147        fn necessary_children_exprs(
1148            &self,
1149            output_columns: &[usize],
1150        ) -> Option<Vec<Vec<usize>>> {
1151            let left_child_len = self.left_child.schema().fields().len();
1152            let mut left_reqs = vec![];
1153            let mut right_reqs = vec![];
1154            for &out_idx in output_columns {
1155                if out_idx < left_child_len {
1156                    left_reqs.push(out_idx);
1157                } else {
1158                    // Output indices further than the left_child_len
1159                    // comes from right children
1160                    right_reqs.push(out_idx - left_child_len)
1161                }
1162            }
1163            Some(vec![left_reqs, right_reqs])
1164        }
1165
1166        fn supports_limit_pushdown(&self) -> bool {
1167            false // Disallow limit push-down by default
1168        }
1169    }
1170
1171    #[test]
1172    fn merge_two_projection() -> Result<()> {
1173        let table_scan = test_table_scan()?;
1174        let plan = LogicalPlanBuilder::from(table_scan)
1175            .project(vec![col("a")])?
1176            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1177            .build()?;
1178
1179        assert_optimized_plan_equal!(
1180            plan,
1181            @r"
1182        Projection: Int32(1) + test.a
1183          TableScan: test projection=[a]
1184        "
1185        )
1186    }
1187
1188    #[test]
1189    fn merge_three_projection() -> Result<()> {
1190        let table_scan = test_table_scan()?;
1191        let plan = LogicalPlanBuilder::from(table_scan)
1192            .project(vec![col("a"), col("b")])?
1193            .project(vec![col("a")])?
1194            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1195            .build()?;
1196
1197        assert_optimized_plan_equal!(
1198            plan,
1199            @r"
1200        Projection: Int32(1) + test.a
1201          TableScan: test projection=[a]
1202        "
1203        )
1204    }
1205
1206    #[test]
1207    fn merge_alias() -> Result<()> {
1208        let table_scan = test_table_scan()?;
1209        let plan = LogicalPlanBuilder::from(table_scan)
1210            .project(vec![col("a")])?
1211            .project(vec![col("a").alias("alias")])?
1212            .build()?;
1213
1214        assert_optimized_plan_equal!(
1215            plan,
1216            @r"
1217        Projection: test.a AS alias
1218          TableScan: test projection=[a]
1219        "
1220        )
1221    }
1222
1223    #[test]
1224    fn merge_nested_alias() -> Result<()> {
1225        let table_scan = test_table_scan()?;
1226        let plan = LogicalPlanBuilder::from(table_scan)
1227            .project(vec![col("a").alias("alias1").alias("alias2")])?
1228            .project(vec![col("alias2").alias("alias")])?
1229            .build()?;
1230
1231        assert_optimized_plan_equal!(
1232            plan,
1233            @r"
1234        Projection: test.a AS alias
1235          TableScan: test projection=[a]
1236        "
1237        )
1238    }
1239
1240    #[test]
1241    fn test_nested_count() -> Result<()> {
1242        let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1243
1244        let groups: Vec<Expr> = vec![];
1245
1246        let plan = table_scan(TableReference::none(), &schema, None)
1247            .unwrap()
1248            .aggregate(groups.clone(), vec![count(lit(1))])
1249            .unwrap()
1250            .aggregate(groups, vec![count(lit(1))])
1251            .unwrap()
1252            .build()
1253            .unwrap();
1254
1255        assert_optimized_plan_equal!(
1256            plan,
1257            @r"
1258        Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1259          EmptyRelation: rows=1
1260        "
1261        )
1262    }
1263
1264    #[test]
1265    fn test_neg_push_down() -> Result<()> {
1266        let table_scan = test_table_scan()?;
1267        let plan = LogicalPlanBuilder::from(table_scan)
1268            .project(vec![-col("a")])?
1269            .build()?;
1270
1271        assert_optimized_plan_equal!(
1272            plan,
1273            @r"
1274        Projection: (- test.a)
1275          TableScan: test projection=[a]
1276        "
1277        )
1278    }
1279
1280    #[test]
1281    fn test_is_null() -> Result<()> {
1282        let table_scan = test_table_scan()?;
1283        let plan = LogicalPlanBuilder::from(table_scan)
1284            .project(vec![col("a").is_null()])?
1285            .build()?;
1286
1287        assert_optimized_plan_equal!(
1288            plan,
1289            @r"
1290        Projection: test.a IS NULL
1291          TableScan: test projection=[a]
1292        "
1293        )
1294    }
1295
1296    #[test]
1297    fn test_is_not_null() -> Result<()> {
1298        let table_scan = test_table_scan()?;
1299        let plan = LogicalPlanBuilder::from(table_scan)
1300            .project(vec![col("a").is_not_null()])?
1301            .build()?;
1302
1303        assert_optimized_plan_equal!(
1304            plan,
1305            @r"
1306        Projection: test.a IS NOT NULL
1307          TableScan: test projection=[a]
1308        "
1309        )
1310    }
1311
1312    #[test]
1313    fn test_is_true() -> Result<()> {
1314        let table_scan = test_table_scan()?;
1315        let plan = LogicalPlanBuilder::from(table_scan)
1316            .project(vec![col("a").is_true()])?
1317            .build()?;
1318
1319        assert_optimized_plan_equal!(
1320            plan,
1321            @r"
1322        Projection: test.a IS TRUE
1323          TableScan: test projection=[a]
1324        "
1325        )
1326    }
1327
1328    #[test]
1329    fn test_is_not_true() -> Result<()> {
1330        let table_scan = test_table_scan()?;
1331        let plan = LogicalPlanBuilder::from(table_scan)
1332            .project(vec![col("a").is_not_true()])?
1333            .build()?;
1334
1335        assert_optimized_plan_equal!(
1336            plan,
1337            @r"
1338        Projection: test.a IS NOT TRUE
1339          TableScan: test projection=[a]
1340        "
1341        )
1342    }
1343
1344    #[test]
1345    fn test_is_false() -> Result<()> {
1346        let table_scan = test_table_scan()?;
1347        let plan = LogicalPlanBuilder::from(table_scan)
1348            .project(vec![col("a").is_false()])?
1349            .build()?;
1350
1351        assert_optimized_plan_equal!(
1352            plan,
1353            @r"
1354        Projection: test.a IS FALSE
1355          TableScan: test projection=[a]
1356        "
1357        )
1358    }
1359
1360    #[test]
1361    fn test_is_not_false() -> Result<()> {
1362        let table_scan = test_table_scan()?;
1363        let plan = LogicalPlanBuilder::from(table_scan)
1364            .project(vec![col("a").is_not_false()])?
1365            .build()?;
1366
1367        assert_optimized_plan_equal!(
1368            plan,
1369            @r"
1370        Projection: test.a IS NOT FALSE
1371          TableScan: test projection=[a]
1372        "
1373        )
1374    }
1375
1376    #[test]
1377    fn test_is_unknown() -> Result<()> {
1378        let table_scan = test_table_scan()?;
1379        let plan = LogicalPlanBuilder::from(table_scan)
1380            .project(vec![col("a").is_unknown()])?
1381            .build()?;
1382
1383        assert_optimized_plan_equal!(
1384            plan,
1385            @r"
1386        Projection: test.a IS UNKNOWN
1387          TableScan: test projection=[a]
1388        "
1389        )
1390    }
1391
1392    #[test]
1393    fn test_is_not_unknown() -> Result<()> {
1394        let table_scan = test_table_scan()?;
1395        let plan = LogicalPlanBuilder::from(table_scan)
1396            .project(vec![col("a").is_not_unknown()])?
1397            .build()?;
1398
1399        assert_optimized_plan_equal!(
1400            plan,
1401            @r"
1402        Projection: test.a IS NOT UNKNOWN
1403          TableScan: test projection=[a]
1404        "
1405        )
1406    }
1407
1408    #[test]
1409    fn test_not() -> Result<()> {
1410        let table_scan = test_table_scan()?;
1411        let plan = LogicalPlanBuilder::from(table_scan)
1412            .project(vec![not(col("a"))])?
1413            .build()?;
1414
1415        assert_optimized_plan_equal!(
1416            plan,
1417            @r"
1418        Projection: NOT test.a
1419          TableScan: test projection=[a]
1420        "
1421        )
1422    }
1423
1424    #[test]
1425    fn test_try_cast() -> Result<()> {
1426        let table_scan = test_table_scan()?;
1427        let plan = LogicalPlanBuilder::from(table_scan)
1428            .project(vec![try_cast(col("a"), DataType::Float64)])?
1429            .build()?;
1430
1431        assert_optimized_plan_equal!(
1432            plan,
1433            @r"
1434        Projection: TRY_CAST(test.a AS Float64)
1435          TableScan: test projection=[a]
1436        "
1437        )
1438    }
1439
1440    #[test]
1441    fn test_similar_to() -> Result<()> {
1442        let table_scan = test_table_scan()?;
1443        let expr = Box::new(col("a"));
1444        let pattern = Box::new(lit("[0-9]"));
1445        let similar_to_expr =
1446            Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1447        let plan = LogicalPlanBuilder::from(table_scan)
1448            .project(vec![similar_to_expr])?
1449            .build()?;
1450
1451        assert_optimized_plan_equal!(
1452            plan,
1453            @r#"
1454        Projection: test.a SIMILAR TO Utf8("[0-9]")
1455          TableScan: test projection=[a]
1456        "#
1457        )
1458    }
1459
1460    #[test]
1461    fn test_between() -> Result<()> {
1462        let table_scan = test_table_scan()?;
1463        let plan = LogicalPlanBuilder::from(table_scan)
1464            .project(vec![col("a").between(lit(1), lit(3))])?
1465            .build()?;
1466
1467        assert_optimized_plan_equal!(
1468            plan,
1469            @r"
1470        Projection: test.a BETWEEN Int32(1) AND Int32(3)
1471          TableScan: test projection=[a]
1472        "
1473        )
1474    }
1475
1476    // Test Case expression
1477    #[test]
1478    fn test_case_merged() -> Result<()> {
1479        let table_scan = test_table_scan()?;
1480        let plan = LogicalPlanBuilder::from(table_scan)
1481            .project(vec![col("a"), lit(0).alias("d")])?
1482            .project(vec![
1483                col("a"),
1484                when(col("a").eq(lit(1)), lit(10))
1485                    .otherwise(col("d"))?
1486                    .alias("d"),
1487            ])?
1488            .build()?;
1489
1490        assert_optimized_plan_equal!(
1491            plan,
1492            @r"
1493        Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1494          TableScan: test projection=[a]
1495        "
1496        )
1497    }
1498
1499    // Test outer projection isn't discarded despite the same schema as inner
1500    // https://github.com/apache/datafusion/issues/8942
1501    #[test]
1502    fn test_derived_column() -> Result<()> {
1503        let table_scan = test_table_scan()?;
1504        let plan = LogicalPlanBuilder::from(table_scan)
1505            .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1506            .project(vec![
1507                col("a"),
1508                when(col("a").eq(lit(1)), lit(10))
1509                    .otherwise(col("d"))?
1510                    .alias("d"),
1511            ])?
1512            .build()?;
1513
1514        assert_optimized_plan_equal!(
1515            plan,
1516            @r"
1517        Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1518          Projection: test.a + Int32(1) AS a, Int32(0) AS d
1519            TableScan: test projection=[a]
1520        "
1521        )
1522    }
1523
1524    // Since only column `a` is referred at the output. Scan should only contain projection=[a].
1525    // User defined node should be able to propagate necessary expressions by its parent to its child.
1526    #[test]
1527    fn test_user_defined_logical_plan_node() -> Result<()> {
1528        let table_scan = test_table_scan()?;
1529        let custom_plan = LogicalPlan::Extension(Extension {
1530            node: Arc::new(NoOpUserDefined::new(
1531                Arc::clone(table_scan.schema()),
1532                Arc::new(table_scan.clone()),
1533            )),
1534        });
1535        let plan = LogicalPlanBuilder::from(custom_plan)
1536            .project(vec![col("a"), lit(0).alias("d")])?
1537            .build()?;
1538
1539        assert_optimized_plan_equal!(
1540            plan,
1541            @r"
1542        Projection: test.a, Int32(0) AS d
1543          NoOpUserDefined
1544            TableScan: test projection=[a]
1545        "
1546        )
1547    }
1548
1549    // Only column `a` is referred at the output. However, User defined node itself uses column `b`
1550    // during its operation. Hence, scan should contain projection=[a, b].
1551    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1552    // required expressions.
1553    #[test]
1554    fn test_user_defined_logical_plan_node2() -> Result<()> {
1555        let table_scan = test_table_scan()?;
1556        let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1557        let custom_plan = LogicalPlan::Extension(Extension {
1558            node: Arc::new(
1559                NoOpUserDefined::new(
1560                    Arc::clone(table_scan.schema()),
1561                    Arc::new(table_scan.clone()),
1562                )
1563                .with_exprs(exprs),
1564            ),
1565        });
1566        let plan = LogicalPlanBuilder::from(custom_plan)
1567            .project(vec![col("a"), lit(0).alias("d")])?
1568            .build()?;
1569
1570        assert_optimized_plan_equal!(
1571            plan,
1572            @r"
1573        Projection: test.a, Int32(0) AS d
1574          NoOpUserDefined
1575            TableScan: test projection=[a, b]
1576        "
1577        )
1578    }
1579
1580    // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c`
1581    // during its operation. Hence, scan should contain projection=[a, b, c].
1582    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1583    // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions
1584    // should be propagated also.
1585    #[test]
1586    fn test_user_defined_logical_plan_node3() -> Result<()> {
1587        let table_scan = test_table_scan()?;
1588        let left_expr = Expr::Column(Column::from_qualified_name("b"));
1589        let right_expr = Expr::Column(Column::from_qualified_name("c"));
1590        let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1591            Box::new(left_expr),
1592            Operator::Plus,
1593            Box::new(right_expr),
1594        ));
1595        let exprs = vec![binary_expr];
1596        let custom_plan = LogicalPlan::Extension(Extension {
1597            node: Arc::new(
1598                NoOpUserDefined::new(
1599                    Arc::clone(table_scan.schema()),
1600                    Arc::new(table_scan.clone()),
1601                )
1602                .with_exprs(exprs),
1603            ),
1604        });
1605        let plan = LogicalPlanBuilder::from(custom_plan)
1606            .project(vec![col("a"), lit(0).alias("d")])?
1607            .build()?;
1608
1609        assert_optimized_plan_equal!(
1610            plan,
1611            @r"
1612        Projection: test.a, Int32(0) AS d
1613          NoOpUserDefined
1614            TableScan: test projection=[a, b, c]
1615        "
1616        )
1617    }
1618
1619    // Columns `l.a`, `l.c`, `r.a` is referred at the output.
1620    // User defined node should be able to propagate necessary expressions by its parent, to its children.
1621    // Even if it has multiple children.
1622    // left child should have `projection=[a, c]`, and right side should have `projection=[a]`.
1623    #[test]
1624    fn test_user_defined_logical_plan_node4() -> Result<()> {
1625        let left_table = test_table_scan_with_name("l")?;
1626        let right_table = test_table_scan_with_name("r")?;
1627        let custom_plan = LogicalPlan::Extension(Extension {
1628            node: Arc::new(UserDefinedCrossJoin::new(
1629                Arc::new(left_table),
1630                Arc::new(right_table),
1631            )),
1632        });
1633        let plan = LogicalPlanBuilder::from(custom_plan)
1634            .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1635            .build()?;
1636
1637        assert_optimized_plan_equal!(
1638            plan,
1639            @r"
1640        Projection: l.a, l.c, r.a, Int32(0) AS d
1641          UserDefinedCrossJoin
1642            TableScan: l projection=[a, c]
1643            TableScan: r projection=[a]
1644        "
1645        )
1646    }
1647
1648    #[test]
1649    fn aggregate_no_group_by() -> Result<()> {
1650        let table_scan = test_table_scan()?;
1651
1652        let plan = LogicalPlanBuilder::from(table_scan)
1653            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1654            .build()?;
1655
1656        assert_optimized_plan_equal!(
1657            plan,
1658            @r"
1659        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1660          TableScan: test projection=[b]
1661        "
1662        )
1663    }
1664
1665    #[test]
1666    fn aggregate_group_by() -> Result<()> {
1667        let table_scan = test_table_scan()?;
1668
1669        let plan = LogicalPlanBuilder::from(table_scan)
1670            .aggregate(vec![col("c")], vec![max(col("b"))])?
1671            .build()?;
1672
1673        assert_optimized_plan_equal!(
1674            plan,
1675            @r"
1676        Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1677          TableScan: test projection=[b, c]
1678        "
1679        )
1680    }
1681
1682    #[test]
1683    fn aggregate_group_by_with_table_alias() -> Result<()> {
1684        let table_scan = test_table_scan()?;
1685
1686        let plan = LogicalPlanBuilder::from(table_scan)
1687            .alias("a")?
1688            .aggregate(vec![col("c")], vec![max(col("b"))])?
1689            .build()?;
1690
1691        assert_optimized_plan_equal!(
1692            plan,
1693            @r"
1694        Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1695          SubqueryAlias: a
1696            TableScan: test projection=[b, c]
1697        "
1698        )
1699    }
1700
1701    #[test]
1702    fn aggregate_no_group_by_with_filter() -> Result<()> {
1703        let table_scan = test_table_scan()?;
1704
1705        let plan = LogicalPlanBuilder::from(table_scan)
1706            .filter(col("c").gt(lit(1)))?
1707            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1708            .build()?;
1709
1710        assert_optimized_plan_equal!(
1711            plan,
1712            @r"
1713        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1714          Projection: test.b
1715            Filter: test.c > Int32(1)
1716              TableScan: test projection=[b, c]
1717        "
1718        )
1719    }
1720
1721    #[test]
1722    fn aggregate_with_periods() -> Result<()> {
1723        let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1724
1725        // Build a plan that looks as follows (note "tag.one" is a column named
1726        // "tag.one", not a column named "one" in a table named "tag"):
1727        //
1728        // Projection: tag.one
1729        //   Aggregate: groupBy=[], aggr=[max("tag.one") AS "tag.one"]
1730        //    TableScan
1731        let plan = table_scan(Some("m4"), &schema, None)?
1732            .aggregate(
1733                Vec::<Expr>::new(),
1734                vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1735            )?
1736            .project([col(Column::new_unqualified("tag.one"))])?
1737            .build()?;
1738
1739        assert_optimized_plan_equal!(
1740            plan,
1741            @r"
1742        Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1743          TableScan: m4 projection=[tag.one]
1744        "
1745        )
1746    }
1747
1748    #[test]
1749    fn redundant_project() -> Result<()> {
1750        let table_scan = test_table_scan()?;
1751
1752        let plan = LogicalPlanBuilder::from(table_scan)
1753            .project(vec![col("a"), col("b"), col("c")])?
1754            .project(vec![col("a"), col("c"), col("b")])?
1755            .build()?;
1756        assert_optimized_plan_equal!(
1757            plan,
1758            @r"
1759        Projection: test.a, test.c, test.b
1760          TableScan: test projection=[a, b, c]
1761        "
1762        )
1763    }
1764
1765    #[test]
1766    fn reorder_scan() -> Result<()> {
1767        let schema = Schema::new(test_table_scan_fields());
1768
1769        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1770        assert_optimized_plan_equal!(
1771            plan,
1772            @"TableScan: test projection=[b, a, c]"
1773        )
1774    }
1775
1776    #[test]
1777    fn reorder_scan_projection() -> Result<()> {
1778        let schema = Schema::new(test_table_scan_fields());
1779
1780        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1781            .project(vec![col("a"), col("b")])?
1782            .build()?;
1783        assert_optimized_plan_equal!(
1784            plan,
1785            @r"
1786        Projection: test.a, test.b
1787          TableScan: test projection=[b, a]
1788        "
1789        )
1790    }
1791
1792    #[test]
1793    fn reorder_projection() -> Result<()> {
1794        let table_scan = test_table_scan()?;
1795
1796        let plan = LogicalPlanBuilder::from(table_scan)
1797            .project(vec![col("c"), col("b"), col("a")])?
1798            .build()?;
1799        assert_optimized_plan_equal!(
1800            plan,
1801            @r"
1802        Projection: test.c, test.b, test.a
1803          TableScan: test projection=[a, b, c]
1804        "
1805        )
1806    }
1807
1808    #[test]
1809    fn noncontinuous_redundant_projection() -> Result<()> {
1810        let table_scan = test_table_scan()?;
1811
1812        let plan = LogicalPlanBuilder::from(table_scan)
1813            .project(vec![col("c"), col("b"), col("a")])?
1814            .filter(col("c").gt(lit(1)))?
1815            .project(vec![col("c"), col("a"), col("b")])?
1816            .filter(col("b").gt(lit(1)))?
1817            .filter(col("a").gt(lit(1)))?
1818            .project(vec![col("a"), col("c"), col("b")])?
1819            .build()?;
1820        assert_optimized_plan_equal!(
1821            plan,
1822            @r"
1823        Projection: test.a, test.c, test.b
1824          Filter: test.a > Int32(1)
1825            Filter: test.b > Int32(1)
1826              Projection: test.c, test.a, test.b
1827                Filter: test.c > Int32(1)
1828                  Projection: test.c, test.b, test.a
1829                    TableScan: test projection=[a, b, c]
1830        "
1831        )
1832    }
1833
1834    #[test]
1835    fn join_schema_trim_full_join_column_projection() -> Result<()> {
1836        let table_scan = test_table_scan()?;
1837
1838        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1839        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1840
1841        let plan = LogicalPlanBuilder::from(table_scan)
1842            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1843            .project(vec![col("a"), col("b"), col("c1")])?
1844            .build()?;
1845
1846        let optimized_plan = optimize(plan)?;
1847
1848        // make sure projections are pushed down to both table scans
1849        assert_snapshot!(
1850            optimized_plan.clone(),
1851            @r"
1852        Left Join: test.a = test2.c1
1853          TableScan: test projection=[a, b]
1854          TableScan: test2 projection=[c1]
1855        "
1856        );
1857
1858        // make sure schema for join node include both join columns
1859        let optimized_join = optimized_plan;
1860        assert_eq!(
1861            **optimized_join.schema(),
1862            DFSchema::new_with_metadata(
1863                vec![
1864                    (
1865                        Some("test".into()),
1866                        Arc::new(Field::new("a", DataType::UInt32, false))
1867                    ),
1868                    (
1869                        Some("test".into()),
1870                        Arc::new(Field::new("b", DataType::UInt32, false))
1871                    ),
1872                    (
1873                        Some("test2".into()),
1874                        Arc::new(Field::new("c1", DataType::UInt32, true))
1875                    ),
1876                ],
1877                HashMap::new()
1878            )?,
1879        );
1880
1881        Ok(())
1882    }
1883
1884    #[test]
1885    fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1886        // test join column push down without explicit column projections
1887
1888        let table_scan = test_table_scan()?;
1889
1890        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1891        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1892
1893        let plan = LogicalPlanBuilder::from(table_scan)
1894            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1895            // projecting joined column `a` should push the right side column `c1` projection as
1896            // well into test2 table even though `c1` is not referenced in projection.
1897            .project(vec![col("a"), col("b")])?
1898            .build()?;
1899
1900        let optimized_plan = optimize(plan)?;
1901
1902        // make sure projections are pushed down to both table scans
1903        assert_snapshot!(
1904            optimized_plan.clone(),
1905            @r"
1906        Projection: test.a, test.b
1907          Left Join: test.a = test2.c1
1908            TableScan: test projection=[a, b]
1909            TableScan: test2 projection=[c1]
1910        "
1911        );
1912
1913        // make sure schema for join node include both join columns
1914        let optimized_join = optimized_plan.inputs()[0];
1915        assert_eq!(
1916            **optimized_join.schema(),
1917            DFSchema::new_with_metadata(
1918                vec![
1919                    (
1920                        Some("test".into()),
1921                        Arc::new(Field::new("a", DataType::UInt32, false))
1922                    ),
1923                    (
1924                        Some("test".into()),
1925                        Arc::new(Field::new("b", DataType::UInt32, false))
1926                    ),
1927                    (
1928                        Some("test2".into()),
1929                        Arc::new(Field::new("c1", DataType::UInt32, true))
1930                    ),
1931                ],
1932                HashMap::new()
1933            )?,
1934        );
1935
1936        Ok(())
1937    }
1938
1939    #[test]
1940    fn join_schema_trim_using_join() -> Result<()> {
1941        // shared join columns from using join should be pushed to both sides
1942
1943        let table_scan = test_table_scan()?;
1944
1945        let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1946        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1947
1948        let plan = LogicalPlanBuilder::from(table_scan)
1949            .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1950            .project(vec![col("a"), col("b")])?
1951            .build()?;
1952
1953        let optimized_plan = optimize(plan)?;
1954
1955        // make sure projections are pushed down to table scan
1956        assert_snapshot!(
1957            optimized_plan.clone(),
1958            @r"
1959        Projection: test.a, test.b
1960          Left Join: Using test.a = test2.a
1961            TableScan: test projection=[a, b]
1962            TableScan: test2 projection=[a]
1963        "
1964        );
1965
1966        // make sure schema for join node include both join columns
1967        let optimized_join = optimized_plan.inputs()[0];
1968        assert_eq!(
1969            **optimized_join.schema(),
1970            DFSchema::new_with_metadata(
1971                vec![
1972                    (
1973                        Some("test".into()),
1974                        Arc::new(Field::new("a", DataType::UInt32, false))
1975                    ),
1976                    (
1977                        Some("test".into()),
1978                        Arc::new(Field::new("b", DataType::UInt32, false))
1979                    ),
1980                    (
1981                        Some("test2".into()),
1982                        Arc::new(Field::new("a", DataType::UInt32, true))
1983                    ),
1984                ],
1985                HashMap::new()
1986            )?,
1987        );
1988
1989        Ok(())
1990    }
1991
1992    #[test]
1993    fn cast() -> Result<()> {
1994        let table_scan = test_table_scan()?;
1995
1996        let plan = LogicalPlanBuilder::from(table_scan)
1997            .project(vec![Expr::Cast(Cast::new(
1998                Box::new(col("c")),
1999                DataType::Float64,
2000            ))])?
2001            .build()?;
2002
2003        assert_optimized_plan_equal!(
2004            plan,
2005            @r"
2006        Projection: CAST(test.c AS Float64)
2007          TableScan: test projection=[c]
2008        "
2009        )
2010    }
2011
2012    #[test]
2013    fn table_scan_projected_schema() -> Result<()> {
2014        let table_scan = test_table_scan()?;
2015        let plan = LogicalPlanBuilder::from(test_table_scan()?)
2016            .project(vec![col("a"), col("b")])?
2017            .build()?;
2018
2019        assert_eq!(3, table_scan.schema().fields().len());
2020        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2021        assert_fields_eq(&plan, vec!["a", "b"]);
2022
2023        assert_optimized_plan_equal!(
2024            plan,
2025            @"TableScan: test projection=[a, b]"
2026        )
2027    }
2028
2029    #[test]
2030    fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
2031        let table_scan = test_table_scan()?;
2032        let input_schema = table_scan.schema();
2033        assert_eq!(3, input_schema.fields().len());
2034        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2035
2036        // Build the LogicalPlan directly (don't use PlanBuilder), so
2037        // that the Column references are unqualified (e.g. their
2038        // relation is `None`). PlanBuilder resolves the expressions
2039        let expr = vec![col("test.a"), col("test.b")];
2040        let plan =
2041            LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
2042
2043        assert_fields_eq(&plan, vec!["a", "b"]);
2044
2045        assert_optimized_plan_equal!(
2046            plan,
2047            @"TableScan: test projection=[a, b]"
2048        )
2049    }
2050
2051    #[test]
2052    fn table_limit() -> Result<()> {
2053        let table_scan = test_table_scan()?;
2054        assert_eq!(3, table_scan.schema().fields().len());
2055        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2056
2057        let plan = LogicalPlanBuilder::from(table_scan)
2058            .project(vec![col("c"), col("a")])?
2059            .limit(0, Some(5))?
2060            .build()?;
2061
2062        assert_fields_eq(&plan, vec!["c", "a"]);
2063
2064        assert_optimized_plan_equal!(
2065            plan,
2066            @r"
2067        Limit: skip=0, fetch=5
2068          Projection: test.c, test.a
2069            TableScan: test projection=[a, c]
2070        "
2071        )
2072    }
2073
2074    #[test]
2075    fn table_scan_without_projection() -> Result<()> {
2076        let table_scan = test_table_scan()?;
2077        let plan = LogicalPlanBuilder::from(table_scan).build()?;
2078        // should expand projection to all columns without projection
2079        assert_optimized_plan_equal!(
2080            plan,
2081            @"TableScan: test projection=[a, b, c]"
2082        )
2083    }
2084
2085    #[test]
2086    fn table_scan_with_literal_projection() -> Result<()> {
2087        let table_scan = test_table_scan()?;
2088        let plan = LogicalPlanBuilder::from(table_scan)
2089            .project(vec![lit(1_i64), lit(2_i64)])?
2090            .build()?;
2091        assert_optimized_plan_equal!(
2092            plan,
2093            @r"
2094        Projection: Int64(1), Int64(2)
2095          TableScan: test projection=[]
2096        "
2097        )
2098    }
2099
2100    /// tests that it removes unused columns in projections
2101    #[test]
2102    fn table_unused_column() -> Result<()> {
2103        let table_scan = test_table_scan()?;
2104        assert_eq!(3, table_scan.schema().fields().len());
2105        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2106
2107        // we never use "b" in the first projection => remove it
2108        let plan = LogicalPlanBuilder::from(table_scan)
2109            .project(vec![col("c"), col("a"), col("b")])?
2110            .filter(col("c").gt(lit(1)))?
2111            .aggregate(vec![col("c")], vec![max(col("a"))])?
2112            .build()?;
2113
2114        assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
2115
2116        let plan = optimize(plan).expect("failed to optimize plan");
2117        assert_optimized_plan_equal!(
2118            plan,
2119            @r"
2120        Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2121          Filter: test.c > Int32(1)
2122            Projection: test.c, test.a
2123              TableScan: test projection=[a, c]
2124        "
2125        )
2126    }
2127
2128    /// tests that it removes un-needed projections
2129    #[test]
2130    fn table_unused_projection() -> Result<()> {
2131        let table_scan = test_table_scan()?;
2132        assert_eq!(3, table_scan.schema().fields().len());
2133        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2134
2135        // there is no need for the first projection
2136        let plan = LogicalPlanBuilder::from(table_scan)
2137            .project(vec![col("b")])?
2138            .project(vec![lit(1).alias("a")])?
2139            .build()?;
2140
2141        assert_fields_eq(&plan, vec!["a"]);
2142
2143        assert_optimized_plan_equal!(
2144            plan,
2145            @r"
2146        Projection: Int32(1) AS a
2147          TableScan: test projection=[]
2148        "
2149        )
2150    }
2151
2152    #[test]
2153    fn table_full_filter_pushdown() -> Result<()> {
2154        let schema = Schema::new(test_table_scan_fields());
2155
2156        let table_scan = table_scan_with_filters(
2157            Some("test"),
2158            &schema,
2159            None,
2160            vec![col("b").eq(lit(1))],
2161        )?
2162        .build()?;
2163        assert_eq!(3, table_scan.schema().fields().len());
2164        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2165
2166        // there is no need for the first projection
2167        let plan = LogicalPlanBuilder::from(table_scan)
2168            .project(vec![col("b")])?
2169            .project(vec![lit(1).alias("a")])?
2170            .build()?;
2171
2172        assert_fields_eq(&plan, vec!["a"]);
2173
2174        assert_optimized_plan_equal!(
2175            plan,
2176            @r"
2177        Projection: Int32(1) AS a
2178          TableScan: test projection=[], full_filters=[b = Int32(1)]
2179        "
2180        )
2181    }
2182
2183    /// tests that optimizing twice yields same plan
2184    #[test]
2185    fn test_double_optimization() -> Result<()> {
2186        let table_scan = test_table_scan()?;
2187
2188        let plan = LogicalPlanBuilder::from(table_scan)
2189            .project(vec![col("b")])?
2190            .project(vec![lit(1).alias("a")])?
2191            .build()?;
2192
2193        let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2194        let optimized_plan2 =
2195            optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2196
2197        let formatted_plan1 = format!("{optimized_plan1:?}");
2198        let formatted_plan2 = format!("{optimized_plan2:?}");
2199        assert_eq!(formatted_plan1, formatted_plan2);
2200        Ok(())
2201    }
2202
2203    /// tests that it removes an aggregate is never used downstream
2204    #[test]
2205    fn table_unused_aggregate() -> Result<()> {
2206        let table_scan = test_table_scan()?;
2207        assert_eq!(3, table_scan.schema().fields().len());
2208        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2209
2210        // we never use "min(b)" => remove it
2211        let plan = LogicalPlanBuilder::from(table_scan)
2212            .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2213            .filter(col("c").gt(lit(1)))?
2214            .project(vec![col("c"), col("a"), col("max(test.b)")])?
2215            .build()?;
2216
2217        assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2218
2219        assert_optimized_plan_equal!(
2220            plan,
2221            @r"
2222        Projection: test.c, test.a, max(test.b)
2223          Filter: test.c > Int32(1)
2224            Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2225              TableScan: test projection=[a, b, c]
2226        "
2227        )
2228    }
2229
2230    #[test]
2231    fn aggregate_filter_pushdown() -> Result<()> {
2232        let table_scan = test_table_scan()?;
2233        let aggr_with_filter = count_udaf()
2234            .call(vec![col("b")])
2235            .filter(col("c").gt(lit(42)))
2236            .build()?;
2237        let plan = LogicalPlanBuilder::from(table_scan)
2238            .aggregate(
2239                vec![col("a")],
2240                vec![count(col("b")), aggr_with_filter.alias("count2")],
2241            )?
2242            .build()?;
2243
2244        assert_optimized_plan_equal!(
2245            plan,
2246            @r"
2247        Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2248          TableScan: test projection=[a, b, c]
2249        "
2250        )
2251    }
2252
2253    #[test]
2254    fn pushdown_through_distinct() -> Result<()> {
2255        let table_scan = test_table_scan()?;
2256
2257        let plan = LogicalPlanBuilder::from(table_scan)
2258            .project(vec![col("a"), col("b")])?
2259            .distinct()?
2260            .project(vec![col("a")])?
2261            .build()?;
2262
2263        assert_optimized_plan_equal!(
2264            plan,
2265            @r"
2266        Projection: test.a
2267          Distinct:
2268            TableScan: test projection=[a, b]
2269        "
2270        )
2271    }
2272
2273    #[test]
2274    fn test_window() -> Result<()> {
2275        let table_scan = test_table_scan()?;
2276
2277        let max1 = Expr::from(expr::WindowFunction::new(
2278            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2279            vec![col("test.a")],
2280        ))
2281        .partition_by(vec![col("test.b")])
2282        .build()
2283        .unwrap();
2284
2285        let max2 = Expr::from(expr::WindowFunction::new(
2286            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2287            vec![col("test.b")],
2288        ));
2289        let col1 = col(max1.schema_name().to_string());
2290        let col2 = col(max2.schema_name().to_string());
2291
2292        let plan = LogicalPlanBuilder::from(table_scan)
2293            .window(vec![max1])?
2294            .window(vec![max2])?
2295            .project(vec![col1, col2])?
2296            .build()?;
2297
2298        assert_optimized_plan_equal!(
2299            plan,
2300            @r"
2301        Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2302          WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2303            Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2304              WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2305                TableScan: test projection=[a, b]
2306        "
2307        )
2308    }
2309
2310    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2311
2312    fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2313        let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2314        let optimized_plan =
2315            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2316        Ok(optimized_plan)
2317    }
2318}