datafusion_physical_plan/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the projection execution plan. A projection determines which columns or expressions
19//! are returned from a query. The SQL statement `SELECT a, b, a+b FROM t1` is an example
20//! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the
21//! projection expressions. `SELECT` without `FROM` will only evaluate expressions.
22
23use std::any::Any;
24use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use super::expressions::{Column, Literal};
30use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use super::{
32    DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
33    SendableRecordBatchStream, Statistics,
34};
35use crate::execution_plan::CardinalityEffect;
36use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef};
37use crate::{ColumnStatistics, DisplayFormatType, ExecutionPlan, PhysicalExpr};
38
39use arrow::datatypes::{Field, Schema, SchemaRef};
40use arrow::record_batch::{RecordBatch, RecordBatchOptions};
41use datafusion_common::stats::Precision;
42use datafusion_common::tree_node::{
43    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
44};
45use datafusion_common::{internal_err, JoinSide, Result};
46use datafusion_execution::TaskContext;
47use datafusion_physical_expr::equivalence::ProjectionMapping;
48use datafusion_physical_expr::utils::collect_columns;
49use datafusion_physical_expr::PhysicalExprRef;
50
51use datafusion_physical_expr_common::physical_expr::fmt_sql;
52use futures::stream::{Stream, StreamExt};
53use itertools::Itertools;
54use log::trace;
55
56/// Execution plan for a projection
57#[derive(Debug, Clone)]
58pub struct ProjectionExec {
59    /// The projection expressions stored as tuples of (expression, output column name)
60    pub(crate) expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
61    /// The schema once the projection has been applied to the input
62    schema: SchemaRef,
63    /// The input plan
64    input: Arc<dyn ExecutionPlan>,
65    /// Execution metrics
66    metrics: ExecutionPlanMetricsSet,
67    /// Cache holding plan properties like equivalences, output partitioning etc.
68    cache: PlanProperties,
69}
70
71impl ProjectionExec {
72    /// Create a projection on an input
73    pub fn try_new(
74        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
75        input: Arc<dyn ExecutionPlan>,
76    ) -> Result<Self> {
77        let input_schema = input.schema();
78
79        let fields: Result<Vec<Field>> = expr
80            .iter()
81            .map(|(e, name)| {
82                let metadata = e.return_field(&input_schema)?.metadata().clone();
83
84                let field = Field::new(
85                    name,
86                    e.data_type(&input_schema)?,
87                    e.nullable(&input_schema)?,
88                )
89                .with_metadata(metadata);
90
91                Ok(field)
92            })
93            .collect();
94
95        let schema = Arc::new(Schema::new_with_metadata(
96            fields?,
97            input_schema.metadata().clone(),
98        ));
99
100        // Construct a map from the input expressions to the output expression of the Projection
101        let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?;
102        let cache =
103            Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?;
104        Ok(Self {
105            expr,
106            schema,
107            input,
108            metrics: ExecutionPlanMetricsSet::new(),
109            cache,
110        })
111    }
112
113    /// The projection expressions stored as tuples of (expression, output column name)
114    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
115        &self.expr
116    }
117
118    /// The input plan
119    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
120        &self.input
121    }
122
123    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
124    fn compute_properties(
125        input: &Arc<dyn ExecutionPlan>,
126        projection_mapping: &ProjectionMapping,
127        schema: SchemaRef,
128    ) -> Result<PlanProperties> {
129        // Calculate equivalence properties:
130        let mut input_eq_properties = input.equivalence_properties().clone();
131        input_eq_properties.substitute_oeq_class(projection_mapping)?;
132        let eq_properties = input_eq_properties.project(projection_mapping, schema);
133
134        // Calculate output partitioning, which needs to respect aliases:
135        let input_partition = input.output_partitioning();
136        let output_partitioning =
137            input_partition.project(projection_mapping, &input_eq_properties);
138
139        Ok(PlanProperties::new(
140            eq_properties,
141            output_partitioning,
142            input.pipeline_behavior(),
143            input.boundedness(),
144        ))
145    }
146}
147
148impl DisplayAs for ProjectionExec {
149    fn fmt_as(
150        &self,
151        t: DisplayFormatType,
152        f: &mut std::fmt::Formatter,
153    ) -> std::fmt::Result {
154        match t {
155            DisplayFormatType::Default | DisplayFormatType::Verbose => {
156                let expr: Vec<String> = self
157                    .expr
158                    .iter()
159                    .map(|(e, alias)| {
160                        let e = e.to_string();
161                        if &e != alias {
162                            format!("{e} as {alias}")
163                        } else {
164                            e
165                        }
166                    })
167                    .collect();
168
169                write!(f, "ProjectionExec: expr=[{}]", expr.join(", "))
170            }
171            DisplayFormatType::TreeRender => {
172                for (i, (e, alias)) in self.expr().iter().enumerate() {
173                    let expr_sql = fmt_sql(e.as_ref());
174                    if &e.to_string() == alias {
175                        writeln!(f, "expr{i}={expr_sql}")?;
176                    } else {
177                        writeln!(f, "{alias}={expr_sql}")?;
178                    }
179                }
180
181                Ok(())
182            }
183        }
184    }
185}
186
187impl ExecutionPlan for ProjectionExec {
188    fn name(&self) -> &'static str {
189        "ProjectionExec"
190    }
191
192    /// Return a reference to Any that can be used for downcasting
193    fn as_any(&self) -> &dyn Any {
194        self
195    }
196
197    fn properties(&self) -> &PlanProperties {
198        &self.cache
199    }
200
201    fn maintains_input_order(&self) -> Vec<bool> {
202        // Tell optimizer this operator doesn't reorder its input
203        vec![true]
204    }
205
206    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
207        let all_simple_exprs = self
208            .expr
209            .iter()
210            .all(|(e, _)| e.as_any().is::<Column>() || e.as_any().is::<Literal>());
211        // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename,
212        // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false.
213        vec![!all_simple_exprs]
214    }
215
216    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
217        vec![&self.input]
218    }
219
220    fn with_new_children(
221        self: Arc<Self>,
222        mut children: Vec<Arc<dyn ExecutionPlan>>,
223    ) -> Result<Arc<dyn ExecutionPlan>> {
224        ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0))
225            .map(|p| Arc::new(p) as _)
226    }
227
228    fn execute(
229        &self,
230        partition: usize,
231        context: Arc<TaskContext>,
232    ) -> Result<SendableRecordBatchStream> {
233        trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
234        Ok(Box::pin(ProjectionStream {
235            schema: Arc::clone(&self.schema),
236            expr: self.expr.iter().map(|x| Arc::clone(&x.0)).collect(),
237            input: self.input.execute(partition, context)?,
238            baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
239        }))
240    }
241
242    fn metrics(&self) -> Option<MetricsSet> {
243        Some(self.metrics.clone_inner())
244    }
245
246    fn statistics(&self) -> Result<Statistics> {
247        self.partition_statistics(None)
248    }
249
250    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
251        let input_stats = self.input.partition_statistics(partition)?;
252        Ok(stats_projection(
253            input_stats,
254            self.expr.iter().map(|(e, _)| Arc::clone(e)),
255            Arc::clone(&self.schema),
256        ))
257    }
258
259    fn supports_limit_pushdown(&self) -> bool {
260        true
261    }
262
263    fn cardinality_effect(&self) -> CardinalityEffect {
264        CardinalityEffect::Equal
265    }
266
267    fn try_swapping_with_projection(
268        &self,
269        projection: &ProjectionExec,
270    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
271        let maybe_unified = try_unifying_projections(projection, self)?;
272        if let Some(new_plan) = maybe_unified {
273            // To unify 3 or more sequential projections:
274            remove_unnecessary_projections(new_plan).data().map(Some)
275        } else {
276            Ok(Some(Arc::new(projection.clone())))
277        }
278    }
279}
280
281fn stats_projection(
282    mut stats: Statistics,
283    exprs: impl Iterator<Item = Arc<dyn PhysicalExpr>>,
284    schema: SchemaRef,
285) -> Statistics {
286    let mut primitive_row_size = 0;
287    let mut primitive_row_size_possible = true;
288    let mut column_statistics = vec![];
289    for expr in exprs {
290        let col_stats = if let Some(col) = expr.as_any().downcast_ref::<Column>() {
291            stats.column_statistics[col.index()].clone()
292        } else {
293            // TODO stats: estimate more statistics from expressions
294            // (expressions should compute their statistics themselves)
295            ColumnStatistics::new_unknown()
296        };
297        column_statistics.push(col_stats);
298        if let Ok(data_type) = expr.data_type(&schema) {
299            if let Some(value) = data_type.primitive_width() {
300                primitive_row_size += value;
301                continue;
302            }
303        }
304        primitive_row_size_possible = false;
305    }
306
307    if primitive_row_size_possible {
308        stats.total_byte_size =
309            Precision::Exact(primitive_row_size).multiply(&stats.num_rows);
310    }
311    stats.column_statistics = column_statistics;
312    stats
313}
314
315impl ProjectionStream {
316    fn batch_project(&self, batch: &RecordBatch) -> Result<RecordBatch> {
317        // Records time on drop
318        let _timer = self.baseline_metrics.elapsed_compute().timer();
319        let arrays = self
320            .expr
321            .iter()
322            .map(|expr| {
323                expr.evaluate(batch)
324                    .and_then(|v| v.into_array(batch.num_rows()))
325            })
326            .collect::<Result<Vec<_>>>()?;
327
328        if arrays.is_empty() {
329            let options =
330                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
331            RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options)
332                .map_err(Into::into)
333        } else {
334            RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into)
335        }
336    }
337}
338
339/// Projection iterator
340struct ProjectionStream {
341    schema: SchemaRef,
342    expr: Vec<Arc<dyn PhysicalExpr>>,
343    input: SendableRecordBatchStream,
344    baseline_metrics: BaselineMetrics,
345}
346
347impl Stream for ProjectionStream {
348    type Item = Result<RecordBatch>;
349
350    fn poll_next(
351        mut self: Pin<&mut Self>,
352        cx: &mut Context<'_>,
353    ) -> Poll<Option<Self::Item>> {
354        let poll = self.input.poll_next_unpin(cx).map(|x| match x {
355            Some(Ok(batch)) => Some(self.batch_project(&batch)),
356            other => other,
357        });
358
359        self.baseline_metrics.record_poll(poll)
360    }
361
362    fn size_hint(&self) -> (usize, Option<usize>) {
363        // Same number of record batches
364        self.input.size_hint()
365    }
366}
367
368impl RecordBatchStream for ProjectionStream {
369    /// Get the schema
370    fn schema(&self) -> SchemaRef {
371        Arc::clone(&self.schema)
372    }
373}
374
375pub trait EmbeddedProjection: ExecutionPlan + Sized {
376    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self>;
377}
378
379/// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later.
380/// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column) and avoid unnecessary output creation.
381pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
382    projection: &ProjectionExec,
383    execution_plan: &Exec,
384) -> Result<Option<Arc<dyn ExecutionPlan>>> {
385    // Collect all column indices from the given projection expressions.
386    let projection_index = collect_column_indices(projection.expr());
387
388    if projection_index.is_empty() {
389        return Ok(None);
390    };
391
392    // If the projection indices is the same as the input columns, we don't need to embed the projection to hash join.
393    // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields.
394    if projection_index.len() == projection_index.last().unwrap() + 1
395        && projection_index.len() == execution_plan.schema().fields().len()
396    {
397        return Ok(None);
398    }
399
400    let new_execution_plan =
401        Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?);
402
403    // Build projection expressions for update_expr. Zip the projection_index with the new_execution_plan output schema fields.
404    let embed_project_exprs = projection_index
405        .iter()
406        .zip(new_execution_plan.schema().fields())
407        .map(|(index, field)| {
408            (
409                Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
410                field.name().to_owned(),
411            )
412        })
413        .collect::<Vec<_>>();
414
415    let mut new_projection_exprs = Vec::with_capacity(projection.expr().len());
416
417    for (expr, alias) in projection.expr() {
418        // update column index for projection expression since the input schema has been changed.
419        let Some(expr) = update_expr(expr, embed_project_exprs.as_slice(), false)? else {
420            return Ok(None);
421        };
422        new_projection_exprs.push((expr, alias.clone()));
423    }
424    // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection.
425    let new_projection = Arc::new(ProjectionExec::try_new(
426        new_projection_exprs,
427        Arc::clone(&new_execution_plan) as _,
428    )?);
429    if is_projection_removable(&new_projection) {
430        Ok(Some(new_execution_plan))
431    } else {
432        Ok(Some(new_projection))
433    }
434}
435
436pub struct JoinData {
437    pub projected_left_child: ProjectionExec,
438    pub projected_right_child: ProjectionExec,
439    pub join_filter: Option<JoinFilter>,
440    pub join_on: JoinOn,
441}
442
443pub fn try_pushdown_through_join(
444    projection: &ProjectionExec,
445    join_left: &Arc<dyn ExecutionPlan>,
446    join_right: &Arc<dyn ExecutionPlan>,
447    join_on: JoinOnRef,
448    schema: SchemaRef,
449    filter: Option<&JoinFilter>,
450) -> Result<Option<JoinData>> {
451    // Convert projected expressions to columns. We can not proceed if this is not possible.
452    let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else {
453        return Ok(None);
454    };
455
456    let (far_right_left_col_ind, far_left_right_col_ind) =
457        join_table_borders(join_left.schema().fields().len(), &projection_as_columns);
458
459    if !join_allows_pushdown(
460        &projection_as_columns,
461        &schema,
462        far_right_left_col_ind,
463        far_left_right_col_ind,
464    ) {
465        return Ok(None);
466    }
467
468    let new_filter = if let Some(filter) = filter {
469        match update_join_filter(
470            &projection_as_columns[0..=far_right_left_col_ind as _],
471            &projection_as_columns[far_left_right_col_ind as _..],
472            filter,
473            join_left.schema().fields().len(),
474        ) {
475            Some(updated_filter) => Some(updated_filter),
476            None => return Ok(None),
477        }
478    } else {
479        None
480    };
481
482    let Some(new_on) = update_join_on(
483        &projection_as_columns[0..=far_right_left_col_ind as _],
484        &projection_as_columns[far_left_right_col_ind as _..],
485        join_on,
486        join_left.schema().fields().len(),
487    ) else {
488        return Ok(None);
489    };
490
491    let (new_left, new_right) = new_join_children(
492        &projection_as_columns,
493        far_right_left_col_ind,
494        far_left_right_col_ind,
495        join_left,
496        join_right,
497    )?;
498
499    Ok(Some(JoinData {
500        projected_left_child: new_left,
501        projected_right_child: new_right,
502        join_filter: new_filter,
503        join_on: new_on,
504    }))
505}
506
507/// This function checks if `plan` is a [`ProjectionExec`], and inspects its
508/// input(s) to test whether it can push `plan` under its input(s). This function
509/// will operate on the entire tree and may ultimately remove `plan` entirely
510/// by leveraging source providers with built-in projection capabilities.
511pub fn remove_unnecessary_projections(
512    plan: Arc<dyn ExecutionPlan>,
513) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
514    let maybe_modified =
515        if let Some(projection) = plan.as_any().downcast_ref::<ProjectionExec>() {
516            // If the projection does not cause any change on the input, we can
517            // safely remove it:
518            if is_projection_removable(projection) {
519                return Ok(Transformed::yes(Arc::clone(projection.input())));
520            }
521            // If it does, check if we can push it under its child(ren):
522            projection
523                .input()
524                .try_swapping_with_projection(projection)?
525        } else {
526            return Ok(Transformed::no(plan));
527        };
528    Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes))
529}
530
531/// Compare the inputs and outputs of the projection. All expressions must be
532/// columns without alias, and projection does not change the order of fields.
533/// For example, if the input schema is `a, b`, `SELECT a, b` is removable,
534/// but `SELECT b, a` and `SELECT a+1, b` and `SELECT a AS c, b` are not.
535fn is_projection_removable(projection: &ProjectionExec) -> bool {
536    let exprs = projection.expr();
537    exprs.iter().enumerate().all(|(idx, (expr, alias))| {
538        let Some(col) = expr.as_any().downcast_ref::<Column>() else {
539            return false;
540        };
541        col.name() == alias && col.index() == idx
542    }) && exprs.len() == projection.input().schema().fields().len()
543}
544
545/// Given the expression set of a projection, checks if the projection causes
546/// any renaming or constructs a non-`Column` physical expression.
547pub fn all_alias_free_columns(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> bool {
548    exprs.iter().all(|(expr, alias)| {
549        expr.as_any()
550            .downcast_ref::<Column>()
551            .map(|column| column.name() == alias)
552            .unwrap_or(false)
553    })
554}
555
556/// Updates a source provider's projected columns according to the given
557/// projection operator's expressions. To use this function safely, one must
558/// ensure that all expressions are `Column` expressions without aliases.
559pub fn new_projections_for_columns(
560    projection: &ProjectionExec,
561    source: &[usize],
562) -> Vec<usize> {
563    projection
564        .expr()
565        .iter()
566        .filter_map(|(expr, _)| {
567            expr.as_any()
568                .downcast_ref::<Column>()
569                .map(|expr| source[expr.index()])
570        })
571        .collect()
572}
573
574/// Creates a new [`ProjectionExec`] instance with the given child plan and
575/// projected expressions.
576pub fn make_with_child(
577    projection: &ProjectionExec,
578    child: &Arc<dyn ExecutionPlan>,
579) -> Result<Arc<dyn ExecutionPlan>> {
580    ProjectionExec::try_new(projection.expr().to_vec(), Arc::clone(child))
581        .map(|e| Arc::new(e) as _)
582}
583
584/// Returns `true` if all the expressions in the argument are `Column`s.
585pub fn all_columns(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> bool {
586    exprs.iter().all(|(expr, _)| expr.as_any().is::<Column>())
587}
588
589/// The function operates in two modes:
590///
591/// 1) When `sync_with_child` is `true`:
592///
593///    The function updates the indices of `expr` if the expression resides
594///    in the input plan. For instance, given the expressions `a@1 + b@2`
595///    and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are
596///    updated to `a@0 + b@1` and `c@2`.
597///
598/// 2) When `sync_with_child` is `false`:
599///
600///    The function determines how the expression would be updated if a projection
601///    was placed before the plan associated with the expression. If the expression
602///    cannot be rewritten after the projection, it returns `None`. For example,
603///    given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with
604///    an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes
605///    `a@0`, but `b@2` results in `None` since the projection does not include `b`.
606pub fn update_expr(
607    expr: &Arc<dyn PhysicalExpr>,
608    projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
609    sync_with_child: bool,
610) -> Result<Option<Arc<dyn PhysicalExpr>>> {
611    #[derive(Debug, PartialEq)]
612    enum RewriteState {
613        /// The expression is unchanged.
614        Unchanged,
615        /// Some part of the expression has been rewritten
616        RewrittenValid,
617        /// Some part of the expression has been rewritten, but some column
618        /// references could not be.
619        RewrittenInvalid,
620    }
621
622    let mut state = RewriteState::Unchanged;
623
624    let new_expr = Arc::clone(expr)
625        .transform_up(|expr: Arc<dyn PhysicalExpr>| {
626            if state == RewriteState::RewrittenInvalid {
627                return Ok(Transformed::no(expr));
628            }
629
630            let Some(column) = expr.as_any().downcast_ref::<Column>() else {
631                return Ok(Transformed::no(expr));
632            };
633            if sync_with_child {
634                state = RewriteState::RewrittenValid;
635                // Update the index of `column`:
636                Ok(Transformed::yes(Arc::clone(
637                    &projected_exprs[column.index()].0,
638                )))
639            } else {
640                // default to invalid, in case we can't find the relevant column
641                state = RewriteState::RewrittenInvalid;
642                // Determine how to update `column` to accommodate `projected_exprs`
643                projected_exprs
644                    .iter()
645                    .enumerate()
646                    .find_map(|(index, (projected_expr, alias))| {
647                        projected_expr.as_any().downcast_ref::<Column>().and_then(
648                            |projected_column| {
649                                (column.name().eq(projected_column.name())
650                                    && column.index() == projected_column.index())
651                                .then(|| {
652                                    state = RewriteState::RewrittenValid;
653                                    Arc::new(Column::new(alias, index)) as _
654                                })
655                            },
656                        )
657                    })
658                    .map_or_else(
659                        || Ok(Transformed::no(expr)),
660                        |c| Ok(Transformed::yes(c)),
661                    )
662            }
663        })
664        .data();
665
666    new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e))
667}
668
669/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given
670/// expressions is not a `Column`, returns `None`.
671pub fn physical_to_column_exprs(
672    exprs: &[(Arc<dyn PhysicalExpr>, String)],
673) -> Option<Vec<(Column, String)>> {
674    exprs
675        .iter()
676        .map(|(expr, alias)| {
677            expr.as_any()
678                .downcast_ref::<Column>()
679                .map(|col| (col.clone(), alias.clone()))
680        })
681        .collect()
682}
683
684/// If pushing down the projection over this join's children seems possible,
685/// this function constructs the new [`ProjectionExec`]s that will come on top
686/// of the original children of the join.
687pub fn new_join_children(
688    projection_as_columns: &[(Column, String)],
689    far_right_left_col_ind: i32,
690    far_left_right_col_ind: i32,
691    left_child: &Arc<dyn ExecutionPlan>,
692    right_child: &Arc<dyn ExecutionPlan>,
693) -> Result<(ProjectionExec, ProjectionExec)> {
694    let new_left = ProjectionExec::try_new(
695        projection_as_columns[0..=far_right_left_col_ind as _]
696            .iter()
697            .map(|(col, alias)| {
698                (
699                    Arc::new(Column::new(col.name(), col.index())) as _,
700                    alias.clone(),
701                )
702            })
703            .collect_vec(),
704        Arc::clone(left_child),
705    )?;
706    let left_size = left_child.schema().fields().len() as i32;
707    let new_right = ProjectionExec::try_new(
708        projection_as_columns[far_left_right_col_ind as _..]
709            .iter()
710            .map(|(col, alias)| {
711                (
712                    Arc::new(Column::new(
713                        col.name(),
714                        // Align projected expressions coming from the right
715                        // table with the new right child projection:
716                        (col.index() as i32 - left_size) as _,
717                    )) as _,
718                    alias.clone(),
719                )
720            })
721            .collect_vec(),
722        Arc::clone(right_child),
723    )?;
724
725    Ok((new_left, new_right))
726}
727
728/// Checks three conditions for pushing a projection down through a join:
729/// - Projection must narrow the join output schema.
730/// - Columns coming from left/right tables must be collected at the left/right
731///   sides of the output table.
732/// - Left or right table is not lost after the projection.
733pub fn join_allows_pushdown(
734    projection_as_columns: &[(Column, String)],
735    join_schema: &SchemaRef,
736    far_right_left_col_ind: i32,
737    far_left_right_col_ind: i32,
738) -> bool {
739    // Projection must narrow the join output:
740    projection_as_columns.len() < join_schema.fields().len()
741    // Are the columns from different tables mixed?
742    && (far_right_left_col_ind + 1 == far_left_right_col_ind)
743    // Left or right table is not lost after the projection.
744    && far_right_left_col_ind >= 0
745    && far_left_right_col_ind < projection_as_columns.len() as i32
746}
747
748/// Returns the last index before encountering a column coming from the right table when traveling
749/// through the projection from left to right, and the last index before encountering a column
750/// coming from the left table when traveling through the projection from right to left.
751/// If there is no column in the projection coming from the left side, it returns (-1, ...),
752/// if there is no column in the projection coming from the right side, it returns (..., projection length).
753pub fn join_table_borders(
754    left_table_column_count: usize,
755    projection_as_columns: &[(Column, String)],
756) -> (i32, i32) {
757    let far_right_left_col_ind = projection_as_columns
758        .iter()
759        .enumerate()
760        .take_while(|(_, (projection_column, _))| {
761            projection_column.index() < left_table_column_count
762        })
763        .last()
764        .map(|(index, _)| index as i32)
765        .unwrap_or(-1);
766
767    let far_left_right_col_ind = projection_as_columns
768        .iter()
769        .enumerate()
770        .rev()
771        .take_while(|(_, (projection_column, _))| {
772            projection_column.index() >= left_table_column_count
773        })
774        .last()
775        .map(|(index, _)| index as i32)
776        .unwrap_or(projection_as_columns.len() as i32);
777
778    (far_right_left_col_ind, far_left_right_col_ind)
779}
780
781/// Tries to update the equi-join `Column`'s of a join as if the input of
782/// the join was replaced by a projection.
783pub fn update_join_on(
784    proj_left_exprs: &[(Column, String)],
785    proj_right_exprs: &[(Column, String)],
786    hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
787    left_field_size: usize,
788) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
789    // TODO: Clippy wants the "map" call removed, but doing so generates
790    //       a compilation error. Remove the clippy directive once this
791    //       issue is fixed.
792    #[allow(clippy::map_identity)]
793    let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on
794        .iter()
795        .map(|(left, right)| (left, right))
796        .unzip();
797
798    let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs, 0);
799    let new_right_columns =
800        new_columns_for_join_on(&right_idx, proj_right_exprs, left_field_size);
801
802    match (new_left_columns, new_right_columns) {
803        (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()),
804        _ => None,
805    }
806}
807
808/// Tries to update the column indices of a [`JoinFilter`] as if the input of
809/// the join was replaced by a projection.
810pub fn update_join_filter(
811    projection_left_exprs: &[(Column, String)],
812    projection_right_exprs: &[(Column, String)],
813    join_filter: &JoinFilter,
814    left_field_size: usize,
815) -> Option<JoinFilter> {
816    let mut new_left_indices = new_indices_for_join_filter(
817        join_filter,
818        JoinSide::Left,
819        projection_left_exprs,
820        0,
821    )
822    .into_iter();
823    let mut new_right_indices = new_indices_for_join_filter(
824        join_filter,
825        JoinSide::Right,
826        projection_right_exprs,
827        left_field_size,
828    )
829    .into_iter();
830
831    // Check if all columns match:
832    (new_right_indices.len() + new_left_indices.len()
833        == join_filter.column_indices().len())
834    .then(|| {
835        JoinFilter::new(
836            Arc::clone(join_filter.expression()),
837            join_filter
838                .column_indices()
839                .iter()
840                .map(|col_idx| ColumnIndex {
841                    index: if col_idx.side == JoinSide::Left {
842                        new_left_indices.next().unwrap()
843                    } else {
844                        new_right_indices.next().unwrap()
845                    },
846                    side: col_idx.side,
847                })
848                .collect(),
849            Arc::clone(join_filter.schema()),
850        )
851    })
852}
853
854/// Unifies `projection` with its input (which is also a [`ProjectionExec`]).
855fn try_unifying_projections(
856    projection: &ProjectionExec,
857    child: &ProjectionExec,
858) -> Result<Option<Arc<dyn ExecutionPlan>>> {
859    let mut projected_exprs = vec![];
860    let mut column_ref_map: HashMap<Column, usize> = HashMap::new();
861
862    // Collect the column references usage in the outer projection.
863    projection.expr().iter().for_each(|(expr, _)| {
864        expr.apply(|expr| {
865            Ok({
866                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
867                    *column_ref_map.entry(column.clone()).or_default() += 1;
868                }
869                TreeNodeRecursion::Continue
870            })
871        })
872        .unwrap();
873    });
874    // Merging these projections is not beneficial, e.g
875    // If an expression is not trivial and it is referred more than 1, unifies projections will be
876    // beneficial as caching mechanism for non-trivial computations.
877    // See discussion in: https://github.com/apache/datafusion/issues/8296
878    if column_ref_map.iter().any(|(column, count)| {
879        *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].0))
880    }) {
881        return Ok(None);
882    }
883    for (expr, alias) in projection.expr() {
884        // If there is no match in the input projection, we cannot unify these
885        // projections. This case will arise if the projection expression contains
886        // a `PhysicalExpr` variant `update_expr` doesn't support.
887        let Some(expr) = update_expr(expr, child.expr(), true)? else {
888            return Ok(None);
889        };
890        projected_exprs.push((expr, alias.clone()));
891    }
892    ProjectionExec::try_new(projected_exprs, Arc::clone(child.input()))
893        .map(|e| Some(Arc::new(e) as _))
894}
895
896/// Collect all column indices from the given projection expressions.
897fn collect_column_indices(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> Vec<usize> {
898    // Collect indices and remove duplicates.
899    let mut indices = exprs
900        .iter()
901        .flat_map(|(expr, _)| collect_columns(expr))
902        .map(|x| x.index())
903        .collect::<std::collections::HashSet<_>>()
904        .into_iter()
905        .collect::<Vec<_>>();
906    indices.sort();
907    indices
908}
909
910/// This function determines and returns a vector of indices representing the
911/// positions of columns in `projection_exprs` that are involved in `join_filter`,
912/// and correspond to a particular side (`join_side`) of the join operation.
913///
914/// Notes: Column indices in the projection expressions are based on the join schema,
915/// whereas the join filter is based on the join child schema. `column_index_offset`
916/// represents the offset between them.
917fn new_indices_for_join_filter(
918    join_filter: &JoinFilter,
919    join_side: JoinSide,
920    projection_exprs: &[(Column, String)],
921    column_index_offset: usize,
922) -> Vec<usize> {
923    join_filter
924        .column_indices()
925        .iter()
926        .filter(|col_idx| col_idx.side == join_side)
927        .filter_map(|col_idx| {
928            projection_exprs
929                .iter()
930                .position(|(col, _)| col_idx.index + column_index_offset == col.index())
931        })
932        .collect()
933}
934
935/// This function generates a new set of columns to be used in a hash join
936/// operation based on a set of equi-join conditions (`hash_join_on`) and a
937/// list of projection expressions (`projection_exprs`).
938///
939/// Notes: Column indices in the projection expressions are based on the join schema,
940/// whereas the join on expressions are based on the join child schema. `column_index_offset`
941/// represents the offset between them.
942fn new_columns_for_join_on(
943    hash_join_on: &[&PhysicalExprRef],
944    projection_exprs: &[(Column, String)],
945    column_index_offset: usize,
946) -> Option<Vec<PhysicalExprRef>> {
947    let new_columns = hash_join_on
948        .iter()
949        .filter_map(|on| {
950            // Rewrite all columns in `on`
951            Arc::clone(*on)
952                .transform(|expr| {
953                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
954                        // Find the column in the projection expressions
955                        let new_column = projection_exprs
956                            .iter()
957                            .enumerate()
958                            .find(|(_, (proj_column, _))| {
959                                column.name() == proj_column.name()
960                                    && column.index() + column_index_offset
961                                        == proj_column.index()
962                            })
963                            .map(|(index, (_, alias))| Column::new(alias, index));
964                        if let Some(new_column) = new_column {
965                            Ok(Transformed::yes(Arc::new(new_column)))
966                        } else {
967                            // If the column is not found in the projection expressions,
968                            // it means that the column is not projected. In this case,
969                            // we cannot push the projection down.
970                            internal_err!(
971                                "Column {:?} not found in projection expressions",
972                                column
973                            )
974                        }
975                    } else {
976                        Ok(Transformed::no(expr))
977                    }
978                })
979                .data()
980                .ok()
981        })
982        .collect::<Vec<_>>();
983    (new_columns.len() == hash_join_on.len()).then_some(new_columns)
984}
985
986/// Checks if the given expression is trivial.
987/// An expression is considered trivial if it is either a `Column` or a `Literal`.
988fn is_expr_trivial(expr: &Arc<dyn PhysicalExpr>) -> bool {
989    expr.as_any().downcast_ref::<Column>().is_some()
990        || expr.as_any().downcast_ref::<Literal>().is_some()
991}
992
993#[cfg(test)]
994mod tests {
995    use super::*;
996    use std::sync::Arc;
997
998    use crate::common::collect;
999    use crate::test;
1000
1001    use arrow::datatypes::DataType;
1002    use datafusion_common::ScalarValue;
1003
1004    use datafusion_expr::Operator;
1005    use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
1006
1007    #[test]
1008    fn test_collect_column_indices() -> Result<()> {
1009        let expr = Arc::new(BinaryExpr::new(
1010            Arc::new(Column::new("b", 7)),
1011            Operator::Minus,
1012            Arc::new(BinaryExpr::new(
1013                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1014                Operator::Plus,
1015                Arc::new(Column::new("a", 1)),
1016            )),
1017        ));
1018        let column_indices = collect_column_indices(&[(expr, "b-(1+a)".to_string())]);
1019        assert_eq!(column_indices, vec![1, 7]);
1020        Ok(())
1021    }
1022
1023    #[test]
1024    fn test_join_table_borders() -> Result<()> {
1025        let projections = vec![
1026            (Column::new("b", 1), "b".to_owned()),
1027            (Column::new("c", 2), "c".to_owned()),
1028            (Column::new("e", 4), "e".to_owned()),
1029            (Column::new("d", 3), "d".to_owned()),
1030            (Column::new("c", 2), "c".to_owned()),
1031            (Column::new("f", 5), "f".to_owned()),
1032            (Column::new("h", 7), "h".to_owned()),
1033            (Column::new("g", 6), "g".to_owned()),
1034        ];
1035        let left_table_column_count = 5;
1036        assert_eq!(
1037            join_table_borders(left_table_column_count, &projections),
1038            (4, 5)
1039        );
1040
1041        let left_table_column_count = 8;
1042        assert_eq!(
1043            join_table_borders(left_table_column_count, &projections),
1044            (7, 8)
1045        );
1046
1047        let left_table_column_count = 1;
1048        assert_eq!(
1049            join_table_borders(left_table_column_count, &projections),
1050            (-1, 0)
1051        );
1052
1053        let projections = vec![
1054            (Column::new("a", 0), "a".to_owned()),
1055            (Column::new("b", 1), "b".to_owned()),
1056            (Column::new("d", 3), "d".to_owned()),
1057            (Column::new("g", 6), "g".to_owned()),
1058            (Column::new("e", 4), "e".to_owned()),
1059            (Column::new("f", 5), "f".to_owned()),
1060            (Column::new("e", 4), "e".to_owned()),
1061            (Column::new("h", 7), "h".to_owned()),
1062        ];
1063        let left_table_column_count = 5;
1064        assert_eq!(
1065            join_table_borders(left_table_column_count, &projections),
1066            (2, 7)
1067        );
1068
1069        let left_table_column_count = 7;
1070        assert_eq!(
1071            join_table_borders(left_table_column_count, &projections),
1072            (6, 7)
1073        );
1074
1075        Ok(())
1076    }
1077
1078    #[tokio::test]
1079    async fn project_no_column() -> Result<()> {
1080        let task_ctx = Arc::new(TaskContext::default());
1081
1082        let exec = test::scan_partitioned(1);
1083        let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?;
1084
1085        let projection = ProjectionExec::try_new(vec![], exec)?;
1086        let stream = projection.execute(0, Arc::clone(&task_ctx))?;
1087        let output = collect(stream).await?;
1088        assert_eq!(output.len(), expected.len());
1089
1090        Ok(())
1091    }
1092
1093    fn get_stats() -> Statistics {
1094        Statistics {
1095            num_rows: Precision::Exact(5),
1096            total_byte_size: Precision::Exact(23),
1097            column_statistics: vec![
1098                ColumnStatistics {
1099                    distinct_count: Precision::Exact(5),
1100                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1101                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1102                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1103                    null_count: Precision::Exact(0),
1104                },
1105                ColumnStatistics {
1106                    distinct_count: Precision::Exact(1),
1107                    max_value: Precision::Exact(ScalarValue::from("x")),
1108                    min_value: Precision::Exact(ScalarValue::from("a")),
1109                    sum_value: Precision::Absent,
1110                    null_count: Precision::Exact(3),
1111                },
1112                ColumnStatistics {
1113                    distinct_count: Precision::Absent,
1114                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
1115                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
1116                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
1117                    null_count: Precision::Absent,
1118                },
1119            ],
1120        }
1121    }
1122
1123    fn get_schema() -> Schema {
1124        let field_0 = Field::new("col0", DataType::Int64, false);
1125        let field_1 = Field::new("col1", DataType::Utf8, false);
1126        let field_2 = Field::new("col2", DataType::Float32, false);
1127        Schema::new(vec![field_0, field_1, field_2])
1128    }
1129    #[tokio::test]
1130    async fn test_stats_projection_columns_only() {
1131        let source = get_stats();
1132        let schema = get_schema();
1133
1134        let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
1135            Arc::new(Column::new("col1", 1)),
1136            Arc::new(Column::new("col0", 0)),
1137        ];
1138
1139        let result = stats_projection(source, exprs.into_iter(), Arc::new(schema));
1140
1141        let expected = Statistics {
1142            num_rows: Precision::Exact(5),
1143            total_byte_size: Precision::Exact(23),
1144            column_statistics: vec![
1145                ColumnStatistics {
1146                    distinct_count: Precision::Exact(1),
1147                    max_value: Precision::Exact(ScalarValue::from("x")),
1148                    min_value: Precision::Exact(ScalarValue::from("a")),
1149                    sum_value: Precision::Absent,
1150                    null_count: Precision::Exact(3),
1151                },
1152                ColumnStatistics {
1153                    distinct_count: Precision::Exact(5),
1154                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1155                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1156                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1157                    null_count: Precision::Exact(0),
1158                },
1159            ],
1160        };
1161
1162        assert_eq!(result, expected);
1163    }
1164
1165    #[tokio::test]
1166    async fn test_stats_projection_column_with_primitive_width_only() {
1167        let source = get_stats();
1168        let schema = get_schema();
1169
1170        let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
1171            Arc::new(Column::new("col2", 2)),
1172            Arc::new(Column::new("col0", 0)),
1173        ];
1174
1175        let result = stats_projection(source, exprs.into_iter(), Arc::new(schema));
1176
1177        let expected = Statistics {
1178            num_rows: Precision::Exact(5),
1179            total_byte_size: Precision::Exact(60),
1180            column_statistics: vec![
1181                ColumnStatistics {
1182                    distinct_count: Precision::Absent,
1183                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
1184                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
1185                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
1186                    null_count: Precision::Absent,
1187                },
1188                ColumnStatistics {
1189                    distinct_count: Precision::Exact(5),
1190                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1191                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1192                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1193                    null_count: Precision::Exact(0),
1194                },
1195            ],
1196        };
1197
1198        assert_eq!(result, expected);
1199    }
1200}