datafusion_physical_plan/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the projection execution plan. A projection determines which columns or expressions
19//! are returned from a query. The SQL statement `SELECT a, b, a+b FROM t1` is an example
20//! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the
21//! projection expressions. `SELECT` without `FROM` will only evaluate expressions.
22
23use super::expressions::{Column, Literal};
24use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
25use super::{
26    DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
27    SendableRecordBatchStream, Statistics,
28};
29use crate::execution_plan::CardinalityEffect;
30use crate::filter_pushdown::{
31    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
32    FilterPushdownPropagation,
33};
34use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef};
35use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr};
36use std::any::Any;
37use std::collections::HashMap;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41
42use arrow::datatypes::SchemaRef;
43use arrow::record_batch::{RecordBatch, RecordBatchOptions};
44use datafusion_common::config::ConfigOptions;
45use datafusion_common::tree_node::{
46    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
47};
48use datafusion_common::{internal_err, JoinSide, Result};
49use datafusion_execution::TaskContext;
50use datafusion_physical_expr::equivalence::ProjectionMapping;
51use datafusion_physical_expr::utils::collect_columns;
52use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
53use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
54// Re-exported from datafusion-physical-expr for backwards compatibility
55// We recommend updating your imports to use datafusion-physical-expr directly
56pub use datafusion_physical_expr::projection::{
57    update_expr, ProjectionExpr, ProjectionExprs,
58};
59
60use futures::stream::{Stream, StreamExt};
61use log::trace;
62
63/// [`ExecutionPlan`] for a projection
64///
65/// Computes a set of scalar value expressions for each input row, producing one
66/// output row for each input row.
67#[derive(Debug, Clone)]
68pub struct ProjectionExec {
69    /// The projection expressions stored as tuples of (expression, output column name)
70    projection: ProjectionExprs,
71    /// The schema once the projection has been applied to the input
72    schema: SchemaRef,
73    /// The input plan
74    input: Arc<dyn ExecutionPlan>,
75    /// Execution metrics
76    metrics: ExecutionPlanMetricsSet,
77    /// Cache holding plan properties like equivalences, output partitioning etc.
78    cache: PlanProperties,
79}
80
81impl ProjectionExec {
82    /// Create a projection on an input
83    ///
84    /// # Example:
85    /// Create a `ProjectionExec` to crate `SELECT a, a+b AS sum_ab FROM t1`:
86    ///
87    /// ```
88    /// # use std::sync::Arc;
89    /// # use arrow_schema::{Schema, Field, DataType};
90    /// # use datafusion_expr::Operator;
91    /// # use datafusion_physical_plan::ExecutionPlan;
92    /// # use datafusion_physical_expr::expressions::{col, binary};
93    /// # use datafusion_physical_plan::empty::EmptyExec;
94    /// # use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
95    /// # fn schema() -> Arc<Schema> {
96    /// #  Arc::new(Schema::new(vec![
97    /// #   Field::new("a", DataType::Int32, false),
98    /// #   Field::new("b", DataType::Int32, false),
99    /// # ]))
100    /// # }
101    /// #
102    /// # fn input() -> Arc<dyn ExecutionPlan> {
103    /// #  Arc::new(EmptyExec::new(schema()))
104    /// # }
105    /// #
106    /// # fn main() {
107    /// let schema = schema();
108    /// // Create PhysicalExprs
109    /// let a = col("a", &schema).unwrap();
110    /// let b = col("b", &schema).unwrap();
111    /// let a_plus_b = binary(Arc::clone(&a), Operator::Plus, b, &schema).unwrap();
112    /// // create ProjectionExec
113    /// let proj = ProjectionExec::try_new(
114    ///     [
115    ///         ProjectionExpr {
116    ///             // expr a produces the column named "a"
117    ///             expr: a,
118    ///             alias: "a".to_string(),
119    ///         },
120    ///         ProjectionExpr {
121    ///             // expr: a + b produces the column named "sum_ab"
122    ///             expr: a_plus_b,
123    ///             alias: "sum_ab".to_string(),
124    ///         },
125    ///     ],
126    ///     input(),
127    /// )
128    /// .unwrap();
129    /// # }
130    /// ```
131    pub fn try_new<I, E>(expr: I, input: Arc<dyn ExecutionPlan>) -> Result<Self>
132    where
133        I: IntoIterator<Item = E>,
134        E: Into<ProjectionExpr>,
135    {
136        let input_schema = input.schema();
137        // convert argument to Vec<ProjectionExpr>
138        let expr_vec = expr.into_iter().map(Into::into).collect::<Vec<_>>();
139        let projection = ProjectionExprs::new(expr_vec);
140
141        let schema = Arc::new(projection.project_schema(&input_schema)?);
142
143        // Construct a map from the input expressions to the output expression of the Projection
144        let projection_mapping = projection.projection_mapping(&input_schema)?;
145        let cache =
146            Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?;
147        Ok(Self {
148            projection,
149            schema,
150            input,
151            metrics: ExecutionPlanMetricsSet::new(),
152            cache,
153        })
154    }
155
156    /// The projection expressions stored as tuples of (expression, output column name)
157    pub fn expr(&self) -> &[ProjectionExpr] {
158        self.projection.as_ref()
159    }
160
161    /// The input plan
162    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
163        &self.input
164    }
165
166    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
167    fn compute_properties(
168        input: &Arc<dyn ExecutionPlan>,
169        projection_mapping: &ProjectionMapping,
170        schema: SchemaRef,
171    ) -> Result<PlanProperties> {
172        // Calculate equivalence properties:
173        let input_eq_properties = input.equivalence_properties();
174        let eq_properties = input_eq_properties.project(projection_mapping, schema);
175        // Calculate output partitioning, which needs to respect aliases:
176        let output_partitioning = input
177            .output_partitioning()
178            .project(projection_mapping, input_eq_properties);
179
180        Ok(PlanProperties::new(
181            eq_properties,
182            output_partitioning,
183            input.pipeline_behavior(),
184            input.boundedness(),
185        ))
186    }
187}
188
189impl DisplayAs for ProjectionExec {
190    fn fmt_as(
191        &self,
192        t: DisplayFormatType,
193        f: &mut std::fmt::Formatter,
194    ) -> std::fmt::Result {
195        match t {
196            DisplayFormatType::Default | DisplayFormatType::Verbose => {
197                let expr: Vec<String> = self
198                    .projection
199                    .as_ref()
200                    .iter()
201                    .map(|proj_expr| {
202                        let e = proj_expr.expr.to_string();
203                        if e != proj_expr.alias {
204                            format!("{e} as {}", proj_expr.alias)
205                        } else {
206                            e
207                        }
208                    })
209                    .collect();
210
211                write!(f, "ProjectionExec: expr=[{}]", expr.join(", "))
212            }
213            DisplayFormatType::TreeRender => {
214                for (i, proj_expr) in self.expr().iter().enumerate() {
215                    let expr_sql = fmt_sql(proj_expr.expr.as_ref());
216                    if proj_expr.expr.to_string() == proj_expr.alias {
217                        writeln!(f, "expr{i}={expr_sql}")?;
218                    } else {
219                        writeln!(f, "{}={expr_sql}", proj_expr.alias)?;
220                    }
221                }
222
223                Ok(())
224            }
225        }
226    }
227}
228
229impl ExecutionPlan for ProjectionExec {
230    fn name(&self) -> &'static str {
231        "ProjectionExec"
232    }
233
234    /// Return a reference to Any that can be used for downcasting
235    fn as_any(&self) -> &dyn Any {
236        self
237    }
238
239    fn properties(&self) -> &PlanProperties {
240        &self.cache
241    }
242
243    fn maintains_input_order(&self) -> Vec<bool> {
244        // Tell optimizer this operator doesn't reorder its input
245        vec![true]
246    }
247
248    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
249        let all_simple_exprs = self.projection.iter().all(|proj_expr| {
250            proj_expr.expr.as_any().is::<Column>()
251                || proj_expr.expr.as_any().is::<Literal>()
252        });
253        // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename,
254        // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false.
255        vec![!all_simple_exprs]
256    }
257
258    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
259        vec![&self.input]
260    }
261
262    fn with_new_children(
263        self: Arc<Self>,
264        mut children: Vec<Arc<dyn ExecutionPlan>>,
265    ) -> Result<Arc<dyn ExecutionPlan>> {
266        ProjectionExec::try_new(self.projection.clone(), children.swap_remove(0))
267            .map(|p| Arc::new(p) as _)
268    }
269
270    fn execute(
271        &self,
272        partition: usize,
273        context: Arc<TaskContext>,
274    ) -> Result<SendableRecordBatchStream> {
275        trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
276        Ok(Box::pin(ProjectionStream::new(
277            Arc::clone(&self.schema),
278            self.projection.expr_iter().collect(),
279            self.input.execute(partition, context)?,
280            BaselineMetrics::new(&self.metrics, partition),
281        )))
282    }
283
284    fn metrics(&self) -> Option<MetricsSet> {
285        Some(self.metrics.clone_inner())
286    }
287
288    fn statistics(&self) -> Result<Statistics> {
289        self.partition_statistics(None)
290    }
291
292    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
293        let input_stats = self.input.partition_statistics(partition)?;
294        self.projection
295            .project_statistics(input_stats, &self.input.schema())
296    }
297
298    fn supports_limit_pushdown(&self) -> bool {
299        true
300    }
301
302    fn cardinality_effect(&self) -> CardinalityEffect {
303        CardinalityEffect::Equal
304    }
305
306    fn try_swapping_with_projection(
307        &self,
308        projection: &ProjectionExec,
309    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
310        let maybe_unified = try_unifying_projections(projection, self)?;
311        if let Some(new_plan) = maybe_unified {
312            // To unify 3 or more sequential projections:
313            remove_unnecessary_projections(new_plan).data().map(Some)
314        } else {
315            Ok(Some(Arc::new(projection.clone())))
316        }
317    }
318
319    fn gather_filters_for_pushdown(
320        &self,
321        _phase: FilterPushdownPhase,
322        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
323        _config: &ConfigOptions,
324    ) -> Result<FilterDescription> {
325        // TODO: In future, we can try to handle inverting aliases here.
326        // For the time being, we pass through untransformed filters, so filters on aliases are not handled.
327        // https://github.com/apache/datafusion/issues/17246
328        FilterDescription::from_children(parent_filters, &self.children())
329    }
330
331    fn handle_child_pushdown_result(
332        &self,
333        _phase: FilterPushdownPhase,
334        child_pushdown_result: ChildPushdownResult,
335        _config: &ConfigOptions,
336    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
337        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
338    }
339}
340
341impl ProjectionStream {
342    /// Create a new projection stream
343    fn new(
344        schema: SchemaRef,
345        expr: Vec<Arc<dyn PhysicalExpr>>,
346        input: SendableRecordBatchStream,
347        baseline_metrics: BaselineMetrics,
348    ) -> Self {
349        Self {
350            schema,
351            expr,
352            input,
353            baseline_metrics,
354        }
355    }
356
357    fn batch_project(&self, batch: &RecordBatch) -> Result<RecordBatch> {
358        // Records time on drop
359        let _timer = self.baseline_metrics.elapsed_compute().timer();
360        let arrays = self
361            .expr
362            .iter()
363            .map(|expr| {
364                expr.evaluate(batch)
365                    .and_then(|v| v.into_array(batch.num_rows()))
366            })
367            .collect::<Result<Vec<_>>>()?;
368
369        if arrays.is_empty() {
370            let options =
371                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
372            RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options)
373                .map_err(Into::into)
374        } else {
375            RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into)
376        }
377    }
378}
379
380/// Projection iterator
381struct ProjectionStream {
382    schema: SchemaRef,
383    expr: Vec<Arc<dyn PhysicalExpr>>,
384    input: SendableRecordBatchStream,
385    baseline_metrics: BaselineMetrics,
386}
387
388impl Stream for ProjectionStream {
389    type Item = Result<RecordBatch>;
390
391    fn poll_next(
392        mut self: Pin<&mut Self>,
393        cx: &mut Context<'_>,
394    ) -> Poll<Option<Self::Item>> {
395        let poll = self.input.poll_next_unpin(cx).map(|x| match x {
396            Some(Ok(batch)) => Some(self.batch_project(&batch)),
397            other => other,
398        });
399
400        self.baseline_metrics.record_poll(poll)
401    }
402
403    fn size_hint(&self) -> (usize, Option<usize>) {
404        // Same number of record batches
405        self.input.size_hint()
406    }
407}
408
409impl RecordBatchStream for ProjectionStream {
410    /// Get the schema
411    fn schema(&self) -> SchemaRef {
412        Arc::clone(&self.schema)
413    }
414}
415
416pub trait EmbeddedProjection: ExecutionPlan + Sized {
417    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self>;
418}
419
420/// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later.
421/// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column) and avoid unnecessary output creation.
422pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
423    projection: &ProjectionExec,
424    execution_plan: &Exec,
425) -> Result<Option<Arc<dyn ExecutionPlan>>> {
426    // Collect all column indices from the given projection expressions.
427    let projection_index = collect_column_indices(projection.expr());
428
429    if projection_index.is_empty() {
430        return Ok(None);
431    };
432
433    // If the projection indices is the same as the input columns, we don't need to embed the projection to hash join.
434    // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields.
435    if projection_index.len() == projection_index.last().unwrap() + 1
436        && projection_index.len() == execution_plan.schema().fields().len()
437    {
438        return Ok(None);
439    }
440
441    let new_execution_plan =
442        Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?);
443
444    // Build projection expressions for update_expr. Zip the projection_index with the new_execution_plan output schema fields.
445    let embed_project_exprs = projection_index
446        .iter()
447        .zip(new_execution_plan.schema().fields())
448        .map(|(index, field)| ProjectionExpr {
449            expr: Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
450            alias: field.name().to_owned(),
451        })
452        .collect::<Vec<_>>();
453
454    let mut new_projection_exprs = Vec::with_capacity(projection.expr().len());
455
456    for proj_expr in projection.expr() {
457        // update column index for projection expression since the input schema has been changed.
458        let Some(expr) =
459            update_expr(&proj_expr.expr, embed_project_exprs.as_slice(), false)?
460        else {
461            return Ok(None);
462        };
463        new_projection_exprs.push(ProjectionExpr {
464            expr,
465            alias: proj_expr.alias.clone(),
466        });
467    }
468    // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection.
469    let new_projection = Arc::new(ProjectionExec::try_new(
470        new_projection_exprs,
471        Arc::clone(&new_execution_plan) as _,
472    )?);
473    if is_projection_removable(&new_projection) {
474        Ok(Some(new_execution_plan))
475    } else {
476        Ok(Some(new_projection))
477    }
478}
479
480pub struct JoinData {
481    pub projected_left_child: ProjectionExec,
482    pub projected_right_child: ProjectionExec,
483    pub join_filter: Option<JoinFilter>,
484    pub join_on: JoinOn,
485}
486
487pub fn try_pushdown_through_join(
488    projection: &ProjectionExec,
489    join_left: &Arc<dyn ExecutionPlan>,
490    join_right: &Arc<dyn ExecutionPlan>,
491    join_on: JoinOnRef,
492    schema: SchemaRef,
493    filter: Option<&JoinFilter>,
494) -> Result<Option<JoinData>> {
495    // Convert projected expressions to columns. We can not proceed if this is not possible.
496    let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else {
497        return Ok(None);
498    };
499
500    let (far_right_left_col_ind, far_left_right_col_ind) =
501        join_table_borders(join_left.schema().fields().len(), &projection_as_columns);
502
503    if !join_allows_pushdown(
504        &projection_as_columns,
505        &schema,
506        far_right_left_col_ind,
507        far_left_right_col_ind,
508    ) {
509        return Ok(None);
510    }
511
512    let new_filter = if let Some(filter) = filter {
513        match update_join_filter(
514            &projection_as_columns[0..=far_right_left_col_ind as _],
515            &projection_as_columns[far_left_right_col_ind as _..],
516            filter,
517            join_left.schema().fields().len(),
518        ) {
519            Some(updated_filter) => Some(updated_filter),
520            None => return Ok(None),
521        }
522    } else {
523        None
524    };
525
526    let Some(new_on) = update_join_on(
527        &projection_as_columns[0..=far_right_left_col_ind as _],
528        &projection_as_columns[far_left_right_col_ind as _..],
529        join_on,
530        join_left.schema().fields().len(),
531    ) else {
532        return Ok(None);
533    };
534
535    let (new_left, new_right) = new_join_children(
536        &projection_as_columns,
537        far_right_left_col_ind,
538        far_left_right_col_ind,
539        join_left,
540        join_right,
541    )?;
542
543    Ok(Some(JoinData {
544        projected_left_child: new_left,
545        projected_right_child: new_right,
546        join_filter: new_filter,
547        join_on: new_on,
548    }))
549}
550
551/// This function checks if `plan` is a [`ProjectionExec`], and inspects its
552/// input(s) to test whether it can push `plan` under its input(s). This function
553/// will operate on the entire tree and may ultimately remove `plan` entirely
554/// by leveraging source providers with built-in projection capabilities.
555pub fn remove_unnecessary_projections(
556    plan: Arc<dyn ExecutionPlan>,
557) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
558    let maybe_modified =
559        if let Some(projection) = plan.as_any().downcast_ref::<ProjectionExec>() {
560            // If the projection does not cause any change on the input, we can
561            // safely remove it:
562            if is_projection_removable(projection) {
563                return Ok(Transformed::yes(Arc::clone(projection.input())));
564            }
565            // If it does, check if we can push it under its child(ren):
566            projection
567                .input()
568                .try_swapping_with_projection(projection)?
569        } else {
570            return Ok(Transformed::no(plan));
571        };
572    Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes))
573}
574
575/// Compare the inputs and outputs of the projection. All expressions must be
576/// columns without alias, and projection does not change the order of fields.
577/// For example, if the input schema is `a, b`, `SELECT a, b` is removable,
578/// but `SELECT b, a` and `SELECT a+1, b` and `SELECT a AS c, b` are not.
579fn is_projection_removable(projection: &ProjectionExec) -> bool {
580    let exprs = projection.expr();
581    exprs.iter().enumerate().all(|(idx, proj_expr)| {
582        let Some(col) = proj_expr.expr.as_any().downcast_ref::<Column>() else {
583            return false;
584        };
585        col.name() == proj_expr.alias && col.index() == idx
586    }) && exprs.len() == projection.input().schema().fields().len()
587}
588
589/// Given the expression set of a projection, checks if the projection causes
590/// any renaming or constructs a non-`Column` physical expression.
591pub fn all_alias_free_columns(exprs: &[ProjectionExpr]) -> bool {
592    exprs.iter().all(|proj_expr| {
593        proj_expr
594            .expr
595            .as_any()
596            .downcast_ref::<Column>()
597            .map(|column| column.name() == proj_expr.alias)
598            .unwrap_or(false)
599    })
600}
601
602/// Updates a source provider's projected columns according to the given
603/// projection operator's expressions. To use this function safely, one must
604/// ensure that all expressions are `Column` expressions without aliases.
605pub fn new_projections_for_columns(
606    projection: &[ProjectionExpr],
607    source: &[usize],
608) -> Vec<usize> {
609    projection
610        .iter()
611        .filter_map(|proj_expr| {
612            proj_expr
613                .expr
614                .as_any()
615                .downcast_ref::<Column>()
616                .map(|expr| source[expr.index()])
617        })
618        .collect()
619}
620
621/// Creates a new [`ProjectionExec`] instance with the given child plan and
622/// projected expressions.
623pub fn make_with_child(
624    projection: &ProjectionExec,
625    child: &Arc<dyn ExecutionPlan>,
626) -> Result<Arc<dyn ExecutionPlan>> {
627    ProjectionExec::try_new(projection.expr().to_vec(), Arc::clone(child))
628        .map(|e| Arc::new(e) as _)
629}
630
631/// Returns `true` if all the expressions in the argument are `Column`s.
632pub fn all_columns(exprs: &[ProjectionExpr]) -> bool {
633    exprs
634        .iter()
635        .all(|proj_expr| proj_expr.expr.as_any().is::<Column>())
636}
637
638/// Updates the given lexicographic ordering according to given projected
639/// expressions using the [`update_expr`] function.
640pub fn update_ordering(
641    ordering: LexOrdering,
642    projected_exprs: &[ProjectionExpr],
643) -> Result<Option<LexOrdering>> {
644    let mut updated_exprs = vec![];
645    for mut sort_expr in ordering.into_iter() {
646        let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)?
647        else {
648            return Ok(None);
649        };
650        sort_expr.expr = updated_expr;
651        updated_exprs.push(sort_expr);
652    }
653    Ok(LexOrdering::new(updated_exprs))
654}
655
656/// Updates the given lexicographic requirement according to given projected
657/// expressions using the [`update_expr`] function.
658pub fn update_ordering_requirement(
659    reqs: LexRequirement,
660    projected_exprs: &[ProjectionExpr],
661) -> Result<Option<LexRequirement>> {
662    let mut updated_exprs = vec![];
663    for mut sort_expr in reqs.into_iter() {
664        let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)?
665        else {
666            return Ok(None);
667        };
668        sort_expr.expr = updated_expr;
669        updated_exprs.push(sort_expr);
670    }
671    Ok(LexRequirement::new(updated_exprs))
672}
673
674/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given
675/// expressions is not a `Column`, returns `None`.
676pub fn physical_to_column_exprs(
677    exprs: &[ProjectionExpr],
678) -> Option<Vec<(Column, String)>> {
679    exprs
680        .iter()
681        .map(|proj_expr| {
682            proj_expr
683                .expr
684                .as_any()
685                .downcast_ref::<Column>()
686                .map(|col| (col.clone(), proj_expr.alias.clone()))
687        })
688        .collect()
689}
690
691/// If pushing down the projection over this join's children seems possible,
692/// this function constructs the new [`ProjectionExec`]s that will come on top
693/// of the original children of the join.
694pub fn new_join_children(
695    projection_as_columns: &[(Column, String)],
696    far_right_left_col_ind: i32,
697    far_left_right_col_ind: i32,
698    left_child: &Arc<dyn ExecutionPlan>,
699    right_child: &Arc<dyn ExecutionPlan>,
700) -> Result<(ProjectionExec, ProjectionExec)> {
701    let new_left = ProjectionExec::try_new(
702        projection_as_columns[0..=far_right_left_col_ind as _]
703            .iter()
704            .map(|(col, alias)| ProjectionExpr {
705                expr: Arc::new(Column::new(col.name(), col.index())) as _,
706                alias: alias.clone(),
707            }),
708        Arc::clone(left_child),
709    )?;
710    let left_size = left_child.schema().fields().len() as i32;
711    let new_right = ProjectionExec::try_new(
712        projection_as_columns[far_left_right_col_ind as _..]
713            .iter()
714            .map(|(col, alias)| {
715                ProjectionExpr {
716                    expr: Arc::new(Column::new(
717                        col.name(),
718                        // Align projected expressions coming from the right
719                        // table with the new right child projection:
720                        (col.index() as i32 - left_size) as _,
721                    )) as _,
722                    alias: alias.clone(),
723                }
724            }),
725        Arc::clone(right_child),
726    )?;
727
728    Ok((new_left, new_right))
729}
730
731/// Checks three conditions for pushing a projection down through a join:
732/// - Projection must narrow the join output schema.
733/// - Columns coming from left/right tables must be collected at the left/right
734///   sides of the output table.
735/// - Left or right table is not lost after the projection.
736pub fn join_allows_pushdown(
737    projection_as_columns: &[(Column, String)],
738    join_schema: &SchemaRef,
739    far_right_left_col_ind: i32,
740    far_left_right_col_ind: i32,
741) -> bool {
742    // Projection must narrow the join output:
743    projection_as_columns.len() < join_schema.fields().len()
744    // Are the columns from different tables mixed?
745    && (far_right_left_col_ind + 1 == far_left_right_col_ind)
746    // Left or right table is not lost after the projection.
747    && far_right_left_col_ind >= 0
748    && far_left_right_col_ind < projection_as_columns.len() as i32
749}
750
751/// Returns the last index before encountering a column coming from the right table when traveling
752/// through the projection from left to right, and the last index before encountering a column
753/// coming from the left table when traveling through the projection from right to left.
754/// If there is no column in the projection coming from the left side, it returns (-1, ...),
755/// if there is no column in the projection coming from the right side, it returns (..., projection length).
756pub fn join_table_borders(
757    left_table_column_count: usize,
758    projection_as_columns: &[(Column, String)],
759) -> (i32, i32) {
760    let far_right_left_col_ind = projection_as_columns
761        .iter()
762        .enumerate()
763        .take_while(|(_, (projection_column, _))| {
764            projection_column.index() < left_table_column_count
765        })
766        .last()
767        .map(|(index, _)| index as i32)
768        .unwrap_or(-1);
769
770    let far_left_right_col_ind = projection_as_columns
771        .iter()
772        .enumerate()
773        .rev()
774        .take_while(|(_, (projection_column, _))| {
775            projection_column.index() >= left_table_column_count
776        })
777        .last()
778        .map(|(index, _)| index as i32)
779        .unwrap_or(projection_as_columns.len() as i32);
780
781    (far_right_left_col_ind, far_left_right_col_ind)
782}
783
784/// Tries to update the equi-join `Column`'s of a join as if the input of
785/// the join was replaced by a projection.
786pub fn update_join_on(
787    proj_left_exprs: &[(Column, String)],
788    proj_right_exprs: &[(Column, String)],
789    hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
790    left_field_size: usize,
791) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
792    // TODO: Clippy wants the "map" call removed, but doing so generates
793    //       a compilation error. Remove the clippy directive once this
794    //       issue is fixed.
795    #[allow(clippy::map_identity)]
796    let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on
797        .iter()
798        .map(|(left, right)| (left, right))
799        .unzip();
800
801    let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs, 0);
802    let new_right_columns =
803        new_columns_for_join_on(&right_idx, proj_right_exprs, left_field_size);
804
805    match (new_left_columns, new_right_columns) {
806        (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()),
807        _ => None,
808    }
809}
810
811/// Tries to update the column indices of a [`JoinFilter`] as if the input of
812/// the join was replaced by a projection.
813pub fn update_join_filter(
814    projection_left_exprs: &[(Column, String)],
815    projection_right_exprs: &[(Column, String)],
816    join_filter: &JoinFilter,
817    left_field_size: usize,
818) -> Option<JoinFilter> {
819    let mut new_left_indices = new_indices_for_join_filter(
820        join_filter,
821        JoinSide::Left,
822        projection_left_exprs,
823        0,
824    )
825    .into_iter();
826    let mut new_right_indices = new_indices_for_join_filter(
827        join_filter,
828        JoinSide::Right,
829        projection_right_exprs,
830        left_field_size,
831    )
832    .into_iter();
833
834    // Check if all columns match:
835    (new_right_indices.len() + new_left_indices.len()
836        == join_filter.column_indices().len())
837    .then(|| {
838        JoinFilter::new(
839            Arc::clone(join_filter.expression()),
840            join_filter
841                .column_indices()
842                .iter()
843                .map(|col_idx| ColumnIndex {
844                    index: if col_idx.side == JoinSide::Left {
845                        new_left_indices.next().unwrap()
846                    } else {
847                        new_right_indices.next().unwrap()
848                    },
849                    side: col_idx.side,
850                })
851                .collect(),
852            Arc::clone(join_filter.schema()),
853        )
854    })
855}
856
857/// Unifies `projection` with its input (which is also a [`ProjectionExec`]).
858fn try_unifying_projections(
859    projection: &ProjectionExec,
860    child: &ProjectionExec,
861) -> Result<Option<Arc<dyn ExecutionPlan>>> {
862    let mut projected_exprs = vec![];
863    let mut column_ref_map: HashMap<Column, usize> = HashMap::new();
864
865    // Collect the column references usage in the outer projection.
866    projection.expr().iter().for_each(|proj_expr| {
867        proj_expr
868            .expr
869            .apply(|expr| {
870                Ok({
871                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
872                        *column_ref_map.entry(column.clone()).or_default() += 1;
873                    }
874                    TreeNodeRecursion::Continue
875                })
876            })
877            .unwrap();
878    });
879    // Merging these projections is not beneficial, e.g
880    // If an expression is not trivial and it is referred more than 1, unifies projections will be
881    // beneficial as caching mechanism for non-trivial computations.
882    // See discussion in: https://github.com/apache/datafusion/issues/8296
883    if column_ref_map.iter().any(|(column, count)| {
884        *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr))
885    }) {
886        return Ok(None);
887    }
888    for proj_expr in projection.expr() {
889        // If there is no match in the input projection, we cannot unify these
890        // projections. This case will arise if the projection expression contains
891        // a `PhysicalExpr` variant `update_expr` doesn't support.
892        let Some(expr) = update_expr(&proj_expr.expr, child.expr(), true)? else {
893            return Ok(None);
894        };
895        projected_exprs.push(ProjectionExpr {
896            expr,
897            alias: proj_expr.alias.clone(),
898        });
899    }
900    ProjectionExec::try_new(projected_exprs, Arc::clone(child.input()))
901        .map(|e| Some(Arc::new(e) as _))
902}
903
904/// Collect all column indices from the given projection expressions.
905fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec<usize> {
906    // Collect indices and remove duplicates.
907    let mut indices = exprs
908        .iter()
909        .flat_map(|proj_expr| collect_columns(&proj_expr.expr))
910        .map(|x| x.index())
911        .collect::<std::collections::HashSet<_>>()
912        .into_iter()
913        .collect::<Vec<_>>();
914    indices.sort();
915    indices
916}
917
918/// This function determines and returns a vector of indices representing the
919/// positions of columns in `projection_exprs` that are involved in `join_filter`,
920/// and correspond to a particular side (`join_side`) of the join operation.
921///
922/// Notes: Column indices in the projection expressions are based on the join schema,
923/// whereas the join filter is based on the join child schema. `column_index_offset`
924/// represents the offset between them.
925fn new_indices_for_join_filter(
926    join_filter: &JoinFilter,
927    join_side: JoinSide,
928    projection_exprs: &[(Column, String)],
929    column_index_offset: usize,
930) -> Vec<usize> {
931    join_filter
932        .column_indices()
933        .iter()
934        .filter(|col_idx| col_idx.side == join_side)
935        .filter_map(|col_idx| {
936            projection_exprs
937                .iter()
938                .position(|(col, _)| col_idx.index + column_index_offset == col.index())
939        })
940        .collect()
941}
942
943/// This function generates a new set of columns to be used in a hash join
944/// operation based on a set of equi-join conditions (`hash_join_on`) and a
945/// list of projection expressions (`projection_exprs`).
946///
947/// Notes: Column indices in the projection expressions are based on the join schema,
948/// whereas the join on expressions are based on the join child schema. `column_index_offset`
949/// represents the offset between them.
950fn new_columns_for_join_on(
951    hash_join_on: &[&PhysicalExprRef],
952    projection_exprs: &[(Column, String)],
953    column_index_offset: usize,
954) -> Option<Vec<PhysicalExprRef>> {
955    let new_columns = hash_join_on
956        .iter()
957        .filter_map(|on| {
958            // Rewrite all columns in `on`
959            Arc::clone(*on)
960                .transform(|expr| {
961                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
962                        // Find the column in the projection expressions
963                        let new_column = projection_exprs
964                            .iter()
965                            .enumerate()
966                            .find(|(_, (proj_column, _))| {
967                                column.name() == proj_column.name()
968                                    && column.index() + column_index_offset
969                                        == proj_column.index()
970                            })
971                            .map(|(index, (_, alias))| Column::new(alias, index));
972                        if let Some(new_column) = new_column {
973                            Ok(Transformed::yes(Arc::new(new_column)))
974                        } else {
975                            // If the column is not found in the projection expressions,
976                            // it means that the column is not projected. In this case,
977                            // we cannot push the projection down.
978                            internal_err!(
979                                "Column {:?} not found in projection expressions",
980                                column
981                            )
982                        }
983                    } else {
984                        Ok(Transformed::no(expr))
985                    }
986                })
987                .data()
988                .ok()
989        })
990        .collect::<Vec<_>>();
991    (new_columns.len() == hash_join_on.len()).then_some(new_columns)
992}
993
994/// Checks if the given expression is trivial.
995/// An expression is considered trivial if it is either a `Column` or a `Literal`.
996fn is_expr_trivial(expr: &Arc<dyn PhysicalExpr>) -> bool {
997    expr.as_any().downcast_ref::<Column>().is_some()
998        || expr.as_any().downcast_ref::<Literal>().is_some()
999}
1000
1001#[cfg(test)]
1002mod tests {
1003    use super::*;
1004    use std::sync::Arc;
1005
1006    use crate::common::collect;
1007    use crate::test;
1008    use crate::test::exec::StatisticsExec;
1009
1010    use arrow::datatypes::{DataType, Field, Schema};
1011    use datafusion_common::stats::{ColumnStatistics, Precision, Statistics};
1012    use datafusion_common::ScalarValue;
1013
1014    use datafusion_expr::Operator;
1015    use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, Literal};
1016
1017    #[test]
1018    fn test_collect_column_indices() -> Result<()> {
1019        let expr = Arc::new(BinaryExpr::new(
1020            Arc::new(Column::new("b", 7)),
1021            Operator::Minus,
1022            Arc::new(BinaryExpr::new(
1023                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1024                Operator::Plus,
1025                Arc::new(Column::new("a", 1)),
1026            )),
1027        ));
1028        let column_indices = collect_column_indices(&[ProjectionExpr {
1029            expr,
1030            alias: "b-(1+a)".to_string(),
1031        }]);
1032        assert_eq!(column_indices, vec![1, 7]);
1033        Ok(())
1034    }
1035
1036    #[test]
1037    fn test_join_table_borders() -> Result<()> {
1038        let projections = vec![
1039            (Column::new("b", 1), "b".to_owned()),
1040            (Column::new("c", 2), "c".to_owned()),
1041            (Column::new("e", 4), "e".to_owned()),
1042            (Column::new("d", 3), "d".to_owned()),
1043            (Column::new("c", 2), "c".to_owned()),
1044            (Column::new("f", 5), "f".to_owned()),
1045            (Column::new("h", 7), "h".to_owned()),
1046            (Column::new("g", 6), "g".to_owned()),
1047        ];
1048        let left_table_column_count = 5;
1049        assert_eq!(
1050            join_table_borders(left_table_column_count, &projections),
1051            (4, 5)
1052        );
1053
1054        let left_table_column_count = 8;
1055        assert_eq!(
1056            join_table_borders(left_table_column_count, &projections),
1057            (7, 8)
1058        );
1059
1060        let left_table_column_count = 1;
1061        assert_eq!(
1062            join_table_borders(left_table_column_count, &projections),
1063            (-1, 0)
1064        );
1065
1066        let projections = vec![
1067            (Column::new("a", 0), "a".to_owned()),
1068            (Column::new("b", 1), "b".to_owned()),
1069            (Column::new("d", 3), "d".to_owned()),
1070            (Column::new("g", 6), "g".to_owned()),
1071            (Column::new("e", 4), "e".to_owned()),
1072            (Column::new("f", 5), "f".to_owned()),
1073            (Column::new("e", 4), "e".to_owned()),
1074            (Column::new("h", 7), "h".to_owned()),
1075        ];
1076        let left_table_column_count = 5;
1077        assert_eq!(
1078            join_table_borders(left_table_column_count, &projections),
1079            (2, 7)
1080        );
1081
1082        let left_table_column_count = 7;
1083        assert_eq!(
1084            join_table_borders(left_table_column_count, &projections),
1085            (6, 7)
1086        );
1087
1088        Ok(())
1089    }
1090
1091    #[tokio::test]
1092    async fn project_no_column() -> Result<()> {
1093        let task_ctx = Arc::new(TaskContext::default());
1094
1095        let exec = test::scan_partitioned(1);
1096        let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?;
1097
1098        let projection = ProjectionExec::try_new(vec![] as Vec<ProjectionExpr>, exec)?;
1099        let stream = projection.execute(0, Arc::clone(&task_ctx))?;
1100        let output = collect(stream).await?;
1101        assert_eq!(output.len(), expected.len());
1102
1103        Ok(())
1104    }
1105
1106    #[tokio::test]
1107    async fn project_old_syntax() {
1108        let exec = test::scan_partitioned(1);
1109        let schema = exec.schema();
1110        let expr = col("i", &schema).unwrap();
1111        ProjectionExec::try_new(
1112            vec![
1113                // use From impl of ProjectionExpr to create ProjectionExpr
1114                // to test old syntax
1115                (expr, "c".to_string()),
1116            ],
1117            exec,
1118        )
1119        // expect this to succeed
1120        .unwrap();
1121    }
1122
1123    #[test]
1124    fn test_projection_statistics_uses_input_schema() {
1125        let input_schema = Schema::new(vec![
1126            Field::new("a", DataType::Int32, false),
1127            Field::new("b", DataType::Int32, false),
1128            Field::new("c", DataType::Int32, false),
1129            Field::new("d", DataType::Int32, false),
1130            Field::new("e", DataType::Int32, false),
1131            Field::new("f", DataType::Int32, false),
1132        ]);
1133
1134        let input_statistics = Statistics {
1135            num_rows: Precision::Exact(10),
1136            column_statistics: vec![
1137                ColumnStatistics {
1138                    min_value: Precision::Exact(ScalarValue::Int32(Some(1))),
1139                    max_value: Precision::Exact(ScalarValue::Int32(Some(100))),
1140                    ..Default::default()
1141                },
1142                ColumnStatistics {
1143                    min_value: Precision::Exact(ScalarValue::Int32(Some(5))),
1144                    max_value: Precision::Exact(ScalarValue::Int32(Some(50))),
1145                    ..Default::default()
1146                },
1147                ColumnStatistics {
1148                    min_value: Precision::Exact(ScalarValue::Int32(Some(10))),
1149                    max_value: Precision::Exact(ScalarValue::Int32(Some(40))),
1150                    ..Default::default()
1151                },
1152                ColumnStatistics {
1153                    min_value: Precision::Exact(ScalarValue::Int32(Some(20))),
1154                    max_value: Precision::Exact(ScalarValue::Int32(Some(30))),
1155                    ..Default::default()
1156                },
1157                ColumnStatistics {
1158                    min_value: Precision::Exact(ScalarValue::Int32(Some(21))),
1159                    max_value: Precision::Exact(ScalarValue::Int32(Some(29))),
1160                    ..Default::default()
1161                },
1162                ColumnStatistics {
1163                    min_value: Precision::Exact(ScalarValue::Int32(Some(24))),
1164                    max_value: Precision::Exact(ScalarValue::Int32(Some(26))),
1165                    ..Default::default()
1166                },
1167            ],
1168            ..Default::default()
1169        };
1170
1171        let input = Arc::new(StatisticsExec::new(input_statistics, input_schema));
1172
1173        // Create projection expressions that reference columns from the input schema and the length
1174        // of output schema columns < input schema columns and hence if we use the last few columns
1175        // from the input schema in the expressions here, bounds_check would fail on them if output
1176        // schema is supplied to the partitions_statistics method.
1177        let exprs: Vec<ProjectionExpr> = vec![
1178            ProjectionExpr {
1179                expr: Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>,
1180                alias: "c_renamed".to_string(),
1181            },
1182            ProjectionExpr {
1183                expr: Arc::new(BinaryExpr::new(
1184                    Arc::new(Column::new("e", 4)),
1185                    Operator::Plus,
1186                    Arc::new(Column::new("f", 5)),
1187                )) as Arc<dyn PhysicalExpr>,
1188                alias: "e_plus_f".to_string(),
1189            },
1190        ];
1191
1192        let projection = ProjectionExec::try_new(exprs, input).unwrap();
1193
1194        let stats = projection.partition_statistics(None).unwrap();
1195
1196        assert_eq!(stats.num_rows, Precision::Exact(10));
1197        assert_eq!(
1198            stats.column_statistics.len(),
1199            2,
1200            "Expected 2 columns in projection statistics"
1201        );
1202        assert!(stats.total_byte_size.is_exact().unwrap_or(false));
1203    }
1204}