datafusion_optimizer/optimize_projections/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`OptimizeProjections`] identifies and eliminates unused columns
19
20mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28    Column, DFSchema, HashMap, JoinType, Result, assert_eq_or_internal_err,
29    get_required_group_by_exprs_indices, internal_datafusion_err, internal_err,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33    Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window,
34    logical_plan::LogicalPlan,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43/// Optimizer rule to prune unnecessary columns from intermediate schemas
44/// inside the [`LogicalPlan`]. This rule:
45/// - Removes unnecessary columns that do not appear at the output and/or are
46///   not used during any computation step.
47/// - Adds projections to decrease table column size before operators that
48///   benefit from a smaller memory footprint at its input.
49/// - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`].
50///
51/// `OptimizeProjections` is an optimizer rule that identifies and eliminates
52/// columns from a logical plan that are not used by downstream operations.
53/// This can improve query performance and reduce unnecessary data processing.
54///
55/// The rule analyzes the input logical plan, determines the necessary column
56/// indices, and then removes any unnecessary columns. It also removes any
57/// unnecessary projections from the plan tree.
58///
59/// ## Schema, Field Properties, and Metadata Handling
60///
61/// The `OptimizeProjections` rule preserves schema and field metadata in most optimization scenarios:
62///
63/// **Schema-level metadata preservation by plan type**:
64/// - **Window and Aggregate plans**: Schema metadata is preserved
65/// - **Projection plans**: Schema metadata is preserved per [`projection_schema`](datafusion_expr::logical_plan::projection_schema).
66/// - **Other logical plans**: Schema metadata is preserved unless [`LogicalPlan::recompute_schema`]
67///   is called on plan types that drop metadata
68///
69/// **Field-level properties and metadata**: Individual field properties are preserved when fields
70/// are retained in the optimized plan, determined by [`exprlist_to_fields`](datafusion_expr::utils::exprlist_to_fields)
71/// and [`ExprSchemable::to_field`](datafusion_expr::expr_schema::ExprSchemable::to_field).
72///
73/// **Field precedence**: When the same field appears multiple times, the optimizer
74/// maintains one occurrence and removes duplicates (refer to `RequiredIndices::compact()`),
75/// preserving the properties and metadata of that occurrence.
76#[derive(Default, Debug)]
77pub struct OptimizeProjections {}
78
79impl OptimizeProjections {
80    #[expect(missing_docs)]
81    pub fn new() -> Self {
82        Self {}
83    }
84}
85
86impl OptimizerRule for OptimizeProjections {
87    fn name(&self) -> &str {
88        "optimize_projections"
89    }
90
91    fn apply_order(&self) -> Option<ApplyOrder> {
92        None
93    }
94
95    fn supports_rewrite(&self) -> bool {
96        true
97    }
98
99    fn rewrite(
100        &self,
101        plan: LogicalPlan,
102        config: &dyn OptimizerConfig,
103    ) -> Result<Transformed<LogicalPlan>> {
104        // All output fields are necessary:
105        let indices = RequiredIndices::new_for_all_exprs(&plan);
106        optimize_projections(plan, config, indices)
107    }
108}
109
110/// Removes unnecessary columns (e.g. columns that do not appear in the output
111/// schema and/or are not used during any computation step such as expression
112/// evaluation) from the logical plan and its inputs.
113///
114/// # Parameters
115///
116/// - `plan`: A reference to the input `LogicalPlan` to optimize.
117/// - `config`: A reference to the optimizer configuration.
118/// - `indices`: A slice of column indices that represent the necessary column
119///   indices for downstream (parent) plan nodes.
120///
121/// # Returns
122///
123/// A `Result` object with the following semantics:
124///
125/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary
126///   columns.
127/// - `Ok(None)`: Signal that the given logical plan did not require any change.
128/// - `Err(error)`: An error occurred during the optimization process.
129#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
130fn optimize_projections(
131    plan: LogicalPlan,
132    config: &dyn OptimizerConfig,
133    indices: RequiredIndices,
134) -> Result<Transformed<LogicalPlan>> {
135    // Recursively rewrite any nodes that may be able to avoid computation given
136    // their parents' required indices.
137    match plan {
138        LogicalPlan::Projection(proj) => {
139            return merge_consecutive_projections(proj)?.transform_data(|proj| {
140                rewrite_projection_given_requirements(proj, config, &indices)
141            });
142        }
143        LogicalPlan::Aggregate(aggregate) => {
144            // Split parent requirements to GROUP BY and aggregate sections:
145            let n_group_exprs = aggregate.group_expr_len()?;
146            // Offset aggregate indices so that they point to valid indices at
147            // `aggregate.aggr_expr`:
148            let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
149
150            // Get absolutely necessary GROUP BY fields:
151            let group_by_expr_existing = aggregate
152                .group_expr
153                .iter()
154                .map(|group_by_expr| group_by_expr.schema_name().to_string())
155                .collect::<Vec<_>>();
156
157            let new_group_bys = if let Some(simplest_groupby_indices) =
158                get_required_group_by_exprs_indices(
159                    aggregate.input.schema(),
160                    &group_by_expr_existing,
161                ) {
162                // Some of the fields in the GROUP BY may be required by the
163                // parent even if these fields are unnecessary in terms of
164                // functional dependency.
165                group_by_reqs
166                    .append(&simplest_groupby_indices)
167                    .get_at_indices(&aggregate.group_expr)
168            } else {
169                aggregate.group_expr
170            };
171
172            // Only use the absolutely necessary aggregate expressions required
173            // by the parent:
174            let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
175
176            if new_group_bys.is_empty() && new_aggr_expr.is_empty() {
177                // Global aggregation with no aggregate functions always produces 1 row and no columns.
178                return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
179                    EmptyRelation {
180                        produce_one_row: true,
181                        schema: Arc::new(DFSchema::empty()),
182                    },
183                )));
184            }
185
186            let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
187            let schema = aggregate.input.schema();
188            let necessary_indices =
189                RequiredIndices::new().with_exprs(schema, all_exprs_iter);
190            let necessary_exprs = necessary_indices.get_required_exprs(schema);
191
192            return optimize_projections(
193                Arc::unwrap_or_clone(aggregate.input),
194                config,
195                necessary_indices,
196            )?
197            .transform_data(|aggregate_input| {
198                // Simplify the input of the aggregation by adding a projection so
199                // that its input only contains absolutely necessary columns for
200                // the aggregate expressions. Note that necessary_indices refer to
201                // fields in `aggregate.input.schema()`.
202                add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
203            })?
204            .map_data(|aggregate_input| {
205                // Create a new aggregate plan with the updated input and only the
206                // absolutely necessary fields:
207                Aggregate::try_new(
208                    Arc::new(aggregate_input),
209                    new_group_bys,
210                    new_aggr_expr,
211                )
212                .map(LogicalPlan::Aggregate)
213            });
214        }
215        LogicalPlan::Window(window) => {
216            let input_schema = Arc::clone(window.input.schema());
217            // Split parent requirements to child and window expression sections:
218            let n_input_fields = input_schema.fields().len();
219            // Offset window expression indices so that they point to valid
220            // indices at `window.window_expr`:
221            let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
222
223            // Only use window expressions that are absolutely necessary according
224            // to parent requirements:
225            let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
226
227            // Get all the required column indices at the input, either by the
228            // parent or window expression requirements.
229            let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
230
231            return optimize_projections(
232                Arc::unwrap_or_clone(window.input),
233                config,
234                required_indices.clone(),
235            )?
236            .transform_data(|window_child| {
237                if new_window_expr.is_empty() {
238                    // When no window expression is necessary, use the input directly:
239                    Ok(Transformed::no(window_child))
240                } else {
241                    // Calculate required expressions at the input of the window.
242                    // Please note that we use `input_schema`, because `required_indices`
243                    // refers to that schema
244                    let required_exprs =
245                        required_indices.get_required_exprs(&input_schema);
246                    let window_child =
247                        add_projection_on_top_if_helpful(window_child, required_exprs)?
248                            .data;
249                    Window::try_new(new_window_expr, Arc::new(window_child))
250                        .map(LogicalPlan::Window)
251                        .map(Transformed::yes)
252                }
253            });
254        }
255        LogicalPlan::TableScan(table_scan) => {
256            let TableScan {
257                table_name,
258                source,
259                projection,
260                filters,
261                fetch,
262                projected_schema: _,
263            } = table_scan;
264
265            // Get indices referred to in the original (schema with all fields)
266            // given projected indices.
267            let projection = match &projection {
268                Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
269                None => indices.into_inner(),
270            };
271            return TableScan::try_new(
272                table_name,
273                source,
274                Some(projection),
275                filters,
276                fetch,
277            )
278            .map(LogicalPlan::TableScan)
279            .map(Transformed::yes);
280        }
281        // Other node types are handled below
282        _ => {}
283    };
284
285    // For other plan node types, calculate indices for columns they use and
286    // try to rewrite their children
287    let mut child_required_indices: Vec<RequiredIndices> = match &plan {
288        LogicalPlan::Sort(_)
289        | LogicalPlan::Filter(_)
290        | LogicalPlan::Repartition(_)
291        | LogicalPlan::Union(_)
292        | LogicalPlan::SubqueryAlias(_)
293        | LogicalPlan::Distinct(Distinct::On(_)) => {
294            // Pass index requirements from the parent as well as column indices
295            // that appear in this plan's expressions to its child. All these
296            // operators benefit from "small" inputs, so the projection_beneficial
297            // flag is `true`.
298            plan.inputs()
299                .into_iter()
300                .map(|input| {
301                    indices
302                        .clone()
303                        .with_projection_beneficial()
304                        .with_plan_exprs(&plan, input.schema())
305                })
306                .collect::<Result<_>>()?
307        }
308        LogicalPlan::Limit(_) => {
309            // Pass index requirements from the parent as well as column indices
310            // that appear in this plan's expressions to its child. These operators
311            // do not benefit from "small" inputs, so the projection_beneficial
312            // flag is `false`.
313            plan.inputs()
314                .into_iter()
315                .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
316                .collect::<Result<_>>()?
317        }
318        LogicalPlan::Copy(_)
319        | LogicalPlan::Ddl(_)
320        | LogicalPlan::Dml(_)
321        | LogicalPlan::Explain(_)
322        | LogicalPlan::Analyze(_)
323        | LogicalPlan::Subquery(_)
324        | LogicalPlan::Statement(_)
325        | LogicalPlan::Distinct(Distinct::All(_)) => {
326            // These plans require all their fields, and their children should
327            // be treated as final plans -- otherwise, we may have schema a
328            // mismatch.
329            // TODO: For some subquery variants (e.g. a subquery arising from an
330            //       EXISTS expression), we may not need to require all indices.
331            plan.inputs()
332                .into_iter()
333                .map(RequiredIndices::new_for_all_exprs)
334                .collect()
335        }
336        LogicalPlan::Extension(extension) => {
337            let Some(necessary_children_indices) =
338                extension.node.necessary_children_exprs(indices.indices())
339            else {
340                // Requirements from parent cannot be routed down to user defined logical plan safely
341                return Ok(Transformed::no(plan));
342            };
343            let children = extension.node.inputs();
344            assert_eq_or_internal_err!(
345                children.len(),
346                necessary_children_indices.len(),
347                "Inconsistent length between children and necessary children indices. \
348                Make sure `.necessary_children_exprs` implementation of the \
349                `UserDefinedLogicalNode` is consistent with actual children length \
350                for the node."
351            );
352            children
353                .into_iter()
354                .zip(necessary_children_indices)
355                .map(|(child, necessary_indices)| {
356                    RequiredIndices::new_from_indices(necessary_indices)
357                        .with_plan_exprs(&plan, child.schema())
358                })
359                .collect::<Result<Vec<_>>>()?
360        }
361        LogicalPlan::EmptyRelation(_)
362        | LogicalPlan::Values(_)
363        | LogicalPlan::DescribeTable(_) => {
364            // These operators have no inputs, so stop the optimization process.
365            return Ok(Transformed::no(plan));
366        }
367        LogicalPlan::RecursiveQuery(recursive) => {
368            // Only allow subqueries that reference the current CTE; nested subqueries are not yet
369            // supported for projection pushdown for simplicity.
370            // TODO: be able to do projection pushdown on recursive CTEs with subqueries
371            if plan_contains_other_subqueries(
372                recursive.static_term.as_ref(),
373                &recursive.name,
374            ) || plan_contains_other_subqueries(
375                recursive.recursive_term.as_ref(),
376                &recursive.name,
377            ) {
378                return Ok(Transformed::no(plan));
379            }
380
381            plan.inputs()
382                .into_iter()
383                .map(|input| {
384                    indices
385                        .clone()
386                        .with_projection_beneficial()
387                        .with_plan_exprs(&plan, input.schema())
388                })
389                .collect::<Result<Vec<_>>>()?
390        }
391        LogicalPlan::Join(join) => {
392            let left_len = join.left.schema().fields().len();
393            let (left_req_indices, right_req_indices) =
394                split_join_requirements(left_len, indices, &join.join_type);
395            let left_indices =
396                left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
397            let right_indices =
398                right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
399            // Joins benefit from "small" input tables (lower memory usage).
400            // Therefore, each child benefits from projection:
401            vec![
402                left_indices.with_projection_beneficial(),
403                right_indices.with_projection_beneficial(),
404            ]
405        }
406        // these nodes are explicitly rewritten in the match statement above
407        LogicalPlan::Projection(_)
408        | LogicalPlan::Aggregate(_)
409        | LogicalPlan::Window(_)
410        | LogicalPlan::TableScan(_) => {
411            return internal_err!(
412                "OptimizeProjection: should have handled in the match statement above"
413            );
414        }
415        LogicalPlan::Unnest(Unnest {
416            input,
417            dependency_indices,
418            ..
419        }) => {
420            // at least provide the indices for the exec-columns as a starting point
421            let required_indices =
422                RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
423
424            // Add additional required indices from the parent
425            let mut additional_necessary_child_indices = Vec::new();
426            indices.indices().iter().for_each(|idx| {
427                if let Some(index) = dependency_indices.get(*idx) {
428                    additional_necessary_child_indices.push(*index);
429                }
430            });
431            vec![required_indices.append(&additional_necessary_child_indices)]
432        }
433    };
434
435    // Required indices are currently ordered (child0, child1, ...)
436    // but the loop pops off the last element, so we need to reverse the order
437    child_required_indices.reverse();
438    assert_eq_or_internal_err!(
439        child_required_indices.len(),
440        plan.inputs().len(),
441        "OptimizeProjection: child_required_indices length mismatch with plan inputs"
442    );
443
444    // Rewrite children of the plan
445    let transformed_plan = plan.map_children(|child| {
446        let required_indices = child_required_indices.pop().ok_or_else(|| {
447            internal_datafusion_err!(
448                "Unexpected number of required_indices in OptimizeProjections rule"
449            )
450        })?;
451
452        let projection_beneficial = required_indices.projection_beneficial();
453        let project_exprs = required_indices.get_required_exprs(child.schema());
454
455        optimize_projections(child, config, required_indices)?.transform_data(
456            |new_input| {
457                if projection_beneficial {
458                    add_projection_on_top_if_helpful(new_input, project_exprs)
459                } else {
460                    Ok(Transformed::no(new_input))
461                }
462            },
463        )
464    })?;
465
466    // If any of the children are transformed, we need to potentially update the plan's schema
467    if transformed_plan.transformed {
468        transformed_plan.map_data(|plan| plan.recompute_schema())
469    } else {
470        Ok(transformed_plan)
471    }
472}
473
474/// Merges consecutive projections.
475///
476/// Given a projection `proj`, this function attempts to merge it with a previous
477/// projection if it exists and if merging is beneficial. Merging is considered
478/// beneficial when expressions in the current projection are non-trivial and
479/// appear more than once in its input fields. This can act as a caching mechanism
480/// for non-trivial computations.
481///
482/// ## Metadata Handling During Projection Merging
483///
484/// **Alias metadata preservation**: When merging projections, alias metadata from both
485/// the current and previous projections is carefully preserved. The presence of metadata
486/// precludes alias trimming.
487///
488/// **Schema, Fields, and metadata**: If a projection is rewritten, the schema and metadata
489/// are preserved. Individual field properties and metadata flows through expression rewriting
490/// and are preserved when fields are referenced in the merged projection.
491/// Refer to [`projection_schema`](datafusion_expr::logical_plan::projection_schema)
492/// for more details.
493///
494/// # Parameters
495///
496/// * `proj` - A reference to the `Projection` to be merged.
497///
498/// # Returns
499///
500/// A `Result` object with the following semantics:
501///
502/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the
503///   merged projection.
504/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place).
505/// - `Err(error)`: An error occurred during the function call.
506fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
507    let Projection {
508        expr,
509        input,
510        schema,
511        ..
512    } = proj;
513    let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
514        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
515    };
516
517    // A fast path: if the previous projection is same as the current projection
518    // we can directly remove the current projection and return child projection.
519    if prev_projection.expr == expr {
520        return Projection::try_new_with_schema(
521            expr,
522            Arc::clone(&prev_projection.input),
523            schema,
524        )
525        .map(Transformed::yes);
526    }
527
528    // Count usages (referrals) of each projection expression in its input fields:
529    let mut column_referral_map = HashMap::<&Column, usize>::new();
530    expr.iter()
531        .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
532
533    // If an expression is non-trivial and appears more than once, do not merge
534    // them as consecutive projections will benefit from a compute-once approach.
535    // For details, see: https://github.com/apache/datafusion/issues/8296
536    if column_referral_map.into_iter().any(|(col, usage)| {
537        usage > 1
538            && !is_expr_trivial(
539                &prev_projection.expr
540                    [prev_projection.schema.index_of_column(col).unwrap()],
541            )
542    }) {
543        // no change
544        return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
545    }
546
547    let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
548        // We know it is a `LogicalPlan::Projection` from check above
549        unreachable!();
550    };
551
552    // Try to rewrite the expressions in the current projection using the
553    // previous projection as input:
554    let name_preserver = NamePreserver::new_for_projection();
555    let mut original_names = vec![];
556    let new_exprs = expr.map_elements(|expr| {
557        original_names.push(name_preserver.save(&expr));
558
559        // do not rewrite top level Aliases (rewriter will remove all aliases within exprs)
560        match expr {
561            Expr::Alias(Alias {
562                expr,
563                relation,
564                name,
565                metadata,
566            }) => rewrite_expr(*expr, &prev_projection).map(|result| {
567                result.update_data(|expr| {
568                    Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata))
569                })
570            }),
571            e => rewrite_expr(e, &prev_projection),
572        }
573    })?;
574
575    // if the expressions could be rewritten, create a new projection with the
576    // new expressions
577    if new_exprs.transformed {
578        // Add any needed aliases back to the expressions
579        let new_exprs = new_exprs
580            .data
581            .into_iter()
582            .zip(original_names)
583            .map(|(expr, original_name)| original_name.restore(expr))
584            .collect::<Vec<_>>();
585        Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
586    } else {
587        // not rewritten, so put the projection back together
588        let input = Arc::new(LogicalPlan::Projection(prev_projection));
589        Projection::try_new_with_schema(new_exprs.data, input, schema)
590            .map(Transformed::no)
591    }
592}
593
594// Check whether `expr` is trivial; i.e. it doesn't imply any computation.
595fn is_expr_trivial(expr: &Expr) -> bool {
596    matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
597}
598
599/// Rewrites a projection expression using the projection before it (i.e. its input)
600/// This is a subroutine to the `merge_consecutive_projections` function.
601///
602/// # Parameters
603///
604/// * `expr` - A reference to the expression to rewrite.
605/// * `input` - A reference to the input of the projection expression (itself
606///   a projection).
607///
608/// # Returns
609///
610/// A `Result` object with the following semantics:
611///
612/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result.
613/// - `Ok(None)`: Signals that `expr` can not be rewritten.
614/// - `Err(error)`: An error occurred during the function call.
615///
616/// # Notes
617/// This rewrite also removes any unnecessary layers of aliasing. "Unnecessary" is
618/// defined as not contributing new information, such as metadata.
619///
620/// Without trimming, we can end up with unnecessary indirections inside expressions
621/// during projection merges.
622///
623/// Consider:
624///
625/// ```text
626/// Projection(a1 + b1 as sum1)
627/// --Projection(a as a1, b as b1)
628/// ----Source(a, b)
629/// ```
630///
631/// After merge, we want to produce:
632///
633/// ```text
634/// Projection(a + b as sum1)
635/// --Source(a, b)
636/// ```
637///
638/// Without trimming, we would end up with:
639///
640/// ```text
641/// Projection((a as a1 + b as b1) as sum1)
642/// --Source(a, b)
643/// ```
644fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
645    expr.transform_up(|expr| {
646        match expr {
647            //  remove any intermediate aliases if they do not carry metadata
648            Expr::Alias(alias) => {
649                match alias
650                    .metadata
651                    .as_ref()
652                    .map(|h| h.is_empty())
653                    .unwrap_or(true)
654                {
655                    true => Ok(Transformed::yes(*alias.expr)),
656                    false => Ok(Transformed::no(Expr::Alias(alias))),
657                }
658            }
659            Expr::Column(col) => {
660                // Find index of column:
661                let idx = input.schema.index_of_column(&col)?;
662                // get the corresponding unaliased input expression
663                //
664                // For example:
665                // * the input projection is [`a + b` as c, `d + e` as f]
666                // * the current column is an expression "f"
667                //
668                // return the expression `d + e` (not `d + e` as f)
669                let input_expr = input.expr[idx].clone().unalias_nested().data;
670                Ok(Transformed::yes(input_expr))
671            }
672            // Unsupported type for consecutive projection merge analysis.
673            _ => Ok(Transformed::no(expr)),
674        }
675    })
676}
677
678/// Accumulates outer-referenced columns by the
679/// given expression, `expr`.
680///
681/// # Parameters
682///
683/// * `expr` - The expression to analyze for outer-referenced columns.
684/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
685///   columns are collected.
686fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
687    // inspect_expr_pre doesn't handle subquery references, so find them explicitly
688    expr.apply(|expr| {
689        match expr {
690            Expr::OuterReferenceColumn(_, col) => {
691                columns.insert(col);
692            }
693            Expr::ScalarSubquery(subquery) => {
694                outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
695            }
696            Expr::Exists(exists) => {
697                outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
698            }
699            Expr::InSubquery(insubquery) => {
700                outer_columns_helper_multi(
701                    &insubquery.subquery.outer_ref_columns,
702                    columns,
703                );
704            }
705            _ => {}
706        };
707        Ok(TreeNodeRecursion::Continue)
708    })
709    // unwrap: closure above never returns Err, so can not be Err here
710    .unwrap();
711}
712
713/// A recursive subroutine that accumulates outer-referenced columns by the
714/// given expressions (`exprs`).
715///
716/// # Parameters
717///
718/// * `exprs` - The expressions to analyze for outer-referenced columns.
719/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
720///   columns are collected.
721fn outer_columns_helper_multi<'a, 'b>(
722    exprs: impl IntoIterator<Item = &'a Expr>,
723    columns: &'b mut HashSet<&'a Column>,
724) {
725    exprs.into_iter().for_each(|e| outer_columns(e, columns));
726}
727
728/// Splits requirement indices for a join into left and right children based on
729/// the join type.
730///
731/// This function takes the length of the left child, a slice of requirement
732/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments.
733/// Depending on the join type, it divides the requirement indices into those
734/// that apply to the left child and those that apply to the right child.
735///
736/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins,
737///   the requirements are split between left and right children. The right
738///   child indices are adjusted to point to valid positions within the right
739///   child by subtracting the length of the left child.
740///
741/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all
742///   requirements are re-routed to either the left child or the right child
743///   directly, depending on the join type.
744///
745/// # Parameters
746///
747/// * `left_len` - The length of the left child.
748/// * `indices` - A slice of requirement indices.
749/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`).
750///
751/// # Returns
752///
753/// A tuple containing two vectors of `usize` indices: The first vector represents
754/// the requirements for the left child, and the second vector represents the
755/// requirements for the right child. The indices are appropriately split and
756/// adjusted based on the join type.
757fn split_join_requirements(
758    left_len: usize,
759    indices: RequiredIndices,
760    join_type: &JoinType,
761) -> (RequiredIndices, RequiredIndices) {
762    match join_type {
763        // In these cases requirements are split between left/right children:
764        JoinType::Inner
765        | JoinType::Left
766        | JoinType::Right
767        | JoinType::Full
768        | JoinType::LeftMark
769        | JoinType::RightMark => {
770            // Decrease right side indices by `left_len` so that they point to valid
771            // positions within the right child:
772            indices.split_off(left_len)
773        }
774        // All requirements can be re-routed to left child directly.
775        JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
776        // All requirements can be re-routed to right side directly.
777        // No need to change index, join schema is right child schema.
778        JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
779    }
780}
781
782/// Adds a projection on top of a logical plan if doing so reduces the number
783/// of columns for the parent operator.
784///
785/// This function takes a `LogicalPlan` and a list of projection expressions.
786/// If the projection is beneficial (it reduces the number of columns in the
787/// plan) a new `LogicalPlan` with the projection is created and returned, along
788/// with a `true` flag. If the projection doesn't reduce the number of columns,
789/// the original plan is returned with a `false` flag.
790///
791/// # Parameters
792///
793/// * `plan` - The input `LogicalPlan` to potentially add a projection to.
794/// * `project_exprs` - A list of expressions for the projection.
795///
796/// # Returns
797///
798/// A `Transformed` indicating if a projection was added
799fn add_projection_on_top_if_helpful(
800    plan: LogicalPlan,
801    project_exprs: Vec<Expr>,
802) -> Result<Transformed<LogicalPlan>> {
803    // Make sure projection decreases the number of columns, otherwise it is unnecessary.
804    if project_exprs.len() >= plan.schema().fields().len() {
805        Ok(Transformed::no(plan))
806    } else {
807        Projection::try_new(project_exprs, Arc::new(plan))
808            .map(LogicalPlan::Projection)
809            .map(Transformed::yes)
810    }
811}
812
813/// Rewrite the given projection according to the fields required by its
814/// ancestors.
815///
816/// # Parameters
817///
818/// * `proj` - A reference to the original projection to rewrite.
819/// * `config` - A reference to the optimizer configuration.
820/// * `indices` - A slice of indices representing the columns required by the
821///   ancestors of the given projection.
822///
823/// # Returns
824///
825/// A `Result` object with the following semantics:
826///
827/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection
828/// - `Ok(None)`: No rewrite necessary.
829/// - `Err(error)`: An error occurred during the function call.
830fn rewrite_projection_given_requirements(
831    proj: Projection,
832    config: &dyn OptimizerConfig,
833    indices: &RequiredIndices,
834) -> Result<Transformed<LogicalPlan>> {
835    let Projection { expr, input, .. } = proj;
836
837    let exprs_used = indices.get_at_indices(&expr);
838
839    let required_indices =
840        RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
841
842    // rewrite the children projection, and if they are changed rewrite the
843    // projection down
844    optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
845        .transform_data(|input| {
846            if is_projection_unnecessary(&input, &exprs_used)? {
847                Ok(Transformed::yes(input))
848            } else {
849                Projection::try_new(exprs_used, Arc::new(input))
850                    .map(LogicalPlan::Projection)
851                    .map(Transformed::yes)
852            }
853        })
854}
855
856/// Projection is unnecessary, when
857/// - input schema of the projection, output schema of the projection are same, and
858/// - all projection expressions are either Column or Literal
859pub fn is_projection_unnecessary(
860    input: &LogicalPlan,
861    proj_exprs: &[Expr],
862) -> Result<bool> {
863    // First check if the number of expressions is equal to the number of fields in the input schema.
864    if proj_exprs.len() != input.schema().fields().len() {
865        return Ok(false);
866    }
867    Ok(input.schema().iter().zip(proj_exprs.iter()).all(
868        |((field_relation, field_name), expr)| {
869            // Check if the expression is a column and if it matches the field name
870            if let Expr::Column(col) = expr {
871                col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
872            } else {
873                false
874            }
875        },
876    ))
877}
878
879/// Returns true if the plan subtree contains any subqueries that are not the
880/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`]
881/// node (including aliased relations) as a blocker, along with expression-level
882/// subqueries like scalar, EXISTS, or IN. These cases prevent projection
883/// pushdown for now because we cannot safely reason about their column usage.
884fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
885    if let LogicalPlan::SubqueryAlias(alias) = plan
886        && alias.alias.table() != cte_name
887        && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
888    {
889        return true;
890    }
891
892    let mut found = false;
893    plan.apply_expressions(|expr| {
894        if expr_contains_subquery(expr) {
895            found = true;
896            Ok(TreeNodeRecursion::Stop)
897        } else {
898            Ok(TreeNodeRecursion::Continue)
899        }
900    })
901    .expect("expression traversal never fails");
902    if found {
903        return true;
904    }
905
906    plan.inputs()
907        .into_iter()
908        .any(|child| plan_contains_other_subqueries(child, cte_name))
909}
910
911fn expr_contains_subquery(expr: &Expr) -> bool {
912    expr.exists(|e| match e {
913        Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
914        _ => Ok(false),
915    })
916    // Safe unwrap since we are doing a simple boolean check
917    .unwrap()
918}
919
920fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool {
921    match plan {
922        LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name,
923        LogicalPlan::SubqueryAlias(alias) => {
924            subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
925        }
926        _ => {
927            let inputs = plan.inputs();
928            if inputs.len() == 1 {
929                subquery_alias_targets_recursive_cte(inputs[0], cte_name)
930            } else {
931                false
932            }
933        }
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use std::cmp::Ordering;
940    use std::collections::HashMap;
941    use std::fmt::Formatter;
942    use std::ops::Add;
943    use std::sync::Arc;
944    use std::vec;
945
946    use crate::optimize_projections::OptimizeProjections;
947    use crate::optimizer::Optimizer;
948    use crate::test::{
949        assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
950        test_table_scan_with_name,
951    };
952    use crate::{OptimizerContext, OptimizerRule};
953    use arrow::datatypes::{DataType, Field, Schema};
954    use datafusion_common::{
955        Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
956    };
957    use datafusion_expr::ExprFunctionExt;
958    use datafusion_expr::{
959        BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection,
960        UserDefinedLogicalNodeCore, WindowFunctionDefinition, binary_expr,
961        build_join_schema,
962        builder::table_scan_with_filters,
963        col,
964        expr::{self, Cast},
965        lit,
966        logical_plan::{builder::LogicalPlanBuilder, table_scan},
967        not, try_cast, when,
968    };
969    use insta::assert_snapshot;
970
971    use crate::assert_optimized_plan_eq_snapshot;
972    use datafusion_functions_aggregate::count::count_udaf;
973    use datafusion_functions_aggregate::expr_fn::{count, max, min};
974    use datafusion_functions_aggregate::min_max::max_udaf;
975
976    macro_rules! assert_optimized_plan_equal {
977        (
978            $plan:expr,
979            @ $expected:literal $(,)?
980        ) => {{
981            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
982            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
983            assert_optimized_plan_eq_snapshot!(
984                optimizer_ctx,
985                rules,
986                $plan,
987                @ $expected,
988            )
989        }};
990    }
991
992    #[derive(Debug, Hash, PartialEq, Eq)]
993    struct NoOpUserDefined {
994        exprs: Vec<Expr>,
995        schema: DFSchemaRef,
996        input: Arc<LogicalPlan>,
997    }
998
999    impl NoOpUserDefined {
1000        fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
1001            Self {
1002                exprs: vec![],
1003                schema,
1004                input,
1005            }
1006        }
1007
1008        fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
1009            self.exprs = exprs;
1010            self
1011        }
1012    }
1013
1014    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1015    impl PartialOrd for NoOpUserDefined {
1016        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1017            match self.exprs.partial_cmp(&other.exprs) {
1018                Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
1019                cmp => cmp,
1020            }
1021            // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
1022            .filter(|cmp| *cmp != Ordering::Equal || self == other)
1023        }
1024    }
1025
1026    impl UserDefinedLogicalNodeCore for NoOpUserDefined {
1027        fn name(&self) -> &str {
1028            "NoOpUserDefined"
1029        }
1030
1031        fn inputs(&self) -> Vec<&LogicalPlan> {
1032            vec![&self.input]
1033        }
1034
1035        fn schema(&self) -> &DFSchemaRef {
1036            &self.schema
1037        }
1038
1039        fn expressions(&self) -> Vec<Expr> {
1040            self.exprs.clone()
1041        }
1042
1043        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1044            write!(f, "NoOpUserDefined")
1045        }
1046
1047        fn with_exprs_and_inputs(
1048            &self,
1049            exprs: Vec<Expr>,
1050            mut inputs: Vec<LogicalPlan>,
1051        ) -> Result<Self> {
1052            Ok(Self {
1053                exprs,
1054                input: Arc::new(inputs.swap_remove(0)),
1055                schema: Arc::clone(&self.schema),
1056            })
1057        }
1058
1059        fn necessary_children_exprs(
1060            &self,
1061            output_columns: &[usize],
1062        ) -> Option<Vec<Vec<usize>>> {
1063            // Since schema is same. Output columns requires their corresponding version in the input columns.
1064            Some(vec![output_columns.to_vec()])
1065        }
1066
1067        fn supports_limit_pushdown(&self) -> bool {
1068            false // Disallow limit push-down by default
1069        }
1070    }
1071
1072    #[derive(Debug, Hash, PartialEq, Eq)]
1073    struct UserDefinedCrossJoin {
1074        exprs: Vec<Expr>,
1075        schema: DFSchemaRef,
1076        left_child: Arc<LogicalPlan>,
1077        right_child: Arc<LogicalPlan>,
1078    }
1079
1080    impl UserDefinedCrossJoin {
1081        fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
1082            let left_schema = left_child.schema();
1083            let right_schema = right_child.schema();
1084            let schema = Arc::new(
1085                build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
1086            );
1087            Self {
1088                exprs: vec![],
1089                schema,
1090                left_child,
1091                right_child,
1092            }
1093        }
1094    }
1095
1096    // Manual implementation needed because of `schema` field. Comparison excludes this field.
1097    impl PartialOrd for UserDefinedCrossJoin {
1098        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1099            match self.exprs.partial_cmp(&other.exprs) {
1100                Some(Ordering::Equal) => {
1101                    match self.left_child.partial_cmp(&other.left_child) {
1102                        Some(Ordering::Equal) => {
1103                            self.right_child.partial_cmp(&other.right_child)
1104                        }
1105                        cmp => cmp,
1106                    }
1107                }
1108                cmp => cmp,
1109            }
1110            // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
1111            .filter(|cmp| *cmp != Ordering::Equal || self == other)
1112        }
1113    }
1114
1115    impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
1116        fn name(&self) -> &str {
1117            "UserDefinedCrossJoin"
1118        }
1119
1120        fn inputs(&self) -> Vec<&LogicalPlan> {
1121            vec![&self.left_child, &self.right_child]
1122        }
1123
1124        fn schema(&self) -> &DFSchemaRef {
1125            &self.schema
1126        }
1127
1128        fn expressions(&self) -> Vec<Expr> {
1129            self.exprs.clone()
1130        }
1131
1132        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1133            write!(f, "UserDefinedCrossJoin")
1134        }
1135
1136        fn with_exprs_and_inputs(
1137            &self,
1138            exprs: Vec<Expr>,
1139            mut inputs: Vec<LogicalPlan>,
1140        ) -> Result<Self> {
1141            assert_eq!(inputs.len(), 2);
1142            Ok(Self {
1143                exprs,
1144                left_child: Arc::new(inputs.remove(0)),
1145                right_child: Arc::new(inputs.remove(0)),
1146                schema: Arc::clone(&self.schema),
1147            })
1148        }
1149
1150        fn necessary_children_exprs(
1151            &self,
1152            output_columns: &[usize],
1153        ) -> Option<Vec<Vec<usize>>> {
1154            let left_child_len = self.left_child.schema().fields().len();
1155            let mut left_reqs = vec![];
1156            let mut right_reqs = vec![];
1157            for &out_idx in output_columns {
1158                if out_idx < left_child_len {
1159                    left_reqs.push(out_idx);
1160                } else {
1161                    // Output indices further than the left_child_len
1162                    // comes from right children
1163                    right_reqs.push(out_idx - left_child_len)
1164                }
1165            }
1166            Some(vec![left_reqs, right_reqs])
1167        }
1168
1169        fn supports_limit_pushdown(&self) -> bool {
1170            false // Disallow limit push-down by default
1171        }
1172    }
1173
1174    #[test]
1175    fn merge_two_projection() -> Result<()> {
1176        let table_scan = test_table_scan()?;
1177        let plan = LogicalPlanBuilder::from(table_scan)
1178            .project(vec![col("a")])?
1179            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1180            .build()?;
1181
1182        assert_optimized_plan_equal!(
1183            plan,
1184            @r"
1185        Projection: Int32(1) + test.a
1186          TableScan: test projection=[a]
1187        "
1188        )
1189    }
1190
1191    #[test]
1192    fn merge_three_projection() -> Result<()> {
1193        let table_scan = test_table_scan()?;
1194        let plan = LogicalPlanBuilder::from(table_scan)
1195            .project(vec![col("a"), col("b")])?
1196            .project(vec![col("a")])?
1197            .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1198            .build()?;
1199
1200        assert_optimized_plan_equal!(
1201            plan,
1202            @r"
1203        Projection: Int32(1) + test.a
1204          TableScan: test projection=[a]
1205        "
1206        )
1207    }
1208
1209    #[test]
1210    fn merge_alias() -> Result<()> {
1211        let table_scan = test_table_scan()?;
1212        let plan = LogicalPlanBuilder::from(table_scan)
1213            .project(vec![col("a")])?
1214            .project(vec![col("a").alias("alias")])?
1215            .build()?;
1216
1217        assert_optimized_plan_equal!(
1218            plan,
1219            @r"
1220        Projection: test.a AS alias
1221          TableScan: test projection=[a]
1222        "
1223        )
1224    }
1225
1226    #[test]
1227    fn merge_nested_alias() -> Result<()> {
1228        let table_scan = test_table_scan()?;
1229        let plan = LogicalPlanBuilder::from(table_scan)
1230            .project(vec![col("a").alias("alias1").alias("alias2")])?
1231            .project(vec![col("alias2").alias("alias")])?
1232            .build()?;
1233
1234        assert_optimized_plan_equal!(
1235            plan,
1236            @r"
1237        Projection: test.a AS alias
1238          TableScan: test projection=[a]
1239        "
1240        )
1241    }
1242
1243    #[test]
1244    fn test_nested_count() -> Result<()> {
1245        let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1246
1247        let groups: Vec<Expr> = vec![];
1248
1249        let plan = table_scan(TableReference::none(), &schema, None)
1250            .unwrap()
1251            .aggregate(groups.clone(), vec![count(lit(1))])
1252            .unwrap()
1253            .aggregate(groups, vec![count(lit(1))])
1254            .unwrap()
1255            .build()
1256            .unwrap();
1257
1258        assert_optimized_plan_equal!(
1259            plan,
1260            @r"
1261        Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1262          EmptyRelation: rows=1
1263        "
1264        )
1265    }
1266
1267    #[test]
1268    fn test_neg_push_down() -> Result<()> {
1269        let table_scan = test_table_scan()?;
1270        let plan = LogicalPlanBuilder::from(table_scan)
1271            .project(vec![-col("a")])?
1272            .build()?;
1273
1274        assert_optimized_plan_equal!(
1275            plan,
1276            @r"
1277        Projection: (- test.a)
1278          TableScan: test projection=[a]
1279        "
1280        )
1281    }
1282
1283    #[test]
1284    fn test_is_null() -> Result<()> {
1285        let table_scan = test_table_scan()?;
1286        let plan = LogicalPlanBuilder::from(table_scan)
1287            .project(vec![col("a").is_null()])?
1288            .build()?;
1289
1290        assert_optimized_plan_equal!(
1291            plan,
1292            @r"
1293        Projection: test.a IS NULL
1294          TableScan: test projection=[a]
1295        "
1296        )
1297    }
1298
1299    #[test]
1300    fn test_is_not_null() -> Result<()> {
1301        let table_scan = test_table_scan()?;
1302        let plan = LogicalPlanBuilder::from(table_scan)
1303            .project(vec![col("a").is_not_null()])?
1304            .build()?;
1305
1306        assert_optimized_plan_equal!(
1307            plan,
1308            @r"
1309        Projection: test.a IS NOT NULL
1310          TableScan: test projection=[a]
1311        "
1312        )
1313    }
1314
1315    #[test]
1316    fn test_is_true() -> Result<()> {
1317        let table_scan = test_table_scan()?;
1318        let plan = LogicalPlanBuilder::from(table_scan)
1319            .project(vec![col("a").is_true()])?
1320            .build()?;
1321
1322        assert_optimized_plan_equal!(
1323            plan,
1324            @r"
1325        Projection: test.a IS TRUE
1326          TableScan: test projection=[a]
1327        "
1328        )
1329    }
1330
1331    #[test]
1332    fn test_is_not_true() -> Result<()> {
1333        let table_scan = test_table_scan()?;
1334        let plan = LogicalPlanBuilder::from(table_scan)
1335            .project(vec![col("a").is_not_true()])?
1336            .build()?;
1337
1338        assert_optimized_plan_equal!(
1339            plan,
1340            @r"
1341        Projection: test.a IS NOT TRUE
1342          TableScan: test projection=[a]
1343        "
1344        )
1345    }
1346
1347    #[test]
1348    fn test_is_false() -> Result<()> {
1349        let table_scan = test_table_scan()?;
1350        let plan = LogicalPlanBuilder::from(table_scan)
1351            .project(vec![col("a").is_false()])?
1352            .build()?;
1353
1354        assert_optimized_plan_equal!(
1355            plan,
1356            @r"
1357        Projection: test.a IS FALSE
1358          TableScan: test projection=[a]
1359        "
1360        )
1361    }
1362
1363    #[test]
1364    fn test_is_not_false() -> Result<()> {
1365        let table_scan = test_table_scan()?;
1366        let plan = LogicalPlanBuilder::from(table_scan)
1367            .project(vec![col("a").is_not_false()])?
1368            .build()?;
1369
1370        assert_optimized_plan_equal!(
1371            plan,
1372            @r"
1373        Projection: test.a IS NOT FALSE
1374          TableScan: test projection=[a]
1375        "
1376        )
1377    }
1378
1379    #[test]
1380    fn test_is_unknown() -> Result<()> {
1381        let table_scan = test_table_scan()?;
1382        let plan = LogicalPlanBuilder::from(table_scan)
1383            .project(vec![col("a").is_unknown()])?
1384            .build()?;
1385
1386        assert_optimized_plan_equal!(
1387            plan,
1388            @r"
1389        Projection: test.a IS UNKNOWN
1390          TableScan: test projection=[a]
1391        "
1392        )
1393    }
1394
1395    #[test]
1396    fn test_is_not_unknown() -> Result<()> {
1397        let table_scan = test_table_scan()?;
1398        let plan = LogicalPlanBuilder::from(table_scan)
1399            .project(vec![col("a").is_not_unknown()])?
1400            .build()?;
1401
1402        assert_optimized_plan_equal!(
1403            plan,
1404            @r"
1405        Projection: test.a IS NOT UNKNOWN
1406          TableScan: test projection=[a]
1407        "
1408        )
1409    }
1410
1411    #[test]
1412    fn test_not() -> Result<()> {
1413        let table_scan = test_table_scan()?;
1414        let plan = LogicalPlanBuilder::from(table_scan)
1415            .project(vec![not(col("a"))])?
1416            .build()?;
1417
1418        assert_optimized_plan_equal!(
1419            plan,
1420            @r"
1421        Projection: NOT test.a
1422          TableScan: test projection=[a]
1423        "
1424        )
1425    }
1426
1427    #[test]
1428    fn test_try_cast() -> Result<()> {
1429        let table_scan = test_table_scan()?;
1430        let plan = LogicalPlanBuilder::from(table_scan)
1431            .project(vec![try_cast(col("a"), DataType::Float64)])?
1432            .build()?;
1433
1434        assert_optimized_plan_equal!(
1435            plan,
1436            @r"
1437        Projection: TRY_CAST(test.a AS Float64)
1438          TableScan: test projection=[a]
1439        "
1440        )
1441    }
1442
1443    #[test]
1444    fn test_similar_to() -> Result<()> {
1445        let table_scan = test_table_scan()?;
1446        let expr = Box::new(col("a"));
1447        let pattern = Box::new(lit("[0-9]"));
1448        let similar_to_expr =
1449            Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1450        let plan = LogicalPlanBuilder::from(table_scan)
1451            .project(vec![similar_to_expr])?
1452            .build()?;
1453
1454        assert_optimized_plan_equal!(
1455            plan,
1456            @r#"
1457        Projection: test.a SIMILAR TO Utf8("[0-9]")
1458          TableScan: test projection=[a]
1459        "#
1460        )
1461    }
1462
1463    #[test]
1464    fn test_between() -> Result<()> {
1465        let table_scan = test_table_scan()?;
1466        let plan = LogicalPlanBuilder::from(table_scan)
1467            .project(vec![col("a").between(lit(1), lit(3))])?
1468            .build()?;
1469
1470        assert_optimized_plan_equal!(
1471            plan,
1472            @r"
1473        Projection: test.a BETWEEN Int32(1) AND Int32(3)
1474          TableScan: test projection=[a]
1475        "
1476        )
1477    }
1478
1479    // Test Case expression
1480    #[test]
1481    fn test_case_merged() -> Result<()> {
1482        let table_scan = test_table_scan()?;
1483        let plan = LogicalPlanBuilder::from(table_scan)
1484            .project(vec![col("a"), lit(0).alias("d")])?
1485            .project(vec![
1486                col("a"),
1487                when(col("a").eq(lit(1)), lit(10))
1488                    .otherwise(col("d"))?
1489                    .alias("d"),
1490            ])?
1491            .build()?;
1492
1493        assert_optimized_plan_equal!(
1494            plan,
1495            @r"
1496        Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1497          TableScan: test projection=[a]
1498        "
1499        )
1500    }
1501
1502    // Test outer projection isn't discarded despite the same schema as inner
1503    // https://github.com/apache/datafusion/issues/8942
1504    #[test]
1505    fn test_derived_column() -> Result<()> {
1506        let table_scan = test_table_scan()?;
1507        let plan = LogicalPlanBuilder::from(table_scan)
1508            .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1509            .project(vec![
1510                col("a"),
1511                when(col("a").eq(lit(1)), lit(10))
1512                    .otherwise(col("d"))?
1513                    .alias("d"),
1514            ])?
1515            .build()?;
1516
1517        assert_optimized_plan_equal!(
1518            plan,
1519            @r"
1520        Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1521          Projection: test.a + Int32(1) AS a, Int32(0) AS d
1522            TableScan: test projection=[a]
1523        "
1524        )
1525    }
1526
1527    // Since only column `a` is referred at the output. Scan should only contain projection=[a].
1528    // User defined node should be able to propagate necessary expressions by its parent to its child.
1529    #[test]
1530    fn test_user_defined_logical_plan_node() -> Result<()> {
1531        let table_scan = test_table_scan()?;
1532        let custom_plan = LogicalPlan::Extension(Extension {
1533            node: Arc::new(NoOpUserDefined::new(
1534                Arc::clone(table_scan.schema()),
1535                Arc::new(table_scan.clone()),
1536            )),
1537        });
1538        let plan = LogicalPlanBuilder::from(custom_plan)
1539            .project(vec![col("a"), lit(0).alias("d")])?
1540            .build()?;
1541
1542        assert_optimized_plan_equal!(
1543            plan,
1544            @r"
1545        Projection: test.a, Int32(0) AS d
1546          NoOpUserDefined
1547            TableScan: test projection=[a]
1548        "
1549        )
1550    }
1551
1552    // Only column `a` is referred at the output. However, User defined node itself uses column `b`
1553    // during its operation. Hence, scan should contain projection=[a, b].
1554    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1555    // required expressions.
1556    #[test]
1557    fn test_user_defined_logical_plan_node2() -> Result<()> {
1558        let table_scan = test_table_scan()?;
1559        let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1560        let custom_plan = LogicalPlan::Extension(Extension {
1561            node: Arc::new(
1562                NoOpUserDefined::new(
1563                    Arc::clone(table_scan.schema()),
1564                    Arc::new(table_scan.clone()),
1565                )
1566                .with_exprs(exprs),
1567            ),
1568        });
1569        let plan = LogicalPlanBuilder::from(custom_plan)
1570            .project(vec![col("a"), lit(0).alias("d")])?
1571            .build()?;
1572
1573        assert_optimized_plan_equal!(
1574            plan,
1575            @r"
1576        Projection: test.a, Int32(0) AS d
1577          NoOpUserDefined
1578            TableScan: test projection=[a, b]
1579        "
1580        )
1581    }
1582
1583    // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c`
1584    // during its operation. Hence, scan should contain projection=[a, b, c].
1585    // User defined node should be able to propagate necessary expressions by its parent, as well as its own
1586    // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions
1587    // should be propagated also.
1588    #[test]
1589    fn test_user_defined_logical_plan_node3() -> Result<()> {
1590        let table_scan = test_table_scan()?;
1591        let left_expr = Expr::Column(Column::from_qualified_name("b"));
1592        let right_expr = Expr::Column(Column::from_qualified_name("c"));
1593        let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1594            Box::new(left_expr),
1595            Operator::Plus,
1596            Box::new(right_expr),
1597        ));
1598        let exprs = vec![binary_expr];
1599        let custom_plan = LogicalPlan::Extension(Extension {
1600            node: Arc::new(
1601                NoOpUserDefined::new(
1602                    Arc::clone(table_scan.schema()),
1603                    Arc::new(table_scan.clone()),
1604                )
1605                .with_exprs(exprs),
1606            ),
1607        });
1608        let plan = LogicalPlanBuilder::from(custom_plan)
1609            .project(vec![col("a"), lit(0).alias("d")])?
1610            .build()?;
1611
1612        assert_optimized_plan_equal!(
1613            plan,
1614            @r"
1615        Projection: test.a, Int32(0) AS d
1616          NoOpUserDefined
1617            TableScan: test projection=[a, b, c]
1618        "
1619        )
1620    }
1621
1622    // Columns `l.a`, `l.c`, `r.a` is referred at the output.
1623    // User defined node should be able to propagate necessary expressions by its parent, to its children.
1624    // Even if it has multiple children.
1625    // left child should have `projection=[a, c]`, and right side should have `projection=[a]`.
1626    #[test]
1627    fn test_user_defined_logical_plan_node4() -> Result<()> {
1628        let left_table = test_table_scan_with_name("l")?;
1629        let right_table = test_table_scan_with_name("r")?;
1630        let custom_plan = LogicalPlan::Extension(Extension {
1631            node: Arc::new(UserDefinedCrossJoin::new(
1632                Arc::new(left_table),
1633                Arc::new(right_table),
1634            )),
1635        });
1636        let plan = LogicalPlanBuilder::from(custom_plan)
1637            .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1638            .build()?;
1639
1640        assert_optimized_plan_equal!(
1641            plan,
1642            @r"
1643        Projection: l.a, l.c, r.a, Int32(0) AS d
1644          UserDefinedCrossJoin
1645            TableScan: l projection=[a, c]
1646            TableScan: r projection=[a]
1647        "
1648        )
1649    }
1650
1651    #[test]
1652    fn aggregate_no_group_by() -> Result<()> {
1653        let table_scan = test_table_scan()?;
1654
1655        let plan = LogicalPlanBuilder::from(table_scan)
1656            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1657            .build()?;
1658
1659        assert_optimized_plan_equal!(
1660            plan,
1661            @r"
1662        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1663          TableScan: test projection=[b]
1664        "
1665        )
1666    }
1667
1668    #[test]
1669    fn aggregate_group_by() -> Result<()> {
1670        let table_scan = test_table_scan()?;
1671
1672        let plan = LogicalPlanBuilder::from(table_scan)
1673            .aggregate(vec![col("c")], vec![max(col("b"))])?
1674            .build()?;
1675
1676        assert_optimized_plan_equal!(
1677            plan,
1678            @r"
1679        Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1680          TableScan: test projection=[b, c]
1681        "
1682        )
1683    }
1684
1685    #[test]
1686    fn aggregate_group_by_with_table_alias() -> Result<()> {
1687        let table_scan = test_table_scan()?;
1688
1689        let plan = LogicalPlanBuilder::from(table_scan)
1690            .alias("a")?
1691            .aggregate(vec![col("c")], vec![max(col("b"))])?
1692            .build()?;
1693
1694        assert_optimized_plan_equal!(
1695            plan,
1696            @r"
1697        Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1698          SubqueryAlias: a
1699            TableScan: test projection=[b, c]
1700        "
1701        )
1702    }
1703
1704    #[test]
1705    fn aggregate_no_group_by_with_filter() -> Result<()> {
1706        let table_scan = test_table_scan()?;
1707
1708        let plan = LogicalPlanBuilder::from(table_scan)
1709            .filter(col("c").gt(lit(1)))?
1710            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1711            .build()?;
1712
1713        assert_optimized_plan_equal!(
1714            plan,
1715            @r"
1716        Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1717          Projection: test.b
1718            Filter: test.c > Int32(1)
1719              TableScan: test projection=[b, c]
1720        "
1721        )
1722    }
1723
1724    #[test]
1725    fn aggregate_with_periods() -> Result<()> {
1726        let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1727
1728        // Build a plan that looks as follows (note "tag.one" is a column named
1729        // "tag.one", not a column named "one" in a table named "tag"):
1730        //
1731        // Projection: tag.one
1732        //   Aggregate: groupBy=[], aggr=[max("tag.one") AS "tag.one"]
1733        //    TableScan
1734        let plan = table_scan(Some("m4"), &schema, None)?
1735            .aggregate(
1736                Vec::<Expr>::new(),
1737                vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1738            )?
1739            .project([col(Column::new_unqualified("tag.one"))])?
1740            .build()?;
1741
1742        assert_optimized_plan_equal!(
1743            plan,
1744            @r"
1745        Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1746          TableScan: m4 projection=[tag.one]
1747        "
1748        )
1749    }
1750
1751    #[test]
1752    fn redundant_project() -> Result<()> {
1753        let table_scan = test_table_scan()?;
1754
1755        let plan = LogicalPlanBuilder::from(table_scan)
1756            .project(vec![col("a"), col("b"), col("c")])?
1757            .project(vec![col("a"), col("c"), col("b")])?
1758            .build()?;
1759        assert_optimized_plan_equal!(
1760            plan,
1761            @r"
1762        Projection: test.a, test.c, test.b
1763          TableScan: test projection=[a, b, c]
1764        "
1765        )
1766    }
1767
1768    #[test]
1769    fn reorder_scan() -> Result<()> {
1770        let schema = Schema::new(test_table_scan_fields());
1771
1772        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1773        assert_optimized_plan_equal!(
1774            plan,
1775            @"TableScan: test projection=[b, a, c]"
1776        )
1777    }
1778
1779    #[test]
1780    fn reorder_scan_projection() -> Result<()> {
1781        let schema = Schema::new(test_table_scan_fields());
1782
1783        let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1784            .project(vec![col("a"), col("b")])?
1785            .build()?;
1786        assert_optimized_plan_equal!(
1787            plan,
1788            @r"
1789        Projection: test.a, test.b
1790          TableScan: test projection=[b, a]
1791        "
1792        )
1793    }
1794
1795    #[test]
1796    fn reorder_projection() -> Result<()> {
1797        let table_scan = test_table_scan()?;
1798
1799        let plan = LogicalPlanBuilder::from(table_scan)
1800            .project(vec![col("c"), col("b"), col("a")])?
1801            .build()?;
1802        assert_optimized_plan_equal!(
1803            plan,
1804            @r"
1805        Projection: test.c, test.b, test.a
1806          TableScan: test projection=[a, b, c]
1807        "
1808        )
1809    }
1810
1811    #[test]
1812    fn noncontinuous_redundant_projection() -> Result<()> {
1813        let table_scan = test_table_scan()?;
1814
1815        let plan = LogicalPlanBuilder::from(table_scan)
1816            .project(vec![col("c"), col("b"), col("a")])?
1817            .filter(col("c").gt(lit(1)))?
1818            .project(vec![col("c"), col("a"), col("b")])?
1819            .filter(col("b").gt(lit(1)))?
1820            .filter(col("a").gt(lit(1)))?
1821            .project(vec![col("a"), col("c"), col("b")])?
1822            .build()?;
1823        assert_optimized_plan_equal!(
1824            plan,
1825            @r"
1826        Projection: test.a, test.c, test.b
1827          Filter: test.a > Int32(1)
1828            Filter: test.b > Int32(1)
1829              Projection: test.c, test.a, test.b
1830                Filter: test.c > Int32(1)
1831                  Projection: test.c, test.b, test.a
1832                    TableScan: test projection=[a, b, c]
1833        "
1834        )
1835    }
1836
1837    #[test]
1838    fn join_schema_trim_full_join_column_projection() -> Result<()> {
1839        let table_scan = test_table_scan()?;
1840
1841        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1842        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1843
1844        let plan = LogicalPlanBuilder::from(table_scan)
1845            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1846            .project(vec![col("a"), col("b"), col("c1")])?
1847            .build()?;
1848
1849        let optimized_plan = optimize(plan)?;
1850
1851        // make sure projections are pushed down to both table scans
1852        assert_snapshot!(
1853            optimized_plan.clone(),
1854            @r"
1855        Left Join: test.a = test2.c1
1856          TableScan: test projection=[a, b]
1857          TableScan: test2 projection=[c1]
1858        "
1859        );
1860
1861        // make sure schema for join node include both join columns
1862        let optimized_join = optimized_plan;
1863        assert_eq!(
1864            **optimized_join.schema(),
1865            DFSchema::new_with_metadata(
1866                vec![
1867                    (
1868                        Some("test".into()),
1869                        Arc::new(Field::new("a", DataType::UInt32, false))
1870                    ),
1871                    (
1872                        Some("test".into()),
1873                        Arc::new(Field::new("b", DataType::UInt32, false))
1874                    ),
1875                    (
1876                        Some("test2".into()),
1877                        Arc::new(Field::new("c1", DataType::UInt32, true))
1878                    ),
1879                ],
1880                HashMap::new()
1881            )?,
1882        );
1883
1884        Ok(())
1885    }
1886
1887    #[test]
1888    fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1889        // test join column push down without explicit column projections
1890
1891        let table_scan = test_table_scan()?;
1892
1893        let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1894        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1895
1896        let plan = LogicalPlanBuilder::from(table_scan)
1897            .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1898            // projecting joined column `a` should push the right side column `c1` projection as
1899            // well into test2 table even though `c1` is not referenced in projection.
1900            .project(vec![col("a"), col("b")])?
1901            .build()?;
1902
1903        let optimized_plan = optimize(plan)?;
1904
1905        // make sure projections are pushed down to both table scans
1906        assert_snapshot!(
1907            optimized_plan.clone(),
1908            @r"
1909        Projection: test.a, test.b
1910          Left Join: test.a = test2.c1
1911            TableScan: test projection=[a, b]
1912            TableScan: test2 projection=[c1]
1913        "
1914        );
1915
1916        // make sure schema for join node include both join columns
1917        let optimized_join = optimized_plan.inputs()[0];
1918        assert_eq!(
1919            **optimized_join.schema(),
1920            DFSchema::new_with_metadata(
1921                vec![
1922                    (
1923                        Some("test".into()),
1924                        Arc::new(Field::new("a", DataType::UInt32, false))
1925                    ),
1926                    (
1927                        Some("test".into()),
1928                        Arc::new(Field::new("b", DataType::UInt32, false))
1929                    ),
1930                    (
1931                        Some("test2".into()),
1932                        Arc::new(Field::new("c1", DataType::UInt32, true))
1933                    ),
1934                ],
1935                HashMap::new()
1936            )?,
1937        );
1938
1939        Ok(())
1940    }
1941
1942    #[test]
1943    fn join_schema_trim_using_join() -> Result<()> {
1944        // shared join columns from using join should be pushed to both sides
1945
1946        let table_scan = test_table_scan()?;
1947
1948        let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1949        let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1950
1951        let plan = LogicalPlanBuilder::from(table_scan)
1952            .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1953            .project(vec![col("a"), col("b")])?
1954            .build()?;
1955
1956        let optimized_plan = optimize(plan)?;
1957
1958        // make sure projections are pushed down to table scan
1959        assert_snapshot!(
1960            optimized_plan.clone(),
1961            @r"
1962        Projection: test.a, test.b
1963          Left Join: Using test.a = test2.a
1964            TableScan: test projection=[a, b]
1965            TableScan: test2 projection=[a]
1966        "
1967        );
1968
1969        // make sure schema for join node include both join columns
1970        let optimized_join = optimized_plan.inputs()[0];
1971        assert_eq!(
1972            **optimized_join.schema(),
1973            DFSchema::new_with_metadata(
1974                vec![
1975                    (
1976                        Some("test".into()),
1977                        Arc::new(Field::new("a", DataType::UInt32, false))
1978                    ),
1979                    (
1980                        Some("test".into()),
1981                        Arc::new(Field::new("b", DataType::UInt32, false))
1982                    ),
1983                    (
1984                        Some("test2".into()),
1985                        Arc::new(Field::new("a", DataType::UInt32, true))
1986                    ),
1987                ],
1988                HashMap::new()
1989            )?,
1990        );
1991
1992        Ok(())
1993    }
1994
1995    #[test]
1996    fn cast() -> Result<()> {
1997        let table_scan = test_table_scan()?;
1998
1999        let plan = LogicalPlanBuilder::from(table_scan)
2000            .project(vec![Expr::Cast(Cast::new(
2001                Box::new(col("c")),
2002                DataType::Float64,
2003            ))])?
2004            .build()?;
2005
2006        assert_optimized_plan_equal!(
2007            plan,
2008            @r"
2009        Projection: CAST(test.c AS Float64)
2010          TableScan: test projection=[c]
2011        "
2012        )
2013    }
2014
2015    #[test]
2016    fn table_scan_projected_schema() -> Result<()> {
2017        let table_scan = test_table_scan()?;
2018        let plan = LogicalPlanBuilder::from(test_table_scan()?)
2019            .project(vec![col("a"), col("b")])?
2020            .build()?;
2021
2022        assert_eq!(3, table_scan.schema().fields().len());
2023        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2024        assert_fields_eq(&plan, vec!["a", "b"]);
2025
2026        assert_optimized_plan_equal!(
2027            plan,
2028            @"TableScan: test projection=[a, b]"
2029        )
2030    }
2031
2032    #[test]
2033    fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
2034        let table_scan = test_table_scan()?;
2035        let input_schema = table_scan.schema();
2036        assert_eq!(3, input_schema.fields().len());
2037        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2038
2039        // Build the LogicalPlan directly (don't use PlanBuilder), so
2040        // that the Column references are unqualified (e.g. their
2041        // relation is `None`). PlanBuilder resolves the expressions
2042        let expr = vec![col("test.a"), col("test.b")];
2043        let plan =
2044            LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
2045
2046        assert_fields_eq(&plan, vec!["a", "b"]);
2047
2048        assert_optimized_plan_equal!(
2049            plan,
2050            @"TableScan: test projection=[a, b]"
2051        )
2052    }
2053
2054    #[test]
2055    fn table_limit() -> Result<()> {
2056        let table_scan = test_table_scan()?;
2057        assert_eq!(3, table_scan.schema().fields().len());
2058        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2059
2060        let plan = LogicalPlanBuilder::from(table_scan)
2061            .project(vec![col("c"), col("a")])?
2062            .limit(0, Some(5))?
2063            .build()?;
2064
2065        assert_fields_eq(&plan, vec!["c", "a"]);
2066
2067        assert_optimized_plan_equal!(
2068            plan,
2069            @r"
2070        Limit: skip=0, fetch=5
2071          Projection: test.c, test.a
2072            TableScan: test projection=[a, c]
2073        "
2074        )
2075    }
2076
2077    #[test]
2078    fn table_scan_without_projection() -> Result<()> {
2079        let table_scan = test_table_scan()?;
2080        let plan = LogicalPlanBuilder::from(table_scan).build()?;
2081        // should expand projection to all columns without projection
2082        assert_optimized_plan_equal!(
2083            plan,
2084            @"TableScan: test projection=[a, b, c]"
2085        )
2086    }
2087
2088    #[test]
2089    fn table_scan_with_literal_projection() -> Result<()> {
2090        let table_scan = test_table_scan()?;
2091        let plan = LogicalPlanBuilder::from(table_scan)
2092            .project(vec![lit(1_i64), lit(2_i64)])?
2093            .build()?;
2094        assert_optimized_plan_equal!(
2095            plan,
2096            @r"
2097        Projection: Int64(1), Int64(2)
2098          TableScan: test projection=[]
2099        "
2100        )
2101    }
2102
2103    /// tests that it removes unused columns in projections
2104    #[test]
2105    fn table_unused_column() -> Result<()> {
2106        let table_scan = test_table_scan()?;
2107        assert_eq!(3, table_scan.schema().fields().len());
2108        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2109
2110        // we never use "b" in the first projection => remove it
2111        let plan = LogicalPlanBuilder::from(table_scan)
2112            .project(vec![col("c"), col("a"), col("b")])?
2113            .filter(col("c").gt(lit(1)))?
2114            .aggregate(vec![col("c")], vec![max(col("a"))])?
2115            .build()?;
2116
2117        assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
2118
2119        let plan = optimize(plan).expect("failed to optimize plan");
2120        assert_optimized_plan_equal!(
2121            plan,
2122            @r"
2123        Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2124          Filter: test.c > Int32(1)
2125            Projection: test.c, test.a
2126              TableScan: test projection=[a, c]
2127        "
2128        )
2129    }
2130
2131    /// tests that it removes un-needed projections
2132    #[test]
2133    fn table_unused_projection() -> Result<()> {
2134        let table_scan = test_table_scan()?;
2135        assert_eq!(3, table_scan.schema().fields().len());
2136        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2137
2138        // there is no need for the first projection
2139        let plan = LogicalPlanBuilder::from(table_scan)
2140            .project(vec![col("b")])?
2141            .project(vec![lit(1).alias("a")])?
2142            .build()?;
2143
2144        assert_fields_eq(&plan, vec!["a"]);
2145
2146        assert_optimized_plan_equal!(
2147            plan,
2148            @r"
2149        Projection: Int32(1) AS a
2150          TableScan: test projection=[]
2151        "
2152        )
2153    }
2154
2155    #[test]
2156    fn table_full_filter_pushdown() -> Result<()> {
2157        let schema = Schema::new(test_table_scan_fields());
2158
2159        let table_scan = table_scan_with_filters(
2160            Some("test"),
2161            &schema,
2162            None,
2163            vec![col("b").eq(lit(1))],
2164        )?
2165        .build()?;
2166        assert_eq!(3, table_scan.schema().fields().len());
2167        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2168
2169        // there is no need for the first projection
2170        let plan = LogicalPlanBuilder::from(table_scan)
2171            .project(vec![col("b")])?
2172            .project(vec![lit(1).alias("a")])?
2173            .build()?;
2174
2175        assert_fields_eq(&plan, vec!["a"]);
2176
2177        assert_optimized_plan_equal!(
2178            plan,
2179            @r"
2180        Projection: Int32(1) AS a
2181          TableScan: test projection=[], full_filters=[b = Int32(1)]
2182        "
2183        )
2184    }
2185
2186    /// tests that optimizing twice yields same plan
2187    #[test]
2188    fn test_double_optimization() -> Result<()> {
2189        let table_scan = test_table_scan()?;
2190
2191        let plan = LogicalPlanBuilder::from(table_scan)
2192            .project(vec![col("b")])?
2193            .project(vec![lit(1).alias("a")])?
2194            .build()?;
2195
2196        let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2197        let optimized_plan2 =
2198            optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2199
2200        let formatted_plan1 = format!("{optimized_plan1:?}");
2201        let formatted_plan2 = format!("{optimized_plan2:?}");
2202        assert_eq!(formatted_plan1, formatted_plan2);
2203        Ok(())
2204    }
2205
2206    /// tests that it removes an aggregate is never used downstream
2207    #[test]
2208    fn table_unused_aggregate() -> Result<()> {
2209        let table_scan = test_table_scan()?;
2210        assert_eq!(3, table_scan.schema().fields().len());
2211        assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2212
2213        // we never use "min(b)" => remove it
2214        let plan = LogicalPlanBuilder::from(table_scan)
2215            .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2216            .filter(col("c").gt(lit(1)))?
2217            .project(vec![col("c"), col("a"), col("max(test.b)")])?
2218            .build()?;
2219
2220        assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2221
2222        assert_optimized_plan_equal!(
2223            plan,
2224            @r"
2225        Projection: test.c, test.a, max(test.b)
2226          Filter: test.c > Int32(1)
2227            Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2228              TableScan: test projection=[a, b, c]
2229        "
2230        )
2231    }
2232
2233    #[test]
2234    fn aggregate_filter_pushdown() -> Result<()> {
2235        let table_scan = test_table_scan()?;
2236        let aggr_with_filter = count_udaf()
2237            .call(vec![col("b")])
2238            .filter(col("c").gt(lit(42)))
2239            .build()?;
2240        let plan = LogicalPlanBuilder::from(table_scan)
2241            .aggregate(
2242                vec![col("a")],
2243                vec![count(col("b")), aggr_with_filter.alias("count2")],
2244            )?
2245            .build()?;
2246
2247        assert_optimized_plan_equal!(
2248            plan,
2249            @r"
2250        Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2251          TableScan: test projection=[a, b, c]
2252        "
2253        )
2254    }
2255
2256    #[test]
2257    fn pushdown_through_distinct() -> Result<()> {
2258        let table_scan = test_table_scan()?;
2259
2260        let plan = LogicalPlanBuilder::from(table_scan)
2261            .project(vec![col("a"), col("b")])?
2262            .distinct()?
2263            .project(vec![col("a")])?
2264            .build()?;
2265
2266        assert_optimized_plan_equal!(
2267            plan,
2268            @r"
2269        Projection: test.a
2270          Distinct:
2271            TableScan: test projection=[a, b]
2272        "
2273        )
2274    }
2275
2276    #[test]
2277    fn test_window() -> Result<()> {
2278        let table_scan = test_table_scan()?;
2279
2280        let max1 = Expr::from(expr::WindowFunction::new(
2281            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2282            vec![col("test.a")],
2283        ))
2284        .partition_by(vec![col("test.b")])
2285        .build()
2286        .unwrap();
2287
2288        let max2 = Expr::from(expr::WindowFunction::new(
2289            WindowFunctionDefinition::AggregateUDF(max_udaf()),
2290            vec![col("test.b")],
2291        ));
2292        let col1 = col(max1.schema_name().to_string());
2293        let col2 = col(max2.schema_name().to_string());
2294
2295        let plan = LogicalPlanBuilder::from(table_scan)
2296            .window(vec![max1])?
2297            .window(vec![max2])?
2298            .project(vec![col1, col2])?
2299            .build()?;
2300
2301        assert_optimized_plan_equal!(
2302            plan,
2303            @r"
2304        Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2305          WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2306            Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2307              WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2308                TableScan: test projection=[a, b]
2309        "
2310        )
2311    }
2312
2313    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2314
2315    fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2316        let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2317        let optimized_plan =
2318            optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2319        Ok(optimized_plan)
2320    }
2321}