datafusion_physical_plan/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the projection execution plan. A projection determines which columns or expressions
19//! are returned from a query. The SQL statement `SELECT a, b, a+b FROM t1` is an example
20//! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the
21//! projection expressions. `SELECT` without `FROM` will only evaluate expressions.
22
23use super::expressions::{Column, Literal};
24use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
25use super::{
26    DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
27    SendableRecordBatchStream, Statistics,
28};
29use crate::execution_plan::CardinalityEffect;
30use crate::filter_pushdown::{
31    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
32    FilterPushdownPropagation,
33};
34use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef};
35use crate::{ColumnStatistics, DisplayFormatType, ExecutionPlan, PhysicalExpr};
36use std::any::Any;
37use std::collections::HashMap;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41
42use arrow::datatypes::{Field, Schema, SchemaRef};
43use arrow::record_batch::{RecordBatch, RecordBatchOptions};
44use datafusion_common::config::ConfigOptions;
45use datafusion_common::stats::Precision;
46use datafusion_common::tree_node::{
47    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
48};
49use datafusion_common::{internal_err, JoinSide, Result};
50use datafusion_execution::TaskContext;
51use datafusion_physical_expr::equivalence::ProjectionMapping;
52use datafusion_physical_expr::utils::collect_columns;
53use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
54use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
55
56use futures::stream::{Stream, StreamExt};
57use log::trace;
58
59/// [`ExecutionPlan`] for a projection
60///
61/// Computes a set of scalar value expressions for each input row, producing one
62/// output row for each input row.
63#[derive(Debug, Clone)]
64pub struct ProjectionExec {
65    /// The projection expressions stored as tuples of (expression, output column name)
66    pub(crate) expr: Vec<ProjectionExpr>,
67    /// The schema once the projection has been applied to the input
68    schema: SchemaRef,
69    /// The input plan
70    input: Arc<dyn ExecutionPlan>,
71    /// Execution metrics
72    metrics: ExecutionPlanMetricsSet,
73    /// Cache holding plan properties like equivalences, output partitioning etc.
74    cache: PlanProperties,
75}
76
77impl ProjectionExec {
78    /// Create a projection on an input
79    ///
80    /// # Example:
81    /// Create a `ProjectionExec` to crate `SELECT a, a+b AS sum_ab FROM t1`:
82    ///
83    /// ```
84    /// # use std::sync::Arc;
85    /// # use arrow_schema::{Schema, Field, DataType};
86    /// # use datafusion_expr::Operator;
87    /// # use datafusion_physical_plan::ExecutionPlan;
88    /// # use datafusion_physical_expr::expressions::{col, binary};
89    /// # use datafusion_physical_plan::empty::EmptyExec;
90    /// # use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
91    /// # fn schema() -> Arc<Schema> {
92    /// #  Arc::new(Schema::new(vec![
93    /// #   Field::new("a", DataType::Int32, false),
94    /// #   Field::new("b", DataType::Int32, false),
95    /// # ]))
96    /// # }
97    /// #
98    /// # fn input() -> Arc<dyn ExecutionPlan> {
99    /// #  Arc::new(EmptyExec::new(schema()))
100    /// # }
101    /// #
102    /// # fn main() {
103    /// let schema = schema();
104    /// // Create PhysicalExprs
105    /// let a = col("a", &schema).unwrap();
106    /// let b = col("b", &schema).unwrap();
107    /// let a_plus_b = binary(Arc::clone(&a), Operator::Plus, b, &schema).unwrap();
108    /// // create ProjectionExec
109    /// let proj = ProjectionExec::try_new([
110    ///     ProjectionExpr {
111    ///       // expr a produces the column named "a"
112    ///       expr: a,
113    ///       alias: "a".to_string(),
114    ///     },
115    ///     ProjectionExpr {
116    ///       // expr: a + b produces the column named "sum_ab"
117    ///       expr: a_plus_b,
118    ///       alias: "sum_ab".to_string(),
119    ///     }
120    ///   ], input()).unwrap();
121    /// # }
122    /// ```
123    pub fn try_new<I, E>(expr: I, input: Arc<dyn ExecutionPlan>) -> Result<Self>
124    where
125        I: IntoIterator<Item = E>,
126        E: Into<ProjectionExpr>,
127    {
128        let input_schema = input.schema();
129        // convert argument to Vec<ProjectionExpr>
130        let expr = expr.into_iter().map(Into::into).collect::<Vec<_>>();
131
132        let fields: Result<Vec<Field>> = expr
133            .iter()
134            .map(|proj_expr| {
135                let metadata = proj_expr
136                    .expr
137                    .return_field(&input_schema)?
138                    .metadata()
139                    .clone();
140
141                let field = Field::new(
142                    &proj_expr.alias,
143                    proj_expr.expr.data_type(&input_schema)?,
144                    proj_expr.expr.nullable(&input_schema)?,
145                )
146                .with_metadata(metadata);
147
148                Ok(field)
149            })
150            .collect();
151
152        let schema = Arc::new(Schema::new_with_metadata(
153            fields?,
154            input_schema.metadata().clone(),
155        ));
156
157        // Construct a map from the input expressions to the output expression of the Projection
158        let projection_mapping = ProjectionMapping::try_new(
159            expr.iter().map(|p| (Arc::clone(&p.expr), p.alias.clone())),
160            &input_schema,
161        )?;
162        let cache =
163            Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?;
164        Ok(Self {
165            expr,
166            schema,
167            input,
168            metrics: ExecutionPlanMetricsSet::new(),
169            cache,
170        })
171    }
172
173    /// The projection expressions stored as tuples of (expression, output column name)
174    pub fn expr(&self) -> &[ProjectionExpr] {
175        &self.expr
176    }
177
178    /// The input plan
179    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
180        &self.input
181    }
182
183    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
184    fn compute_properties(
185        input: &Arc<dyn ExecutionPlan>,
186        projection_mapping: &ProjectionMapping,
187        schema: SchemaRef,
188    ) -> Result<PlanProperties> {
189        // Calculate equivalence properties:
190        let input_eq_properties = input.equivalence_properties();
191        let eq_properties = input_eq_properties.project(projection_mapping, schema);
192        // Calculate output partitioning, which needs to respect aliases:
193        let output_partitioning = input
194            .output_partitioning()
195            .project(projection_mapping, input_eq_properties);
196
197        Ok(PlanProperties::new(
198            eq_properties,
199            output_partitioning,
200            input.pipeline_behavior(),
201            input.boundedness(),
202        ))
203    }
204}
205
206/// A projection expression that is created by [`ProjectionExec`]
207///
208/// The expression is evaluated and the result is stored in a column
209/// with the name specified by `alias`.
210///
211/// For example, the SQL expression `a + b AS sum_ab` would be represented
212/// as a `ProjectionExpr` where `expr` is the expression `a + b`
213/// and `alias` is the string `sum_ab`.
214#[derive(Debug, Clone)]
215pub struct ProjectionExpr {
216    /// The expression that will be evaluated.
217    pub expr: Arc<dyn PhysicalExpr>,
218    /// The name of the output column for use an output schema.
219    pub alias: String,
220}
221
222impl ProjectionExpr {
223    /// Create a new projection expression
224    pub fn new(expr: Arc<dyn PhysicalExpr>, alias: String) -> Self {
225        Self { expr, alias }
226    }
227}
228
229impl From<(Arc<dyn PhysicalExpr>, String)> for ProjectionExpr {
230    fn from(value: (Arc<dyn PhysicalExpr>, String)) -> Self {
231        Self::new(value.0, value.1)
232    }
233}
234
235impl DisplayAs for ProjectionExec {
236    fn fmt_as(
237        &self,
238        t: DisplayFormatType,
239        f: &mut std::fmt::Formatter,
240    ) -> std::fmt::Result {
241        match t {
242            DisplayFormatType::Default | DisplayFormatType::Verbose => {
243                let expr: Vec<String> = self
244                    .expr
245                    .iter()
246                    .map(|proj_expr| {
247                        let e = proj_expr.expr.to_string();
248                        if e != proj_expr.alias {
249                            format!("{e} as {}", proj_expr.alias)
250                        } else {
251                            e
252                        }
253                    })
254                    .collect();
255
256                write!(f, "ProjectionExec: expr=[{}]", expr.join(", "))
257            }
258            DisplayFormatType::TreeRender => {
259                for (i, proj_expr) in self.expr().iter().enumerate() {
260                    let expr_sql = fmt_sql(proj_expr.expr.as_ref());
261                    if proj_expr.expr.to_string() == proj_expr.alias {
262                        writeln!(f, "expr{i}={expr_sql}")?;
263                    } else {
264                        writeln!(f, "{}={expr_sql}", proj_expr.alias)?;
265                    }
266                }
267
268                Ok(())
269            }
270        }
271    }
272}
273
274impl ExecutionPlan for ProjectionExec {
275    fn name(&self) -> &'static str {
276        "ProjectionExec"
277    }
278
279    /// Return a reference to Any that can be used for downcasting
280    fn as_any(&self) -> &dyn Any {
281        self
282    }
283
284    fn properties(&self) -> &PlanProperties {
285        &self.cache
286    }
287
288    fn maintains_input_order(&self) -> Vec<bool> {
289        // Tell optimizer this operator doesn't reorder its input
290        vec![true]
291    }
292
293    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
294        let all_simple_exprs = self.expr.iter().all(|proj_expr| {
295            proj_expr.expr.as_any().is::<Column>()
296                || proj_expr.expr.as_any().is::<Literal>()
297        });
298        // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename,
299        // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false.
300        vec![!all_simple_exprs]
301    }
302
303    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
304        vec![&self.input]
305    }
306
307    fn with_new_children(
308        self: Arc<Self>,
309        mut children: Vec<Arc<dyn ExecutionPlan>>,
310    ) -> Result<Arc<dyn ExecutionPlan>> {
311        ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0))
312            .map(|p| Arc::new(p) as _)
313    }
314
315    fn execute(
316        &self,
317        partition: usize,
318        context: Arc<TaskContext>,
319    ) -> Result<SendableRecordBatchStream> {
320        trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
321        Ok(Box::pin(ProjectionStream {
322            schema: Arc::clone(&self.schema),
323            expr: self.expr.iter().map(|x| Arc::clone(&x.expr)).collect(),
324            input: self.input.execute(partition, context)?,
325            baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
326        }))
327    }
328
329    fn metrics(&self) -> Option<MetricsSet> {
330        Some(self.metrics.clone_inner())
331    }
332
333    fn statistics(&self) -> Result<Statistics> {
334        self.partition_statistics(None)
335    }
336
337    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
338        let input_stats = self.input.partition_statistics(partition)?;
339        stats_projection(
340            input_stats,
341            self.expr
342                .iter()
343                .map(|proj_expr| Arc::clone(&proj_expr.expr)),
344            Arc::clone(&self.input.schema()),
345        )
346    }
347
348    fn supports_limit_pushdown(&self) -> bool {
349        true
350    }
351
352    fn cardinality_effect(&self) -> CardinalityEffect {
353        CardinalityEffect::Equal
354    }
355
356    fn try_swapping_with_projection(
357        &self,
358        projection: &ProjectionExec,
359    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
360        let maybe_unified = try_unifying_projections(projection, self)?;
361        if let Some(new_plan) = maybe_unified {
362            // To unify 3 or more sequential projections:
363            remove_unnecessary_projections(new_plan).data().map(Some)
364        } else {
365            Ok(Some(Arc::new(projection.clone())))
366        }
367    }
368
369    fn gather_filters_for_pushdown(
370        &self,
371        _phase: FilterPushdownPhase,
372        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
373        _config: &ConfigOptions,
374    ) -> Result<FilterDescription> {
375        // TODO: In future, we can try to handle inverting aliases here.
376        // For the time being, we pass through untransformed filters, so filters on aliases are not handled.
377        // https://github.com/apache/datafusion/issues/17246
378        FilterDescription::from_children(parent_filters, &self.children())
379    }
380
381    fn handle_child_pushdown_result(
382        &self,
383        _phase: FilterPushdownPhase,
384        child_pushdown_result: ChildPushdownResult,
385        _config: &ConfigOptions,
386    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
387        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
388    }
389}
390
391fn stats_projection(
392    mut stats: Statistics,
393    exprs: impl Iterator<Item = Arc<dyn PhysicalExpr>>,
394    schema: SchemaRef,
395) -> Result<Statistics> {
396    let mut primitive_row_size = 0;
397    let mut primitive_row_size_possible = true;
398    let mut column_statistics = vec![];
399    for expr in exprs {
400        let col_stats = if let Some(col) = expr.as_any().downcast_ref::<Column>() {
401            stats.column_statistics[col.index()].clone()
402        } else {
403            // TODO stats: estimate more statistics from expressions
404            // (expressions should compute their statistics themselves)
405            ColumnStatistics::new_unknown()
406        };
407        column_statistics.push(col_stats);
408        let data_type = expr.data_type(&schema)?;
409        if let Some(value) = data_type.primitive_width() {
410            primitive_row_size += value;
411            continue;
412        }
413        primitive_row_size_possible = false;
414    }
415
416    if primitive_row_size_possible {
417        stats.total_byte_size =
418            Precision::Exact(primitive_row_size).multiply(&stats.num_rows);
419    }
420    stats.column_statistics = column_statistics;
421    Ok(stats)
422}
423
424impl ProjectionStream {
425    fn batch_project(&self, batch: &RecordBatch) -> Result<RecordBatch> {
426        // Records time on drop
427        let _timer = self.baseline_metrics.elapsed_compute().timer();
428        let arrays = self
429            .expr
430            .iter()
431            .map(|expr| {
432                expr.evaluate(batch)
433                    .and_then(|v| v.into_array(batch.num_rows()))
434            })
435            .collect::<Result<Vec<_>>>()?;
436
437        if arrays.is_empty() {
438            let options =
439                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
440            RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options)
441                .map_err(Into::into)
442        } else {
443            RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into)
444        }
445    }
446}
447
448/// Projection iterator
449struct ProjectionStream {
450    schema: SchemaRef,
451    expr: Vec<Arc<dyn PhysicalExpr>>,
452    input: SendableRecordBatchStream,
453    baseline_metrics: BaselineMetrics,
454}
455
456impl Stream for ProjectionStream {
457    type Item = Result<RecordBatch>;
458
459    fn poll_next(
460        mut self: Pin<&mut Self>,
461        cx: &mut Context<'_>,
462    ) -> Poll<Option<Self::Item>> {
463        let poll = self.input.poll_next_unpin(cx).map(|x| match x {
464            Some(Ok(batch)) => Some(self.batch_project(&batch)),
465            other => other,
466        });
467
468        self.baseline_metrics.record_poll(poll)
469    }
470
471    fn size_hint(&self) -> (usize, Option<usize>) {
472        // Same number of record batches
473        self.input.size_hint()
474    }
475}
476
477impl RecordBatchStream for ProjectionStream {
478    /// Get the schema
479    fn schema(&self) -> SchemaRef {
480        Arc::clone(&self.schema)
481    }
482}
483
484pub trait EmbeddedProjection: ExecutionPlan + Sized {
485    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self>;
486}
487
488/// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later.
489/// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column) and avoid unnecessary output creation.
490pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
491    projection: &ProjectionExec,
492    execution_plan: &Exec,
493) -> Result<Option<Arc<dyn ExecutionPlan>>> {
494    // Collect all column indices from the given projection expressions.
495    let projection_index = collect_column_indices(projection.expr());
496
497    if projection_index.is_empty() {
498        return Ok(None);
499    };
500
501    // If the projection indices is the same as the input columns, we don't need to embed the projection to hash join.
502    // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields.
503    if projection_index.len() == projection_index.last().unwrap() + 1
504        && projection_index.len() == execution_plan.schema().fields().len()
505    {
506        return Ok(None);
507    }
508
509    let new_execution_plan =
510        Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?);
511
512    // Build projection expressions for update_expr. Zip the projection_index with the new_execution_plan output schema fields.
513    let embed_project_exprs = projection_index
514        .iter()
515        .zip(new_execution_plan.schema().fields())
516        .map(|(index, field)| ProjectionExpr {
517            expr: Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
518            alias: field.name().to_owned(),
519        })
520        .collect::<Vec<_>>();
521
522    let mut new_projection_exprs = Vec::with_capacity(projection.expr().len());
523
524    for proj_expr in projection.expr() {
525        // update column index for projection expression since the input schema has been changed.
526        let Some(expr) =
527            update_expr(&proj_expr.expr, embed_project_exprs.as_slice(), false)?
528        else {
529            return Ok(None);
530        };
531        new_projection_exprs.push(ProjectionExpr {
532            expr,
533            alias: proj_expr.alias.clone(),
534        });
535    }
536    // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection.
537    let new_projection = Arc::new(ProjectionExec::try_new(
538        new_projection_exprs,
539        Arc::clone(&new_execution_plan) as _,
540    )?);
541    if is_projection_removable(&new_projection) {
542        Ok(Some(new_execution_plan))
543    } else {
544        Ok(Some(new_projection))
545    }
546}
547
548pub struct JoinData {
549    pub projected_left_child: ProjectionExec,
550    pub projected_right_child: ProjectionExec,
551    pub join_filter: Option<JoinFilter>,
552    pub join_on: JoinOn,
553}
554
555pub fn try_pushdown_through_join(
556    projection: &ProjectionExec,
557    join_left: &Arc<dyn ExecutionPlan>,
558    join_right: &Arc<dyn ExecutionPlan>,
559    join_on: JoinOnRef,
560    schema: SchemaRef,
561    filter: Option<&JoinFilter>,
562) -> Result<Option<JoinData>> {
563    // Convert projected expressions to columns. We can not proceed if this is not possible.
564    let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else {
565        return Ok(None);
566    };
567
568    let (far_right_left_col_ind, far_left_right_col_ind) =
569        join_table_borders(join_left.schema().fields().len(), &projection_as_columns);
570
571    if !join_allows_pushdown(
572        &projection_as_columns,
573        &schema,
574        far_right_left_col_ind,
575        far_left_right_col_ind,
576    ) {
577        return Ok(None);
578    }
579
580    let new_filter = if let Some(filter) = filter {
581        match update_join_filter(
582            &projection_as_columns[0..=far_right_left_col_ind as _],
583            &projection_as_columns[far_left_right_col_ind as _..],
584            filter,
585            join_left.schema().fields().len(),
586        ) {
587            Some(updated_filter) => Some(updated_filter),
588            None => return Ok(None),
589        }
590    } else {
591        None
592    };
593
594    let Some(new_on) = update_join_on(
595        &projection_as_columns[0..=far_right_left_col_ind as _],
596        &projection_as_columns[far_left_right_col_ind as _..],
597        join_on,
598        join_left.schema().fields().len(),
599    ) else {
600        return Ok(None);
601    };
602
603    let (new_left, new_right) = new_join_children(
604        &projection_as_columns,
605        far_right_left_col_ind,
606        far_left_right_col_ind,
607        join_left,
608        join_right,
609    )?;
610
611    Ok(Some(JoinData {
612        projected_left_child: new_left,
613        projected_right_child: new_right,
614        join_filter: new_filter,
615        join_on: new_on,
616    }))
617}
618
619/// This function checks if `plan` is a [`ProjectionExec`], and inspects its
620/// input(s) to test whether it can push `plan` under its input(s). This function
621/// will operate on the entire tree and may ultimately remove `plan` entirely
622/// by leveraging source providers with built-in projection capabilities.
623pub fn remove_unnecessary_projections(
624    plan: Arc<dyn ExecutionPlan>,
625) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
626    let maybe_modified =
627        if let Some(projection) = plan.as_any().downcast_ref::<ProjectionExec>() {
628            // If the projection does not cause any change on the input, we can
629            // safely remove it:
630            if is_projection_removable(projection) {
631                return Ok(Transformed::yes(Arc::clone(projection.input())));
632            }
633            // If it does, check if we can push it under its child(ren):
634            projection
635                .input()
636                .try_swapping_with_projection(projection)?
637        } else {
638            return Ok(Transformed::no(plan));
639        };
640    Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes))
641}
642
643/// Compare the inputs and outputs of the projection. All expressions must be
644/// columns without alias, and projection does not change the order of fields.
645/// For example, if the input schema is `a, b`, `SELECT a, b` is removable,
646/// but `SELECT b, a` and `SELECT a+1, b` and `SELECT a AS c, b` are not.
647fn is_projection_removable(projection: &ProjectionExec) -> bool {
648    let exprs = projection.expr();
649    exprs.iter().enumerate().all(|(idx, proj_expr)| {
650        let Some(col) = proj_expr.expr.as_any().downcast_ref::<Column>() else {
651            return false;
652        };
653        col.name() == proj_expr.alias && col.index() == idx
654    }) && exprs.len() == projection.input().schema().fields().len()
655}
656
657/// Given the expression set of a projection, checks if the projection causes
658/// any renaming or constructs a non-`Column` physical expression.
659pub fn all_alias_free_columns(exprs: &[ProjectionExpr]) -> bool {
660    exprs.iter().all(|proj_expr| {
661        proj_expr
662            .expr
663            .as_any()
664            .downcast_ref::<Column>()
665            .map(|column| column.name() == proj_expr.alias)
666            .unwrap_or(false)
667    })
668}
669
670/// Updates a source provider's projected columns according to the given
671/// projection operator's expressions. To use this function safely, one must
672/// ensure that all expressions are `Column` expressions without aliases.
673pub fn new_projections_for_columns(
674    projection: &[ProjectionExpr],
675    source: &[usize],
676) -> Vec<usize> {
677    projection
678        .iter()
679        .filter_map(|proj_expr| {
680            proj_expr
681                .expr
682                .as_any()
683                .downcast_ref::<Column>()
684                .map(|expr| source[expr.index()])
685        })
686        .collect()
687}
688
689/// Creates a new [`ProjectionExec`] instance with the given child plan and
690/// projected expressions.
691pub fn make_with_child(
692    projection: &ProjectionExec,
693    child: &Arc<dyn ExecutionPlan>,
694) -> Result<Arc<dyn ExecutionPlan>> {
695    ProjectionExec::try_new(projection.expr().to_vec(), Arc::clone(child))
696        .map(|e| Arc::new(e) as _)
697}
698
699/// Returns `true` if all the expressions in the argument are `Column`s.
700pub fn all_columns(exprs: &[ProjectionExpr]) -> bool {
701    exprs
702        .iter()
703        .all(|proj_expr| proj_expr.expr.as_any().is::<Column>())
704}
705
706/// The function operates in two modes:
707///
708/// 1) When `sync_with_child` is `true`:
709///
710///    The function updates the indices of `expr` if the expression resides
711///    in the input plan. For instance, given the expressions `a@1 + b@2`
712///    and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are
713///    updated to `a@0 + b@1` and `c@2`.
714///
715/// 2) When `sync_with_child` is `false`:
716///
717///    The function determines how the expression would be updated if a projection
718///    was placed before the plan associated with the expression. If the expression
719///    cannot be rewritten after the projection, it returns `None`. For example,
720///    given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with
721///    an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes
722///    `a@0`, but `b@2` results in `None` since the projection does not include `b`.
723pub fn update_expr(
724    expr: &Arc<dyn PhysicalExpr>,
725    projected_exprs: &[ProjectionExpr],
726    sync_with_child: bool,
727) -> Result<Option<Arc<dyn PhysicalExpr>>> {
728    #[derive(Debug, PartialEq)]
729    enum RewriteState {
730        /// The expression is unchanged.
731        Unchanged,
732        /// Some part of the expression has been rewritten
733        RewrittenValid,
734        /// Some part of the expression has been rewritten, but some column
735        /// references could not be.
736        RewrittenInvalid,
737    }
738
739    let mut state = RewriteState::Unchanged;
740
741    let new_expr = Arc::clone(expr)
742        .transform_up(|expr| {
743            if state == RewriteState::RewrittenInvalid {
744                return Ok(Transformed::no(expr));
745            }
746
747            let Some(column) = expr.as_any().downcast_ref::<Column>() else {
748                return Ok(Transformed::no(expr));
749            };
750            if sync_with_child {
751                state = RewriteState::RewrittenValid;
752                // Update the index of `column`:
753                Ok(Transformed::yes(Arc::clone(
754                    &projected_exprs[column.index()].expr,
755                )))
756            } else {
757                // default to invalid, in case we can't find the relevant column
758                state = RewriteState::RewrittenInvalid;
759                // Determine how to update `column` to accommodate `projected_exprs`
760                projected_exprs
761                    .iter()
762                    .enumerate()
763                    .find_map(|(index, proj_expr)| {
764                        proj_expr.expr.as_any().downcast_ref::<Column>().and_then(
765                            |projected_column| {
766                                (column.name().eq(projected_column.name())
767                                    && column.index() == projected_column.index())
768                                .then(|| {
769                                    state = RewriteState::RewrittenValid;
770                                    Arc::new(Column::new(&proj_expr.alias, index)) as _
771                                })
772                            },
773                        )
774                    })
775                    .map_or_else(
776                        || Ok(Transformed::no(expr)),
777                        |c| Ok(Transformed::yes(c)),
778                    )
779            }
780        })
781        .data();
782
783    new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e))
784}
785
786/// Updates the given lexicographic ordering according to given projected
787/// expressions using the [`update_expr`] function.
788pub fn update_ordering(
789    ordering: LexOrdering,
790    projected_exprs: &[ProjectionExpr],
791) -> Result<Option<LexOrdering>> {
792    let mut updated_exprs = vec![];
793    for mut sort_expr in ordering.into_iter() {
794        let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)?
795        else {
796            return Ok(None);
797        };
798        sort_expr.expr = updated_expr;
799        updated_exprs.push(sort_expr);
800    }
801    Ok(LexOrdering::new(updated_exprs))
802}
803
804/// Updates the given lexicographic requirement according to given projected
805/// expressions using the [`update_expr`] function.
806pub fn update_ordering_requirement(
807    reqs: LexRequirement,
808    projected_exprs: &[ProjectionExpr],
809) -> Result<Option<LexRequirement>> {
810    let mut updated_exprs = vec![];
811    for mut sort_expr in reqs.into_iter() {
812        let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)?
813        else {
814            return Ok(None);
815        };
816        sort_expr.expr = updated_expr;
817        updated_exprs.push(sort_expr);
818    }
819    Ok(LexRequirement::new(updated_exprs))
820}
821
822/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given
823/// expressions is not a `Column`, returns `None`.
824pub fn physical_to_column_exprs(
825    exprs: &[ProjectionExpr],
826) -> Option<Vec<(Column, String)>> {
827    exprs
828        .iter()
829        .map(|proj_expr| {
830            proj_expr
831                .expr
832                .as_any()
833                .downcast_ref::<Column>()
834                .map(|col| (col.clone(), proj_expr.alias.clone()))
835        })
836        .collect()
837}
838
839/// If pushing down the projection over this join's children seems possible,
840/// this function constructs the new [`ProjectionExec`]s that will come on top
841/// of the original children of the join.
842pub fn new_join_children(
843    projection_as_columns: &[(Column, String)],
844    far_right_left_col_ind: i32,
845    far_left_right_col_ind: i32,
846    left_child: &Arc<dyn ExecutionPlan>,
847    right_child: &Arc<dyn ExecutionPlan>,
848) -> Result<(ProjectionExec, ProjectionExec)> {
849    let new_left = ProjectionExec::try_new(
850        projection_as_columns[0..=far_right_left_col_ind as _]
851            .iter()
852            .map(|(col, alias)| ProjectionExpr {
853                expr: Arc::new(Column::new(col.name(), col.index())) as _,
854                alias: alias.clone(),
855            }),
856        Arc::clone(left_child),
857    )?;
858    let left_size = left_child.schema().fields().len() as i32;
859    let new_right = ProjectionExec::try_new(
860        projection_as_columns[far_left_right_col_ind as _..]
861            .iter()
862            .map(|(col, alias)| {
863                ProjectionExpr {
864                    expr: Arc::new(Column::new(
865                        col.name(),
866                        // Align projected expressions coming from the right
867                        // table with the new right child projection:
868                        (col.index() as i32 - left_size) as _,
869                    )) as _,
870                    alias: alias.clone(),
871                }
872            }),
873        Arc::clone(right_child),
874    )?;
875
876    Ok((new_left, new_right))
877}
878
879/// Checks three conditions for pushing a projection down through a join:
880/// - Projection must narrow the join output schema.
881/// - Columns coming from left/right tables must be collected at the left/right
882///   sides of the output table.
883/// - Left or right table is not lost after the projection.
884pub fn join_allows_pushdown(
885    projection_as_columns: &[(Column, String)],
886    join_schema: &SchemaRef,
887    far_right_left_col_ind: i32,
888    far_left_right_col_ind: i32,
889) -> bool {
890    // Projection must narrow the join output:
891    projection_as_columns.len() < join_schema.fields().len()
892    // Are the columns from different tables mixed?
893    && (far_right_left_col_ind + 1 == far_left_right_col_ind)
894    // Left or right table is not lost after the projection.
895    && far_right_left_col_ind >= 0
896    && far_left_right_col_ind < projection_as_columns.len() as i32
897}
898
899/// Returns the last index before encountering a column coming from the right table when traveling
900/// through the projection from left to right, and the last index before encountering a column
901/// coming from the left table when traveling through the projection from right to left.
902/// If there is no column in the projection coming from the left side, it returns (-1, ...),
903/// if there is no column in the projection coming from the right side, it returns (..., projection length).
904pub fn join_table_borders(
905    left_table_column_count: usize,
906    projection_as_columns: &[(Column, String)],
907) -> (i32, i32) {
908    let far_right_left_col_ind = projection_as_columns
909        .iter()
910        .enumerate()
911        .take_while(|(_, (projection_column, _))| {
912            projection_column.index() < left_table_column_count
913        })
914        .last()
915        .map(|(index, _)| index as i32)
916        .unwrap_or(-1);
917
918    let far_left_right_col_ind = projection_as_columns
919        .iter()
920        .enumerate()
921        .rev()
922        .take_while(|(_, (projection_column, _))| {
923            projection_column.index() >= left_table_column_count
924        })
925        .last()
926        .map(|(index, _)| index as i32)
927        .unwrap_or(projection_as_columns.len() as i32);
928
929    (far_right_left_col_ind, far_left_right_col_ind)
930}
931
932/// Tries to update the equi-join `Column`'s of a join as if the input of
933/// the join was replaced by a projection.
934pub fn update_join_on(
935    proj_left_exprs: &[(Column, String)],
936    proj_right_exprs: &[(Column, String)],
937    hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
938    left_field_size: usize,
939) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
940    // TODO: Clippy wants the "map" call removed, but doing so generates
941    //       a compilation error. Remove the clippy directive once this
942    //       issue is fixed.
943    #[allow(clippy::map_identity)]
944    let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on
945        .iter()
946        .map(|(left, right)| (left, right))
947        .unzip();
948
949    let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs, 0);
950    let new_right_columns =
951        new_columns_for_join_on(&right_idx, proj_right_exprs, left_field_size);
952
953    match (new_left_columns, new_right_columns) {
954        (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()),
955        _ => None,
956    }
957}
958
959/// Tries to update the column indices of a [`JoinFilter`] as if the input of
960/// the join was replaced by a projection.
961pub fn update_join_filter(
962    projection_left_exprs: &[(Column, String)],
963    projection_right_exprs: &[(Column, String)],
964    join_filter: &JoinFilter,
965    left_field_size: usize,
966) -> Option<JoinFilter> {
967    let mut new_left_indices = new_indices_for_join_filter(
968        join_filter,
969        JoinSide::Left,
970        projection_left_exprs,
971        0,
972    )
973    .into_iter();
974    let mut new_right_indices = new_indices_for_join_filter(
975        join_filter,
976        JoinSide::Right,
977        projection_right_exprs,
978        left_field_size,
979    )
980    .into_iter();
981
982    // Check if all columns match:
983    (new_right_indices.len() + new_left_indices.len()
984        == join_filter.column_indices().len())
985    .then(|| {
986        JoinFilter::new(
987            Arc::clone(join_filter.expression()),
988            join_filter
989                .column_indices()
990                .iter()
991                .map(|col_idx| ColumnIndex {
992                    index: if col_idx.side == JoinSide::Left {
993                        new_left_indices.next().unwrap()
994                    } else {
995                        new_right_indices.next().unwrap()
996                    },
997                    side: col_idx.side,
998                })
999                .collect(),
1000            Arc::clone(join_filter.schema()),
1001        )
1002    })
1003}
1004
1005/// Unifies `projection` with its input (which is also a [`ProjectionExec`]).
1006fn try_unifying_projections(
1007    projection: &ProjectionExec,
1008    child: &ProjectionExec,
1009) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1010    let mut projected_exprs = vec![];
1011    let mut column_ref_map: HashMap<Column, usize> = HashMap::new();
1012
1013    // Collect the column references usage in the outer projection.
1014    projection.expr().iter().for_each(|proj_expr| {
1015        proj_expr
1016            .expr
1017            .apply(|expr| {
1018                Ok({
1019                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
1020                        *column_ref_map.entry(column.clone()).or_default() += 1;
1021                    }
1022                    TreeNodeRecursion::Continue
1023                })
1024            })
1025            .unwrap();
1026    });
1027    // Merging these projections is not beneficial, e.g
1028    // If an expression is not trivial and it is referred more than 1, unifies projections will be
1029    // beneficial as caching mechanism for non-trivial computations.
1030    // See discussion in: https://github.com/apache/datafusion/issues/8296
1031    if column_ref_map.iter().any(|(column, count)| {
1032        *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr))
1033    }) {
1034        return Ok(None);
1035    }
1036    for proj_expr in projection.expr() {
1037        // If there is no match in the input projection, we cannot unify these
1038        // projections. This case will arise if the projection expression contains
1039        // a `PhysicalExpr` variant `update_expr` doesn't support.
1040        let Some(expr) = update_expr(&proj_expr.expr, child.expr(), true)? else {
1041            return Ok(None);
1042        };
1043        projected_exprs.push(ProjectionExpr {
1044            expr,
1045            alias: proj_expr.alias.clone(),
1046        });
1047    }
1048    ProjectionExec::try_new(projected_exprs, Arc::clone(child.input()))
1049        .map(|e| Some(Arc::new(e) as _))
1050}
1051
1052/// Collect all column indices from the given projection expressions.
1053fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec<usize> {
1054    // Collect indices and remove duplicates.
1055    let mut indices = exprs
1056        .iter()
1057        .flat_map(|proj_expr| collect_columns(&proj_expr.expr))
1058        .map(|x| x.index())
1059        .collect::<std::collections::HashSet<_>>()
1060        .into_iter()
1061        .collect::<Vec<_>>();
1062    indices.sort();
1063    indices
1064}
1065
1066/// This function determines and returns a vector of indices representing the
1067/// positions of columns in `projection_exprs` that are involved in `join_filter`,
1068/// and correspond to a particular side (`join_side`) of the join operation.
1069///
1070/// Notes: Column indices in the projection expressions are based on the join schema,
1071/// whereas the join filter is based on the join child schema. `column_index_offset`
1072/// represents the offset between them.
1073fn new_indices_for_join_filter(
1074    join_filter: &JoinFilter,
1075    join_side: JoinSide,
1076    projection_exprs: &[(Column, String)],
1077    column_index_offset: usize,
1078) -> Vec<usize> {
1079    join_filter
1080        .column_indices()
1081        .iter()
1082        .filter(|col_idx| col_idx.side == join_side)
1083        .filter_map(|col_idx| {
1084            projection_exprs
1085                .iter()
1086                .position(|(col, _)| col_idx.index + column_index_offset == col.index())
1087        })
1088        .collect()
1089}
1090
1091/// This function generates a new set of columns to be used in a hash join
1092/// operation based on a set of equi-join conditions (`hash_join_on`) and a
1093/// list of projection expressions (`projection_exprs`).
1094///
1095/// Notes: Column indices in the projection expressions are based on the join schema,
1096/// whereas the join on expressions are based on the join child schema. `column_index_offset`
1097/// represents the offset between them.
1098fn new_columns_for_join_on(
1099    hash_join_on: &[&PhysicalExprRef],
1100    projection_exprs: &[(Column, String)],
1101    column_index_offset: usize,
1102) -> Option<Vec<PhysicalExprRef>> {
1103    let new_columns = hash_join_on
1104        .iter()
1105        .filter_map(|on| {
1106            // Rewrite all columns in `on`
1107            Arc::clone(*on)
1108                .transform(|expr| {
1109                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
1110                        // Find the column in the projection expressions
1111                        let new_column = projection_exprs
1112                            .iter()
1113                            .enumerate()
1114                            .find(|(_, (proj_column, _))| {
1115                                column.name() == proj_column.name()
1116                                    && column.index() + column_index_offset
1117                                        == proj_column.index()
1118                            })
1119                            .map(|(index, (_, alias))| Column::new(alias, index));
1120                        if let Some(new_column) = new_column {
1121                            Ok(Transformed::yes(Arc::new(new_column)))
1122                        } else {
1123                            // If the column is not found in the projection expressions,
1124                            // it means that the column is not projected. In this case,
1125                            // we cannot push the projection down.
1126                            internal_err!(
1127                                "Column {:?} not found in projection expressions",
1128                                column
1129                            )
1130                        }
1131                    } else {
1132                        Ok(Transformed::no(expr))
1133                    }
1134                })
1135                .data()
1136                .ok()
1137        })
1138        .collect::<Vec<_>>();
1139    (new_columns.len() == hash_join_on.len()).then_some(new_columns)
1140}
1141
1142/// Checks if the given expression is trivial.
1143/// An expression is considered trivial if it is either a `Column` or a `Literal`.
1144fn is_expr_trivial(expr: &Arc<dyn PhysicalExpr>) -> bool {
1145    expr.as_any().downcast_ref::<Column>().is_some()
1146        || expr.as_any().downcast_ref::<Literal>().is_some()
1147}
1148
1149#[cfg(test)]
1150mod tests {
1151    use super::*;
1152    use std::sync::Arc;
1153
1154    use crate::common::collect;
1155    use crate::test;
1156    use crate::test::exec::StatisticsExec;
1157
1158    use arrow::datatypes::{DataType, Field, Schema};
1159    use datafusion_common::stats::{ColumnStatistics, Precision, Statistics};
1160    use datafusion_common::ScalarValue;
1161
1162    use datafusion_expr::Operator;
1163    use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, Literal};
1164
1165    #[test]
1166    fn test_collect_column_indices() -> Result<()> {
1167        let expr = Arc::new(BinaryExpr::new(
1168            Arc::new(Column::new("b", 7)),
1169            Operator::Minus,
1170            Arc::new(BinaryExpr::new(
1171                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1172                Operator::Plus,
1173                Arc::new(Column::new("a", 1)),
1174            )),
1175        ));
1176        let column_indices = collect_column_indices(&[ProjectionExpr {
1177            expr,
1178            alias: "b-(1+a)".to_string(),
1179        }]);
1180        assert_eq!(column_indices, vec![1, 7]);
1181        Ok(())
1182    }
1183
1184    #[test]
1185    fn test_join_table_borders() -> Result<()> {
1186        let projections = vec![
1187            (Column::new("b", 1), "b".to_owned()),
1188            (Column::new("c", 2), "c".to_owned()),
1189            (Column::new("e", 4), "e".to_owned()),
1190            (Column::new("d", 3), "d".to_owned()),
1191            (Column::new("c", 2), "c".to_owned()),
1192            (Column::new("f", 5), "f".to_owned()),
1193            (Column::new("h", 7), "h".to_owned()),
1194            (Column::new("g", 6), "g".to_owned()),
1195        ];
1196        let left_table_column_count = 5;
1197        assert_eq!(
1198            join_table_borders(left_table_column_count, &projections),
1199            (4, 5)
1200        );
1201
1202        let left_table_column_count = 8;
1203        assert_eq!(
1204            join_table_borders(left_table_column_count, &projections),
1205            (7, 8)
1206        );
1207
1208        let left_table_column_count = 1;
1209        assert_eq!(
1210            join_table_borders(left_table_column_count, &projections),
1211            (-1, 0)
1212        );
1213
1214        let projections = vec![
1215            (Column::new("a", 0), "a".to_owned()),
1216            (Column::new("b", 1), "b".to_owned()),
1217            (Column::new("d", 3), "d".to_owned()),
1218            (Column::new("g", 6), "g".to_owned()),
1219            (Column::new("e", 4), "e".to_owned()),
1220            (Column::new("f", 5), "f".to_owned()),
1221            (Column::new("e", 4), "e".to_owned()),
1222            (Column::new("h", 7), "h".to_owned()),
1223        ];
1224        let left_table_column_count = 5;
1225        assert_eq!(
1226            join_table_borders(left_table_column_count, &projections),
1227            (2, 7)
1228        );
1229
1230        let left_table_column_count = 7;
1231        assert_eq!(
1232            join_table_borders(left_table_column_count, &projections),
1233            (6, 7)
1234        );
1235
1236        Ok(())
1237    }
1238
1239    #[tokio::test]
1240    async fn project_no_column() -> Result<()> {
1241        let task_ctx = Arc::new(TaskContext::default());
1242
1243        let exec = test::scan_partitioned(1);
1244        let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?;
1245
1246        let projection = ProjectionExec::try_new(vec![] as Vec<ProjectionExpr>, exec)?;
1247        let stream = projection.execute(0, Arc::clone(&task_ctx))?;
1248        let output = collect(stream).await?;
1249        assert_eq!(output.len(), expected.len());
1250
1251        Ok(())
1252    }
1253
1254    #[tokio::test]
1255    async fn project_old_syntax() {
1256        let exec = test::scan_partitioned(1);
1257        let schema = exec.schema();
1258        let expr = col("i", &schema).unwrap();
1259        ProjectionExec::try_new(
1260            vec![
1261                // use From impl of ProjectionExpr to create ProjectionExpr
1262                // to test old syntax
1263                (expr, "c".to_string()),
1264            ],
1265            exec,
1266        )
1267        // expect this to succeed
1268        .unwrap();
1269    }
1270
1271    fn get_stats() -> Statistics {
1272        Statistics {
1273            num_rows: Precision::Exact(5),
1274            total_byte_size: Precision::Exact(23),
1275            column_statistics: vec![
1276                ColumnStatistics {
1277                    distinct_count: Precision::Exact(5),
1278                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1279                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1280                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1281                    null_count: Precision::Exact(0),
1282                },
1283                ColumnStatistics {
1284                    distinct_count: Precision::Exact(1),
1285                    max_value: Precision::Exact(ScalarValue::from("x")),
1286                    min_value: Precision::Exact(ScalarValue::from("a")),
1287                    sum_value: Precision::Absent,
1288                    null_count: Precision::Exact(3),
1289                },
1290                ColumnStatistics {
1291                    distinct_count: Precision::Absent,
1292                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
1293                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
1294                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
1295                    null_count: Precision::Absent,
1296                },
1297            ],
1298        }
1299    }
1300
1301    fn get_schema() -> Schema {
1302        let field_0 = Field::new("col0", DataType::Int64, false);
1303        let field_1 = Field::new("col1", DataType::Utf8, false);
1304        let field_2 = Field::new("col2", DataType::Float32, false);
1305        Schema::new(vec![field_0, field_1, field_2])
1306    }
1307    #[tokio::test]
1308    async fn test_stats_projection_columns_only() {
1309        let source = get_stats();
1310        let schema = get_schema();
1311
1312        let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
1313            Arc::new(Column::new("col1", 1)),
1314            Arc::new(Column::new("col0", 0)),
1315        ];
1316
1317        let result =
1318            stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap();
1319
1320        let expected = Statistics {
1321            num_rows: Precision::Exact(5),
1322            total_byte_size: Precision::Exact(23),
1323            column_statistics: vec![
1324                ColumnStatistics {
1325                    distinct_count: Precision::Exact(1),
1326                    max_value: Precision::Exact(ScalarValue::from("x")),
1327                    min_value: Precision::Exact(ScalarValue::from("a")),
1328                    sum_value: Precision::Absent,
1329                    null_count: Precision::Exact(3),
1330                },
1331                ColumnStatistics {
1332                    distinct_count: Precision::Exact(5),
1333                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1334                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1335                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1336                    null_count: Precision::Exact(0),
1337                },
1338            ],
1339        };
1340
1341        assert_eq!(result, expected);
1342    }
1343
1344    #[tokio::test]
1345    async fn test_stats_projection_column_with_primitive_width_only() {
1346        let source = get_stats();
1347        let schema = get_schema();
1348
1349        let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
1350            Arc::new(Column::new("col2", 2)),
1351            Arc::new(Column::new("col0", 0)),
1352        ];
1353
1354        let result =
1355            stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap();
1356
1357        let expected = Statistics {
1358            num_rows: Precision::Exact(5),
1359            total_byte_size: Precision::Exact(60),
1360            column_statistics: vec![
1361                ColumnStatistics {
1362                    distinct_count: Precision::Absent,
1363                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
1364                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
1365                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
1366                    null_count: Precision::Absent,
1367                },
1368                ColumnStatistics {
1369                    distinct_count: Precision::Exact(5),
1370                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1371                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1372                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1373                    null_count: Precision::Exact(0),
1374                },
1375            ],
1376        };
1377
1378        assert_eq!(result, expected);
1379    }
1380
1381    #[test]
1382    fn test_projection_statistics_uses_input_schema() {
1383        let input_schema = Schema::new(vec![
1384            Field::new("a", DataType::Int32, false),
1385            Field::new("b", DataType::Int32, false),
1386            Field::new("c", DataType::Int32, false),
1387            Field::new("d", DataType::Int32, false),
1388            Field::new("e", DataType::Int32, false),
1389            Field::new("f", DataType::Int32, false),
1390        ]);
1391
1392        let input_statistics = Statistics {
1393            num_rows: Precision::Exact(10),
1394            column_statistics: vec![
1395                ColumnStatistics {
1396                    min_value: Precision::Exact(ScalarValue::Int32(Some(1))),
1397                    max_value: Precision::Exact(ScalarValue::Int32(Some(100))),
1398                    ..Default::default()
1399                },
1400                ColumnStatistics {
1401                    min_value: Precision::Exact(ScalarValue::Int32(Some(5))),
1402                    max_value: Precision::Exact(ScalarValue::Int32(Some(50))),
1403                    ..Default::default()
1404                },
1405                ColumnStatistics {
1406                    min_value: Precision::Exact(ScalarValue::Int32(Some(10))),
1407                    max_value: Precision::Exact(ScalarValue::Int32(Some(40))),
1408                    ..Default::default()
1409                },
1410                ColumnStatistics {
1411                    min_value: Precision::Exact(ScalarValue::Int32(Some(20))),
1412                    max_value: Precision::Exact(ScalarValue::Int32(Some(30))),
1413                    ..Default::default()
1414                },
1415                ColumnStatistics {
1416                    min_value: Precision::Exact(ScalarValue::Int32(Some(21))),
1417                    max_value: Precision::Exact(ScalarValue::Int32(Some(29))),
1418                    ..Default::default()
1419                },
1420                ColumnStatistics {
1421                    min_value: Precision::Exact(ScalarValue::Int32(Some(24))),
1422                    max_value: Precision::Exact(ScalarValue::Int32(Some(26))),
1423                    ..Default::default()
1424                },
1425            ],
1426            ..Default::default()
1427        };
1428
1429        let input = Arc::new(StatisticsExec::new(input_statistics, input_schema));
1430
1431        // Create projection expressions that reference columns from the input schema and the length
1432        // of output schema columns < input schema columns and hence if we use the last few columns
1433        // from the input schema in the expressions here, bounds_check would fail on them if output
1434        // schema is supplied to the partitions_statistics method.
1435        let exprs: Vec<ProjectionExpr> = vec![
1436            ProjectionExpr {
1437                expr: Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>,
1438                alias: "c_renamed".to_string(),
1439            },
1440            ProjectionExpr {
1441                expr: Arc::new(BinaryExpr::new(
1442                    Arc::new(Column::new("e", 4)),
1443                    Operator::Plus,
1444                    Arc::new(Column::new("f", 5)),
1445                )) as Arc<dyn PhysicalExpr>,
1446                alias: "e_plus_f".to_string(),
1447            },
1448        ];
1449
1450        let projection = ProjectionExec::try_new(exprs, input).unwrap();
1451
1452        let stats = projection.partition_statistics(None).unwrap();
1453
1454        assert_eq!(stats.num_rows, Precision::Exact(10));
1455        assert_eq!(
1456            stats.column_statistics.len(),
1457            2,
1458            "Expected 2 columns in projection statistics"
1459        );
1460        assert!(stats.total_byte_size.is_exact().unwrap_or(false));
1461    }
1462}