Skip to main content

datafusion_physical_expr/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ProjectionExpr`] and [`ProjectionExprs`] for representing projections.
19
20use std::ops::Deref;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24use crate::expressions::{CastExpr, Column, Literal};
25use crate::scalar_function::ScalarFunctionExpr;
26use crate::utils::collect_columns;
27
28use arrow::array::{RecordBatch, RecordBatchOptions};
29use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
30use datafusion_common::stats::{ColumnStatistics, Precision};
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_common::{
33    Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err,
34    plan_err,
35};
36
37use datafusion_physical_expr_common::metrics::ExecutionPlanMetricsSet;
38use datafusion_physical_expr_common::metrics::ExpressionEvaluatorMetrics;
39use datafusion_physical_expr_common::physical_expr::fmt_sql;
40use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
41use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays_with_metrics;
42use indexmap::IndexMap;
43use itertools::Itertools;
44
45/// An expression used by projection operations.
46///
47/// The expression is evaluated and the result is stored in a column
48/// with the name specified by `alias`.
49///
50/// For example, the SQL expression `a + b AS sum_ab` would be represented
51/// as a `ProjectionExpr` where `expr` is the expression `a + b`
52/// and `alias` is the string `sum_ab`.
53///
54/// See [`ProjectionExprs`] for a collection of projection expressions.
55#[derive(Debug, Clone)]
56pub struct ProjectionExpr {
57    /// The expression that will be evaluated.
58    pub expr: Arc<dyn PhysicalExpr>,
59    /// The name of the output column for use an output schema.
60    pub alias: String,
61}
62
63impl PartialEq for ProjectionExpr {
64    fn eq(&self, other: &Self) -> bool {
65        let ProjectionExpr { expr, alias } = self;
66        expr.eq(&other.expr) && *alias == other.alias
67    }
68}
69
70impl Eq for ProjectionExpr {}
71
72impl std::fmt::Display for ProjectionExpr {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        if self.expr.to_string() == self.alias {
75            write!(f, "{}", self.alias)
76        } else {
77            write!(f, "{} AS {}", self.expr, self.alias)
78        }
79    }
80}
81
82impl ProjectionExpr {
83    /// Create a new projection expression
84    pub fn new(expr: Arc<dyn PhysicalExpr>, alias: impl Into<String>) -> Self {
85        let alias = alias.into();
86        Self { expr, alias }
87    }
88
89    /// Create a new projection expression from an expression and a schema using the expression's output field name as alias.
90    pub fn new_from_expression(
91        expr: Arc<dyn PhysicalExpr>,
92        schema: &Schema,
93    ) -> Result<Self> {
94        let field = expr.return_field(schema)?;
95        Ok(Self {
96            expr,
97            alias: field.name().to_string(),
98        })
99    }
100}
101
102impl From<(Arc<dyn PhysicalExpr>, String)> for ProjectionExpr {
103    fn from(value: (Arc<dyn PhysicalExpr>, String)) -> Self {
104        Self::new(value.0, value.1)
105    }
106}
107
108impl From<&(Arc<dyn PhysicalExpr>, String)> for ProjectionExpr {
109    fn from(value: &(Arc<dyn PhysicalExpr>, String)) -> Self {
110        Self::new(Arc::clone(&value.0), value.1.clone())
111    }
112}
113
114impl From<ProjectionExpr> for (Arc<dyn PhysicalExpr>, String) {
115    fn from(value: ProjectionExpr) -> Self {
116        (value.expr, value.alias)
117    }
118}
119
120/// A collection of  [`ProjectionExpr`] instances, representing a complete
121/// projection operation.
122///
123/// Projection operations are used in query plans to select specific columns or
124/// compute new columns based on existing ones.
125///
126/// See [`ProjectionExprs::from_indices`] to select a subset of columns by
127/// indices.
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub struct ProjectionExprs {
130    /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance.
131    exprs: Arc<[ProjectionExpr]>,
132}
133
134impl std::fmt::Display for ProjectionExprs {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        let exprs: Vec<String> = self.exprs.iter().map(|e| e.to_string()).collect();
137        write!(f, "Projection[{}]", exprs.join(", "))
138    }
139}
140
141impl From<Vec<ProjectionExpr>> for ProjectionExprs {
142    fn from(value: Vec<ProjectionExpr>) -> Self {
143        Self {
144            exprs: value.into(),
145        }
146    }
147}
148
149impl From<&[ProjectionExpr]> for ProjectionExprs {
150    fn from(value: &[ProjectionExpr]) -> Self {
151        Self {
152            exprs: value.iter().cloned().collect(),
153        }
154    }
155}
156
157impl FromIterator<ProjectionExpr> for ProjectionExprs {
158    fn from_iter<T: IntoIterator<Item = ProjectionExpr>>(exprs: T) -> Self {
159        Self {
160            exprs: exprs.into_iter().collect(),
161        }
162    }
163}
164
165impl AsRef<[ProjectionExpr]> for ProjectionExprs {
166    fn as_ref(&self) -> &[ProjectionExpr] {
167        &self.exprs
168    }
169}
170
171impl ProjectionExprs {
172    /// Make a new [`ProjectionExprs`] from expressions iterator.
173    pub fn new(exprs: impl IntoIterator<Item = ProjectionExpr>) -> Self {
174        Self {
175            exprs: exprs.into_iter().collect(),
176        }
177    }
178
179    /// Make a new [`ProjectionExprs`] from expressions.
180    pub fn from_expressions(exprs: impl Into<Arc<[ProjectionExpr]>>) -> Self {
181        Self {
182            exprs: exprs.into(),
183        }
184    }
185
186    /// Creates a [`ProjectionExpr`] from a list of column indices.
187    ///
188    /// This is a convenience method for creating simple column-only projections, where each projection expression is a reference to a column
189    /// in the input schema.
190    ///
191    /// # Behavior
192    /// - Ordering: the output projection preserves the exact order of indices provided in the input slice
193    ///   For example, `[2, 0, 1]` will produce projections for columns 2, 0, then 1 in that order
194    /// - Duplicates: Duplicate indices are allowed and will create multiple projection expressions referencing the same source column
195    ///   For example, `[0, 0]` creates 2 separate projections both referencing column 0
196    ///
197    /// # Panics
198    /// Panics if any index in `indices` is out of bounds for the provided schema.
199    ///
200    /// # Example
201    ///
202    /// ```rust
203    /// use arrow::datatypes::{DataType, Field, Schema};
204    /// use datafusion_physical_expr::projection::ProjectionExprs;
205    /// use std::sync::Arc;
206    ///
207    /// // Create a schema with three columns
208    /// let schema = Arc::new(Schema::new(vec![
209    ///     Field::new("a", DataType::Int32, false),
210    ///     Field::new("b", DataType::Utf8, false),
211    ///     Field::new("c", DataType::Float64, false),
212    /// ]));
213    ///
214    /// // Project columns at indices 2 and 0 (c and a) - ordering is preserved
215    /// let projection = ProjectionExprs::from_indices(&[2, 0], &schema);
216    ///
217    /// // This creates: SELECT c@2 AS c, a@0 AS a
218    /// assert_eq!(projection.as_ref().len(), 2);
219    /// assert_eq!(projection.as_ref()[0].alias, "c");
220    /// assert_eq!(projection.as_ref()[1].alias, "a");
221    ///
222    /// // Duplicate indices are allowed
223    /// let projection_with_dups = ProjectionExprs::from_indices(&[0, 0, 1], &schema);
224    /// assert_eq!(projection_with_dups.as_ref().len(), 3);
225    /// assert_eq!(projection_with_dups.as_ref()[0].alias, "a");
226    /// assert_eq!(projection_with_dups.as_ref()[1].alias, "a"); // duplicate
227    /// assert_eq!(projection_with_dups.as_ref()[2].alias, "b");
228    /// ```
229    pub fn from_indices(indices: &[usize], schema: &Schema) -> Self {
230        let projection_exprs = indices.iter().map(|&i| {
231            let field = schema.field(i);
232            ProjectionExpr {
233                expr: Arc::new(Column::new(field.name(), i)),
234                alias: field.name().clone(),
235            }
236        });
237
238        Self::from_iter(projection_exprs)
239    }
240
241    /// Returns an iterator over the projection expressions
242    pub fn iter(&self) -> impl Iterator<Item = &ProjectionExpr> {
243        self.exprs.iter()
244    }
245
246    /// Creates a ProjectionMapping from this projection
247    pub fn projection_mapping(
248        &self,
249        input_schema: &SchemaRef,
250    ) -> Result<ProjectionMapping> {
251        ProjectionMapping::try_new(
252            self.exprs
253                .iter()
254                .map(|p| (Arc::clone(&p.expr), p.alias.clone())),
255            input_schema,
256        )
257    }
258
259    /// Iterate over a clone of the projection expressions.
260    pub fn expr_iter(&self) -> impl Iterator<Item = Arc<dyn PhysicalExpr>> + '_ {
261        self.exprs.iter().map(|e| Arc::clone(&e.expr))
262    }
263
264    /// Apply a fallible transformation to the [`PhysicalExpr`] of each projection.
265    ///
266    /// This method transforms the expression in each [`ProjectionExpr`] while preserving
267    /// the alias. This is useful for rewriting expressions, such as when adapting
268    /// expressions to a different schema.
269    ///
270    /// # Example
271    ///
272    /// ```rust
273    /// use std::sync::Arc;
274    /// use arrow::datatypes::{DataType, Field, Schema};
275    /// use datafusion_common::Result;
276    /// use datafusion_physical_expr::expressions::Column;
277    /// use datafusion_physical_expr::projection::ProjectionExprs;
278    /// use datafusion_physical_expr::PhysicalExpr;
279    ///
280    /// // Create a schema and projection
281    /// let schema = Arc::new(Schema::new(vec![
282    ///     Field::new("a", DataType::Int32, false),
283    ///     Field::new("b", DataType::Int32, false),
284    /// ]));
285    /// let projection = ProjectionExprs::from_indices(&[0, 1], &schema);
286    ///
287    /// // Transform each expression (this example just clones them)
288    /// let transformed = projection.try_map_exprs(|expr| Ok(expr))?;
289    /// assert_eq!(transformed.as_ref().len(), 2);
290    /// # Ok::<(), datafusion_common::DataFusionError>(())
291    /// ```
292    pub fn try_map_exprs<F>(self, mut f: F) -> Result<Self>
293    where
294        F: FnMut(Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>,
295    {
296        let exprs = self
297            .exprs
298            .iter()
299            .cloned()
300            .map(|mut proj| {
301                proj.expr = f(proj.expr)?;
302                Ok(proj)
303            })
304            .collect::<Result<Arc<_>>>()?;
305        Ok(Self::from_expressions(exprs))
306    }
307
308    /// Apply another projection on top of this projection, returning the combined projection.
309    /// For example, if this projection is `SELECT c@2 AS x, b@1 AS y, a@0 as z` and the other projection is `SELECT x@0 + 1 AS c1, y@1 + z@2 as c2`,
310    /// we return a projection equivalent to `SELECT c@2 + 1 AS c1, b@1 + a@0 as c2`.
311    ///
312    /// # Example
313    ///
314    /// ```rust
315    /// use datafusion_common::{Result, ScalarValue};
316    /// use datafusion_expr::Operator;
317    /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
318    /// use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs};
319    /// use std::sync::Arc;
320    ///
321    /// fn main() -> Result<()> {
322    ///     // Example from the docstring:
323    ///     // Base projection: SELECT c@2 AS x, b@1 AS y, a@0 AS z
324    ///     let base = ProjectionExprs::new(vec![
325    ///         ProjectionExpr {
326    ///             expr: Arc::new(Column::new("c", 2)),
327    ///             alias: "x".to_string(),
328    ///         },
329    ///         ProjectionExpr {
330    ///             expr: Arc::new(Column::new("b", 1)),
331    ///             alias: "y".to_string(),
332    ///         },
333    ///         ProjectionExpr {
334    ///             expr: Arc::new(Column::new("a", 0)),
335    ///             alias: "z".to_string(),
336    ///         },
337    ///     ]);
338    ///
339    ///     // Top projection: SELECT x@0 + 1 AS c1, y@1 + z@2 AS c2
340    ///     let top = ProjectionExprs::new(vec![
341    ///         ProjectionExpr {
342    ///             expr: Arc::new(BinaryExpr::new(
343    ///                 Arc::new(Column::new("x", 0)),
344    ///                 Operator::Plus,
345    ///                 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
346    ///             )),
347    ///             alias: "c1".to_string(),
348    ///         },
349    ///         ProjectionExpr {
350    ///             expr: Arc::new(BinaryExpr::new(
351    ///                 Arc::new(Column::new("y", 1)),
352    ///                 Operator::Plus,
353    ///                 Arc::new(Column::new("z", 2)),
354    ///             )),
355    ///             alias: "c2".to_string(),
356    ///         },
357    ///     ]);
358    ///
359    ///     // Expected result: SELECT c@2 + 1 AS c1, b@1 + a@0 AS c2
360    ///     let result = base.try_merge(&top)?;
361    ///
362    ///     assert_eq!(result.as_ref().len(), 2);
363    ///     assert_eq!(result.as_ref()[0].alias, "c1");
364    ///     assert_eq!(result.as_ref()[1].alias, "c2");
365    ///
366    ///     Ok(())
367    /// }
368    /// ```
369    ///
370    /// # Errors
371    /// This function returns an error if any expression in the `other` projection cannot be
372    /// applied on top of this projection.
373    pub fn try_merge(&self, other: &ProjectionExprs) -> Result<ProjectionExprs> {
374        let mut new_exprs = Vec::with_capacity(other.exprs.len());
375        for proj_expr in other.exprs.iter() {
376            new_exprs.push(ProjectionExpr {
377                expr: self.unproject_expr(&proj_expr.expr)?,
378                alias: proj_expr.alias.clone(),
379            });
380        }
381        Ok(ProjectionExprs::new(new_exprs))
382    }
383
384    /// Extract the column indices used in this projection.
385    /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1,
386    /// this function would return `[0, 1]`.
387    /// Repeated indices are returned only once, and the order is ascending.
388    pub fn column_indices(&self) -> Vec<usize> {
389        self.exprs
390            .iter()
391            .flat_map(|e| collect_columns(&e.expr).into_iter().map(|col| col.index()))
392            .sorted_unstable()
393            .dedup()
394            .collect_vec()
395    }
396
397    /// Extract the ordered column indices for a column-only projection.
398    ///
399    /// This function assumes that all expressions in the projection are simple column references.
400    /// It returns the column indices in the order they appear in the projection.
401    ///
402    /// # Panics
403    ///
404    /// Panics if any expression in the projection is not a simple column reference. This includes:
405    /// - Computed expressions (e.g., `a + 1`, `CAST(a AS INT)`)
406    /// - Function calls (e.g., `UPPER(name)`, `SUM(amount)`)
407    /// - Literals (e.g., `42`, `'hello'`)
408    /// - Complex nested expressions (e.g., `CASE WHEN ... THEN ... END`)
409    ///
410    /// # Returns
411    ///
412    /// A vector of column indices in projection order. Unlike [`column_indices()`](Self::column_indices),
413    /// this function:
414    /// - Preserves the projection order (does not sort)
415    /// - Preserves duplicates (does not deduplicate)
416    ///
417    /// # Example
418    ///
419    /// For a projection `SELECT c, a, c` where `a` is at index 0 and `c` is at index 2,
420    /// this function would return `[2, 0, 2]`.
421    ///
422    /// Use [`column_indices()`](Self::column_indices) instead if the projection may contain
423    /// non-column expressions or if you need a deduplicated sorted list.
424    ///
425    /// # Panics
426    ///
427    /// Panics if any expression in the projection is not a simple column reference.
428    #[deprecated(
429        since = "52.0.0",
430        note = "Use column_indices() instead. This method will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first."
431    )]
432    pub fn ordered_column_indices(&self) -> Vec<usize> {
433        self.exprs
434            .iter()
435            .map(|e| {
436                e.expr
437                    .downcast_ref::<Column>()
438                    .expect("Expected column reference in projection")
439                    .index()
440            })
441            .collect()
442    }
443
444    /// Project a schema according to this projection.
445    ///
446    /// For example, given a projection:
447    /// * `SELECT a AS x, b + 1 AS y`
448    /// * where `a` is at index 0
449    /// * `b` is at index 1
450    ///
451    /// If the input schema is `[a: Int32, b: Int32, c: Int32]`, the output
452    /// schema would be `[x: Int32, y: Int32]`.
453    ///
454    /// Note that [`Field`] metadata are preserved from the input schema.
455    pub fn project_schema(&self, input_schema: &Schema) -> Result<Schema> {
456        let fields: Result<Vec<Field>> = self
457            .exprs
458            .iter()
459            .map(|proj_expr| {
460                let metadata = proj_expr
461                    .expr
462                    .return_field(input_schema)?
463                    .metadata()
464                    .clone();
465
466                let field = Field::new(
467                    &proj_expr.alias,
468                    proj_expr.expr.data_type(input_schema)?,
469                    proj_expr.expr.nullable(input_schema)?,
470                )
471                .with_metadata(metadata);
472
473                Ok(field)
474            })
475            .collect();
476
477        Ok(Schema::new_with_metadata(
478            fields?,
479            input_schema.metadata().clone(),
480        ))
481    }
482
483    /// "unproject" an expression by applying this projection in reverse,
484    /// returning a new set of expressions that reference the original input
485    /// columns.
486    ///
487    /// For example, consider
488    /// * an expression `c1_c2 > 5`, and a schema `[c1, c2]`
489    /// * a projection `c1 + c2 as c1_c2`
490    ///
491    /// This method would rewrite the expression to `c1 + c2 > 5`
492    pub fn unproject_expr(
493        &self,
494        expr: &Arc<dyn PhysicalExpr>,
495    ) -> Result<Arc<dyn PhysicalExpr>> {
496        update_expr(expr, &self.exprs, true)?.ok_or_else(|| {
497            internal_datafusion_err!(
498                "Failed to unproject an expression {} with ProjectionExprs {}",
499                expr,
500                self.exprs.iter().map(|e| format!("{e}")).join(", ")
501            )
502        })
503    }
504
505    /// "project" an expression using these projection's expressions
506    ///
507    /// For example, consider
508    /// * an expression `c1 + c2 > 5`, and a schema `[c1, c2]`
509    /// * a projection `c1 + c2 as c1_c2`
510    ///
511    /// * This method would rewrite the expression to `c1_c2 > 5`
512    pub fn project_expr(
513        &self,
514        expr: &Arc<dyn PhysicalExpr>,
515    ) -> Result<Arc<dyn PhysicalExpr>> {
516        update_expr(expr, &self.exprs, false)?.ok_or_else(|| {
517            internal_datafusion_err!(
518                "Failed to project an expression {} with ProjectionExprs {}",
519                expr,
520                self.exprs.iter().map(|e| format!("{e}")).join(", ")
521            )
522        })
523    }
524
525    /// Create a new [`Projector`] from this projection and an input schema.
526    ///
527    /// A [`Projector`] can be used to apply this projection to record batches.
528    ///
529    /// # Errors
530    /// This function returns an error if the output schema cannot be constructed from the input schema
531    /// with the given projection expressions.
532    /// For example, if an expression only works with integer columns but the input schema has a string column at that index.
533    pub fn make_projector(&self, input_schema: &Schema) -> Result<Projector> {
534        let output_schema = Arc::new(self.project_schema(input_schema)?);
535        Ok(Projector {
536            projection: self.clone(),
537            output_schema,
538            expression_metrics: None,
539        })
540    }
541
542    pub fn create_expression_metrics(
543        &self,
544        metrics: &ExecutionPlanMetricsSet,
545        partition: usize,
546    ) -> ExpressionEvaluatorMetrics {
547        let labels: Vec<String> = self
548            .exprs
549            .iter()
550            .map(|proj_expr| {
551                let expr_sql = fmt_sql(proj_expr.expr.as_ref()).to_string();
552                if proj_expr.expr.to_string() == proj_expr.alias {
553                    expr_sql
554                } else {
555                    format!("{expr_sql} AS {}", proj_expr.alias)
556                }
557            })
558            .collect();
559        ExpressionEvaluatorMetrics::new(metrics, partition, labels)
560    }
561
562    /// Project statistics according to this projection.
563    /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1,
564    /// if the input statistics has column statistics for columns `a`, `b`, and `c`, the output statistics would have column statistics for columns `x` and `y`.
565    ///
566    /// # Example
567    ///
568    /// ```rust
569    /// use arrow::datatypes::{DataType, Field, Schema};
570    /// use datafusion_common::stats::{ColumnStatistics, Precision, Statistics};
571    /// use datafusion_physical_expr::projection::ProjectionExprs;
572    /// use datafusion_common::Result;
573    /// use datafusion_common::ScalarValue;
574    /// use std::sync::Arc;
575    ///
576    /// fn main() -> Result<()> {
577    ///     // Input schema: a: Int32, b: Int32, c: Int32
578    ///     let input_schema = Arc::new(Schema::new(vec![
579    ///         Field::new("a", DataType::Int32, false),
580    ///         Field::new("b", DataType::Int32, false),
581    ///         Field::new("c", DataType::Int32, false),
582    ///     ]));
583    ///
584    ///     // Input statistics with column stats for a, b, c
585    ///     let input_stats = Statistics {
586    ///         num_rows: Precision::Exact(100),
587    ///         total_byte_size: Precision::Exact(1200),
588    ///         column_statistics: vec![
589    ///             // Column a stats
590    ///             ColumnStatistics::new_unknown()
591    ///                 .with_null_count(Precision::Exact(0))
592    ///                 .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0))))
593    ///                 .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100))))
594    ///                 .with_distinct_count(Precision::Exact(100)),
595    ///             // Column b stats
596    ///             ColumnStatistics::new_unknown()
597    ///                 .with_null_count(Precision::Exact(0))
598    ///                 .with_min_value(Precision::Exact(ScalarValue::Int32(Some(10))))
599    ///                 .with_max_value(Precision::Exact(ScalarValue::Int32(Some(60))))
600    ///                 .with_distinct_count(Precision::Exact(50)),
601    ///             // Column c stats
602    ///             ColumnStatistics::new_unknown()
603    ///                 .with_null_count(Precision::Exact(5))
604    ///                 .with_min_value(Precision::Exact(ScalarValue::Int32(Some(-10))))
605    ///                 .with_max_value(Precision::Exact(ScalarValue::Int32(Some(200))))
606    ///                 .with_distinct_count(Precision::Exact(25)),
607    ///         ],
608    ///     };
609    ///
610    ///     // Create a projection that selects columns c and a (indices 2 and 0)
611    ///     let projection = ProjectionExprs::from_indices(&[2, 0], &input_schema);
612    ///
613    ///     // Compute output schema
614    ///     let output_schema = projection.project_schema(&input_schema)?;
615    ///
616    ///     // Project the statistics
617    ///     let output_stats = projection.project_statistics(input_stats, &output_schema)?;
618    ///
619    ///     // The output should have 2 column statistics (for c and a, in that order)
620    ///     assert_eq!(output_stats.column_statistics.len(), 2);
621    ///
622    ///     // First column in output is c (was at index 2)
623    ///     assert_eq!(
624    ///         output_stats.column_statistics[0].min_value,
625    ///         Precision::Exact(ScalarValue::Int32(Some(-10)))
626    ///     );
627    ///     assert_eq!(
628    ///         output_stats.column_statistics[0].null_count,
629    ///         Precision::Exact(5)
630    ///     );
631    ///
632    ///     // Second column in output is a (was at index 0)
633    ///     assert_eq!(
634    ///         output_stats.column_statistics[1].min_value,
635    ///         Precision::Exact(ScalarValue::Int32(Some(0)))
636    ///     );
637    ///     assert_eq!(
638    ///         output_stats.column_statistics[1].distinct_count,
639    ///         Precision::Exact(100)
640    ///     );
641    ///
642    ///     // Total byte size is recalculated based on projected columns
643    ///     assert_eq!(
644    ///         output_stats.total_byte_size,
645    ///         Precision::Exact(800), // each Int32 column is 4 bytes * 100 rows * 2 columns
646    ///     );
647    ///
648    ///     // Number of rows remains the same
649    ///     assert_eq!(output_stats.num_rows, Precision::Exact(100));
650    ///
651    ///     Ok(())
652    /// }
653    /// ```
654    pub fn project_statistics(
655        &self,
656        mut stats: Statistics,
657        output_schema: &Schema,
658    ) -> Result<Statistics> {
659        let mut column_statistics = Vec::with_capacity(self.exprs.len());
660
661        for proj_expr in self.exprs.iter() {
662            let expr = &proj_expr.expr;
663            let col_stats = if let Some(col) = expr.downcast_ref::<Column>() {
664                std::mem::take(&mut stats.column_statistics[col.index()])
665            } else if let Some(literal) = expr.downcast_ref::<Literal>() {
666                // Handle literal expressions (constants) by calculating proper statistics
667                let data_type = expr.data_type(output_schema)?;
668
669                if literal.value().is_null() {
670                    let null_count = match stats.num_rows {
671                        Precision::Exact(num_rows) => Precision::Exact(num_rows),
672                        _ => Precision::Absent,
673                    };
674
675                    ColumnStatistics {
676                        min_value: Precision::Exact(literal.value().clone()),
677                        max_value: Precision::Exact(literal.value().clone()),
678                        distinct_count: Precision::Exact(1),
679                        null_count,
680                        sum_value: Precision::Exact(literal.value().clone()),
681                        byte_size: Precision::Exact(0),
682                    }
683                } else {
684                    let value = literal.value();
685                    let distinct_count = Precision::Exact(1);
686                    let null_count = Precision::Exact(0);
687
688                    let byte_size = if let Some(byte_width) = data_type.primitive_width()
689                    {
690                        stats.num_rows.multiply(&Precision::Exact(byte_width))
691                    } else {
692                        // Complex types depend on array encoding, so set to Absent
693                        Precision::Absent
694                    };
695
696                    let widened_sum = Precision::Exact(value.clone()).cast_to_sum_type();
697                    let sum_value = widened_sum
698                        .get_value()
699                        .and_then(|sum| {
700                            Precision::<ScalarValue>::from(stats.num_rows)
701                                .cast_to(&sum.data_type())
702                                .ok()
703                        })
704                        .map(|row_count| widened_sum.multiply(&row_count))
705                        .unwrap_or(Precision::Absent);
706
707                    ColumnStatistics {
708                        min_value: Precision::Exact(value.clone()),
709                        max_value: Precision::Exact(value.clone()),
710                        distinct_count,
711                        null_count,
712                        sum_value,
713                        byte_size,
714                    }
715                }
716            } else {
717                project_column_statistics_through_expr(
718                    expr.as_ref(),
719                    &stats.column_statistics,
720                )
721            };
722            column_statistics.push(col_stats);
723        }
724        stats.calculate_total_byte_size(output_schema);
725        stats.column_statistics = column_statistics;
726        Ok(stats)
727    }
728}
729
730/// Propagate column statistics through CAST projections. Other expressions
731/// return unknown — generalizing via [`PhysicalExpr::evaluate_bounds`] is
732/// unsafe for aggregate folding since many impls (e.g. `sin`) return a fixed
733/// envelope rather than tight bounds on the actual inputs.
734fn project_column_statistics_through_expr(
735    expr: &dyn PhysicalExpr,
736    column_stats: &[ColumnStatistics],
737) -> ColumnStatistics {
738    if let Some(col) = expr.downcast_ref::<Column>() {
739        return column_stats[col.index()].clone();
740    }
741    let Some(cast_expr) = expr.downcast_ref::<CastExpr>() else {
742        return ColumnStatistics::new_unknown();
743    };
744    let inner_stats =
745        project_column_statistics_through_expr(cast_expr.expr.as_ref(), column_stats);
746    let target_type = cast_expr.cast_type();
747    ColumnStatistics {
748        min_value: inner_stats
749            .min_value
750            .cast_to(target_type)
751            .unwrap_or(Precision::Absent),
752        max_value: inner_stats
753            .max_value
754            .cast_to(target_type)
755            .unwrap_or(Precision::Absent),
756        null_count: inner_stats.null_count,
757        distinct_count: inner_stats.distinct_count,
758        sum_value: Precision::Absent,
759        byte_size: Precision::Absent,
760    }
761}
762
763impl<'a> IntoIterator for &'a ProjectionExprs {
764    type Item = &'a ProjectionExpr;
765    type IntoIter = std::slice::Iter<'a, ProjectionExpr>;
766
767    fn into_iter(self) -> Self::IntoIter {
768        self.exprs.iter()
769    }
770}
771
772/// Applies a projection to record batches.
773///
774/// A [`Projector`] uses a set of projection expressions to transform
775/// and a pre-computed output schema to project record batches accordingly.
776///
777/// The main reason to use a `Projector` is to avoid repeatedly computing
778/// the output schema for each batch, which can be costly if the projection
779/// expressions are complex.
780#[derive(Clone, Debug)]
781pub struct Projector {
782    projection: ProjectionExprs,
783    output_schema: SchemaRef,
784    /// If `Some`, metrics will be tracked for projection evaluation.
785    expression_metrics: Option<ExpressionEvaluatorMetrics>,
786}
787
788impl Projector {
789    /// Construct the projector with metrics. After execution, related metrics will
790    /// be tracked inside `ExecutionPlanMetricsSet`
791    ///
792    /// See [`ExpressionEvaluatorMetrics`] for details.
793    pub fn with_metrics(
794        &self,
795        metrics: &ExecutionPlanMetricsSet,
796        partition: usize,
797    ) -> Self {
798        let expr_metrics = self
799            .projection
800            .create_expression_metrics(metrics, partition);
801        Self {
802            expression_metrics: Some(expr_metrics),
803            projection: self.projection.clone(),
804            output_schema: Arc::clone(&self.output_schema),
805        }
806    }
807
808    /// Project a record batch according to this projector's expressions.
809    ///
810    /// # Errors
811    /// This function returns an error if any expression evaluation fails
812    /// or if the output schema of the resulting record batch does not match
813    /// the pre-computed output schema of the projector.
814    pub fn project_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
815        let arrays = evaluate_expressions_to_arrays_with_metrics(
816            self.projection.exprs.iter().map(|p| &p.expr),
817            batch,
818            self.expression_metrics.as_ref(),
819        )?;
820
821        if arrays.is_empty() {
822            let options =
823                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
824            RecordBatch::try_new_with_options(
825                Arc::clone(&self.output_schema),
826                arrays,
827                &options,
828            )
829            .map_err(Into::into)
830        } else {
831            RecordBatch::try_new(Arc::clone(&self.output_schema), arrays)
832                .map_err(Into::into)
833        }
834    }
835
836    pub fn output_schema(&self) -> &SchemaRef {
837        &self.output_schema
838    }
839
840    pub fn projection(&self) -> &ProjectionExprs {
841        &self.projection
842    }
843}
844
845/// Describes an immutable reference counted projection.
846///
847/// This structure represents projecting a set of columns by index.
848/// [`Arc`] is used to make it cheap to clone.
849pub type ProjectionRef = Arc<[usize]>;
850
851/// Combine two projections.
852///
853/// If `p1` is [`None`] then there are no changes.
854/// Otherwise, if passed `p2` is not [`None`] then it is remapped
855/// according to the `p1`. Otherwise, there are no changes.
856///
857/// # Example
858///
859/// If stored projection is [0, 2] and we call `apply_projection([0, 2, 3])`,
860/// then the resulting projection will be [0, 3].
861///
862/// # Error
863///
864/// Returns an internal error if `p1` contains index that is greater than `p2` len.
865///
866pub fn combine_projections(
867    p1: Option<&ProjectionRef>,
868    p2: Option<&ProjectionRef>,
869) -> Result<Option<ProjectionRef>> {
870    let Some(p1) = p1 else {
871        return Ok(None);
872    };
873    let Some(p2) = p2 else {
874        return Ok(Some(Arc::clone(p1)));
875    };
876
877    Ok(Some(
878        p1.iter()
879            .map(|i| {
880                let idx = *i;
881                assert_or_internal_err!(
882                    idx < p2.len(),
883                    "unable to apply projection: index {} is greater than new projection len {}",
884                    idx,
885                    p2.len(),
886                );
887                Ok(p2[*i])
888            })
889            .collect::<Result<Arc<[usize]>>>()?,
890    ))
891}
892
893/// The function projects / unprojects an expression with respect to set of
894/// projection expressions.
895///
896/// See also [`ProjectionExprs::unproject_expr`] and [`ProjectionExprs::project_expr`]
897///
898/// 1) When `unproject` is `true`:
899///
900///    Rewrites an expression with respect to the projection expressions,
901///    effectively "unprojecting" it to reference the original input columns.
902///
903///    For example, given
904///    * the expressions `a@1 + b@2` and `c@0`
905///    * and projection expressions `c@2, a@0, b@1`
906///
907///    Then
908///    * `a@1 + b@2` becomes `a@0 + b@1`
909///    * `c@0` becomes `c@2`
910///
911/// 2) When `unproject` is `false`:
912///
913///    Rewrites the expression to reference the projected expressions,
914///    effectively "projecting" it. The resulting expression will reference the
915///    indices as they appear in the projection.
916///
917///    If the expression cannot be rewritten after the projection, it returns
918///    `None`.
919///
920///    For example, given
921///    * the expressions `c@0`, `a@1` and `b@2`
922///    * the projection `a@1 as a, c@0 as c_new`,
923///
924///    Then
925///    * `c@0` becomes `c_new@1`
926///    * `a@1` becomes `a@0`
927///    * `b@2` results in `None` since the projection does not include `b`.
928///
929/// # Errors
930/// This function returns an error if `unproject` is `true` and if any expression references
931/// an index that is out of bounds for `projected_exprs`.
932/// For example:
933///
934/// - `expr` is `a@3`
935/// - `projected_exprs` is \[`a@0`, `b@1`\]
936///
937/// In this case, `a@3` references index 3, which is out of bounds for `projected_exprs` (which has length 2).
938pub fn update_expr(
939    expr: &Arc<dyn PhysicalExpr>,
940    projected_exprs: &[ProjectionExpr],
941    unproject: bool,
942) -> Result<Option<Arc<dyn PhysicalExpr>>> {
943    #[derive(Debug, PartialEq)]
944    enum RewriteState {
945        /// The expression is unchanged.
946        Unchanged,
947        /// Some part of the expression has been rewritten
948        RewrittenValid,
949        /// Some part of the expression has been rewritten, but some column
950        /// references could not be.
951        RewrittenInvalid,
952    }
953
954    let mut state = RewriteState::Unchanged;
955
956    let new_expr = Arc::clone(expr)
957        .transform_up(|expr| {
958            if state == RewriteState::RewrittenInvalid {
959                return Ok(Transformed::no(expr));
960            }
961
962            let Some(column) = expr.downcast_ref::<Column>() else {
963                return Ok(Transformed::no(expr));
964            };
965            if unproject {
966                state = RewriteState::RewrittenValid;
967                // Update the index of `column`:
968                let projected_expr = projected_exprs.get(column.index()).ok_or_else(|| {
969                    internal_datafusion_err!(
970                        "Column index {} out of bounds for projected expressions of length {}",
971                        column.index(),
972                        projected_exprs.len()
973                    )
974                })?;
975                Ok(Transformed::yes(Arc::clone(&projected_expr.expr)))
976            } else {
977                // default to invalid, in case we can't find the relevant column
978                state = RewriteState::RewrittenInvalid;
979                // Determine how to update `column` to accommodate `projected_exprs`
980                projected_exprs
981                    .iter()
982                    .enumerate()
983                    .find_map(|(index, proj_expr)| {
984                        proj_expr.expr.downcast_ref::<Column>().and_then(
985                            |projected_column| {
986                                (column.name().eq(projected_column.name())
987                                    && column.index() == projected_column.index())
988                                .then(|| {
989                                    state = RewriteState::RewrittenValid;
990                                    Arc::new(Column::new(&proj_expr.alias, index)) as _
991                                })
992                            },
993                        )
994                    })
995                    .map_or_else(
996                        || Ok(Transformed::no(expr)),
997                        |c| Ok(Transformed::yes(c)),
998                    )
999            }
1000        })
1001        .data()?;
1002
1003    match state {
1004        RewriteState::RewrittenInvalid => Ok(None),
1005        // Both Unchanged and RewrittenValid are valid:
1006        // - Unchanged means no columns to rewrite (e.g., literals)
1007        // - RewrittenValid means columns were successfully rewritten
1008        RewriteState::Unchanged | RewriteState::RewrittenValid => Ok(Some(new_expr)),
1009    }
1010}
1011
1012/// Stores target expressions, along with their indices, that associate with a
1013/// source expression in a projection mapping.
1014#[derive(Clone, Debug, Default)]
1015pub struct ProjectionTargets {
1016    /// A non-empty vector of pairs of target expressions and their indices.
1017    /// Consider using a special non-empty collection type in the future (e.g.
1018    /// if Rust provides one in the standard library).
1019    exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
1020}
1021
1022impl ProjectionTargets {
1023    /// Returns the first target expression and its index.
1024    pub fn first(&self) -> &(Arc<dyn PhysicalExpr>, usize) {
1025        // Since the vector is non-empty, we can safely unwrap:
1026        self.exprs_indices.first().unwrap()
1027    }
1028
1029    /// Adds a target expression and its index to the list of targets.
1030    pub fn push(&mut self, target: (Arc<dyn PhysicalExpr>, usize)) {
1031        self.exprs_indices.push(target);
1032    }
1033}
1034
1035impl Deref for ProjectionTargets {
1036    type Target = [(Arc<dyn PhysicalExpr>, usize)];
1037
1038    fn deref(&self) -> &Self::Target {
1039        &self.exprs_indices
1040    }
1041}
1042
1043impl From<Vec<(Arc<dyn PhysicalExpr>, usize)>> for ProjectionTargets {
1044    fn from(exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>) -> Self {
1045        Self { exprs_indices }
1046    }
1047}
1048
1049/// Stores the mapping between source expressions and target expressions for a
1050/// projection.
1051#[derive(Clone, Debug)]
1052pub struct ProjectionMapping {
1053    /// Mapping between source expressions and target expressions.
1054    /// Vector indices correspond to the indices after projection.
1055    map: IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>,
1056}
1057
1058impl ProjectionMapping {
1059    /// Constructs the mapping between a projection's input and output
1060    /// expressions.
1061    ///
1062    /// For example, given the input projection expressions (`a + b`, `c + d`)
1063    /// and an output schema with two columns `"c + d"` and `"a + b"`, the
1064    /// projection mapping would be:
1065    ///
1066    /// ```text
1067    ///  [0]: (c + d, [(col("c + d"), 0)])
1068    ///  [1]: (a + b, [(col("a + b"), 1)])
1069    /// ```
1070    ///
1071    /// where `col("c + d")` means the column named `"c + d"`.
1072    pub fn try_new(
1073        expr: impl IntoIterator<Item = (Arc<dyn PhysicalExpr>, String)>,
1074        input_schema: &SchemaRef,
1075    ) -> Result<Self> {
1076        // Construct a map from the input expressions to the output expression of the projection:
1077        let mut map = IndexMap::<_, ProjectionTargets>::new();
1078        for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
1079            let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
1080            let source_expr = expr.transform_down(|e| match e.downcast_ref::<Column>() {
1081                Some(col) => {
1082                    // Sometimes, an expression and its name in the input_schema
1083                    // doesn't match. This can cause problems, so we make sure
1084                    // that the expression name matches with the name in `input_schema`.
1085                    // Conceptually, `source_expr` and `expression` should be the same.
1086                    let idx = col.index();
1087                    let matching_field = input_schema.field(idx);
1088                    let matching_name = matching_field.name();
1089                    assert_or_internal_err!(
1090                        col.name() == matching_name,
1091                        "Input field name {matching_name} does not match with the projection expression {}",
1092                        col.name()
1093                    );
1094                    let matching_column = Column::new(matching_name, idx);
1095                    Ok(Transformed::yes(Arc::new(matching_column)))
1096                }
1097                None => Ok(Transformed::no(e)),
1098            })
1099            .data()?;
1100            map.entry(Arc::clone(&source_expr))
1101                .or_default()
1102                .push((Arc::clone(&target_expr), expr_idx));
1103
1104            // For struct-producing functions (e.g. named_struct), decompose
1105            // into field-level mapping entries so that orderings propagate
1106            // through struct projections. For example, if the projection has
1107            // `named_struct('ticker', p.ticker, ...) AS details`, this adds:
1108            //   p.ticker → get_field(col("details"), "ticker")
1109            // enabling the optimizer to know that sorting by
1110            // `details.ticker` is equivalent to sorting by `p.ticker`.
1111            if let Some(func_expr) = source_expr.downcast_ref::<ScalarFunctionExpr>() {
1112                let literal_args: Vec<Option<ScalarValue>> = func_expr
1113                    .args()
1114                    .iter()
1115                    .map(|arg| arg.downcast_ref::<Literal>().map(|l| l.value().clone()))
1116                    .collect();
1117
1118                if let Some(field_mapping) =
1119                    func_expr.fun().struct_field_mapping(&literal_args)
1120                    && let DataType::Struct(struct_fields) = func_expr.return_type()
1121                {
1122                    for (accessor_args, source_arg_idx) in &field_mapping.fields {
1123                        let value_expr = Arc::clone(&func_expr.args()[*source_arg_idx]);
1124
1125                        // Build accessor args: [target_col, ...field_name_literals]
1126                        let mut accessor_fn_args: Vec<Arc<dyn PhysicalExpr>> =
1127                            vec![Arc::clone(&target_expr)];
1128                        accessor_fn_args.extend(accessor_args.iter().map(|sv| {
1129                            Arc::new(Literal::new(sv.clone())) as Arc<dyn PhysicalExpr>
1130                        }));
1131
1132                        // Look up the field's return type from the struct schema
1133                        let return_field = accessor_args
1134                            .first()
1135                            .and_then(|sv| sv.try_as_str().flatten())
1136                            .and_then(|field_name| {
1137                                struct_fields
1138                                    .iter()
1139                                    .find(|f| f.name() == field_name)
1140                                    .cloned()
1141                            });
1142
1143                        if let Some(return_field) = return_field {
1144                            let field_access_expr = Arc::new(ScalarFunctionExpr::new(
1145                                field_mapping.field_accessor.name(),
1146                                Arc::clone(&field_mapping.field_accessor),
1147                                accessor_fn_args,
1148                                return_field,
1149                                Arc::new(func_expr.config_options().clone()),
1150                            ))
1151                                as Arc<dyn PhysicalExpr>;
1152
1153                            map.entry(value_expr)
1154                                .or_default()
1155                                .push((field_access_expr, expr_idx));
1156                        }
1157                    }
1158                }
1159            }
1160        }
1161        Ok(Self { map })
1162    }
1163
1164    /// Constructs a subset mapping using the provided indices.
1165    ///
1166    /// This is used when the output is a subset of the input without any
1167    /// other transformations. The indices are for columns in the schema.
1168    pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
1169        let projection_exprs = indices.iter().map(|index| {
1170            let field = schema.field(*index);
1171            let column = Arc::new(Column::new(field.name(), *index));
1172            (column as _, field.name().clone())
1173        });
1174        ProjectionMapping::try_new(projection_exprs, schema)
1175    }
1176}
1177
1178impl Deref for ProjectionMapping {
1179    type Target = IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>;
1180
1181    fn deref(&self) -> &Self::Target {
1182        &self.map
1183    }
1184}
1185
1186impl FromIterator<(Arc<dyn PhysicalExpr>, ProjectionTargets)> for ProjectionMapping {
1187    fn from_iter<T: IntoIterator<Item = (Arc<dyn PhysicalExpr>, ProjectionTargets)>>(
1188        iter: T,
1189    ) -> Self {
1190        Self {
1191            map: IndexMap::from_iter(iter),
1192        }
1193    }
1194}
1195
1196/// Projects a slice of [LexOrdering]s onto the given schema.
1197///
1198/// This is a convenience wrapper that applies [project_ordering] to each
1199/// input ordering and collects the successful projections:
1200/// - For each input ordering, the result of [project_ordering] is appended to
1201///   the output if it is `Some(...)`.
1202/// - Order is preserved and no deduplication is attempted.
1203/// - If none of the input orderings can be projected, an empty `Vec` is
1204///   returned.
1205///
1206/// See [project_ordering] for the semantics of projecting a single
1207/// [LexOrdering].
1208pub fn project_orderings(
1209    orderings: &[LexOrdering],
1210    schema: &SchemaRef,
1211) -> Vec<LexOrdering> {
1212    let mut projected_orderings = vec![];
1213
1214    for ordering in orderings {
1215        projected_orderings.extend(project_ordering(ordering, schema));
1216    }
1217
1218    projected_orderings
1219}
1220
1221/// Projects a single [LexOrdering] onto the given schema.
1222///
1223/// This function attempts to rewrite every [PhysicalSortExpr] in the provided
1224/// [LexOrdering] so that any [Column] expressions point at the correct field
1225/// indices in `schema`.
1226///
1227/// Key details:
1228/// - Columns are matched by name, not by index. The index of each matched
1229///   column is looked up with [Schema::column_with_name](arrow::datatypes::Schema::column_with_name) and a new
1230///   [Column] with the correct [index](Column::index) is substituted.
1231/// - If an expression references a column name that does not exist in
1232///   `schema`, projection of the current ordering stops and only the already
1233///   rewritten prefix is kept. This models the fact that a lexicographical
1234///   ordering remains valid for any leading prefix whose expressions are
1235///   present in the projected schema.
1236/// - If no expressions can be projected (i.e. the first one is missing), the
1237///   function returns `None`.
1238///
1239/// Return value:
1240/// - `Some(LexOrdering)` if at least one sort expression could be projected.
1241///   The returned ordering may be a strict prefix of the input ordering.
1242/// - `None` if no part of the ordering can be projected onto `schema`.
1243///
1244/// Example
1245///
1246/// Suppose we have an input ordering `[col("a@0"), col("b@1")]` but the projected
1247/// schema only contains b and not a. The result will be `Some([col("a@0")])`. In other
1248/// words, the column reference is reindexed to match the projected schema.
1249/// If neither a nor b is present, the result will be None.
1250pub fn project_ordering(
1251    ordering: &LexOrdering,
1252    schema: &SchemaRef,
1253) -> Option<LexOrdering> {
1254    let mut projected_exprs = vec![];
1255    for PhysicalSortExpr { expr, options } in ordering.iter() {
1256        let transformed = Arc::clone(expr).transform_up(|expr| {
1257            let Some(col) = expr.downcast_ref::<Column>() else {
1258                return Ok(Transformed::no(expr));
1259            };
1260
1261            let name = col.name();
1262            if let Some((idx, _)) = schema.column_with_name(name) {
1263                // Compute the new column expression (with correct index) after projection:
1264                Ok(Transformed::yes(Arc::new(Column::new(name, idx))))
1265            } else {
1266                // Cannot find expression in the projected_schema,
1267                // signal this using an Err result
1268                plan_err!("")
1269            }
1270        });
1271
1272        match transformed {
1273            Ok(transformed) => {
1274                projected_exprs.push(PhysicalSortExpr::new(transformed.data, *options));
1275            }
1276            Err(_) => {
1277                // Err result indicates an expression could not be found in the
1278                // projected_schema, stop iterating since rest of the orderings are violated
1279                break;
1280            }
1281        }
1282    }
1283
1284    LexOrdering::new(projected_exprs)
1285}
1286
1287#[cfg(test)]
1288pub(crate) mod tests {
1289    use std::collections::HashMap;
1290
1291    use super::*;
1292    use crate::equivalence::{EquivalenceProperties, convert_to_orderings};
1293    use crate::expressions::{BinaryExpr, CastExpr, col};
1294    use crate::utils::tests::TestScalarUDF;
1295    use crate::{PhysicalExprRef, ScalarFunctionExpr};
1296
1297    use arrow::compute::SortOptions;
1298    use arrow::datatypes::{DataType, TimeUnit};
1299    use datafusion_common::config::ConfigOptions;
1300    use datafusion_expr::{Operator, ScalarUDF};
1301    use insta::assert_snapshot;
1302
1303    pub(crate) fn output_schema(
1304        mapping: &ProjectionMapping,
1305        input_schema: &Arc<Schema>,
1306    ) -> Result<SchemaRef> {
1307        // Calculate output schema:
1308        let mut fields = vec![];
1309        for (source, targets) in mapping.iter() {
1310            let data_type = source.data_type(input_schema)?;
1311            let nullable = source.nullable(input_schema)?;
1312            for (target, _) in targets.iter() {
1313                // Skip non-Column targets (e.g. struct field decomposition
1314                // entries which are ScalarFunctionExpr targets).
1315                let Some(column) = target.downcast_ref::<Column>() else {
1316                    continue;
1317                };
1318                fields.push(Field::new(column.name(), data_type.clone(), nullable));
1319            }
1320        }
1321
1322        let output_schema = Arc::new(Schema::new_with_metadata(
1323            fields,
1324            input_schema.metadata().clone(),
1325        ));
1326
1327        Ok(output_schema)
1328    }
1329
1330    #[test]
1331    fn project_orderings() -> Result<()> {
1332        let schema = Arc::new(Schema::new(vec![
1333            Field::new("a", DataType::Int32, true),
1334            Field::new("b", DataType::Int32, true),
1335            Field::new("c", DataType::Int32, true),
1336            Field::new("d", DataType::Int32, true),
1337            Field::new("e", DataType::Int32, true),
1338            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
1339        ]));
1340        let col_a = &col("a", &schema)?;
1341        let col_b = &col("b", &schema)?;
1342        let col_c = &col("c", &schema)?;
1343        let col_d = &col("d", &schema)?;
1344        let col_e = &col("e", &schema)?;
1345        let col_ts = &col("ts", &schema)?;
1346        let a_plus_b = Arc::new(BinaryExpr::new(
1347            Arc::clone(col_a),
1348            Operator::Plus,
1349            Arc::clone(col_b),
1350        )) as Arc<dyn PhysicalExpr>;
1351        let b_plus_d = Arc::new(BinaryExpr::new(
1352            Arc::clone(col_b),
1353            Operator::Plus,
1354            Arc::clone(col_d),
1355        )) as Arc<dyn PhysicalExpr>;
1356        let b_plus_e = Arc::new(BinaryExpr::new(
1357            Arc::clone(col_b),
1358            Operator::Plus,
1359            Arc::clone(col_e),
1360        )) as Arc<dyn PhysicalExpr>;
1361        let c_plus_d = Arc::new(BinaryExpr::new(
1362            Arc::clone(col_c),
1363            Operator::Plus,
1364            Arc::clone(col_d),
1365        )) as Arc<dyn PhysicalExpr>;
1366
1367        let option_asc = SortOptions {
1368            descending: false,
1369            nulls_first: false,
1370        };
1371        let option_desc = SortOptions {
1372            descending: true,
1373            nulls_first: true,
1374        };
1375
1376        let test_cases = vec![
1377            // ---------- TEST CASE 1 ------------
1378            (
1379                // orderings
1380                vec![
1381                    // [b ASC]
1382                    vec![(col_b, option_asc)],
1383                ],
1384                // projection exprs
1385                vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
1386                // expected
1387                vec![
1388                    // [b_new ASC]
1389                    vec![("b_new", option_asc)],
1390                ],
1391            ),
1392            // ---------- TEST CASE 2 ------------
1393            (
1394                // orderings
1395                vec![
1396                    // empty ordering
1397                ],
1398                // projection exprs
1399                vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
1400                // expected
1401                vec![
1402                    // no ordering at the output
1403                ],
1404            ),
1405            // ---------- TEST CASE 3 ------------
1406            (
1407                // orderings
1408                vec![
1409                    // [ts ASC]
1410                    vec![(col_ts, option_asc)],
1411                ],
1412                // projection exprs
1413                vec![
1414                    (col_b, "b_new".to_string()),
1415                    (col_a, "a_new".to_string()),
1416                    (col_ts, "ts_new".to_string()),
1417                ],
1418                // expected
1419                vec![
1420                    // [ts_new ASC]
1421                    vec![("ts_new", option_asc)],
1422                ],
1423            ),
1424            // ---------- TEST CASE 4 ------------
1425            (
1426                // orderings
1427                vec![
1428                    // [a ASC, ts ASC]
1429                    vec![(col_a, option_asc), (col_ts, option_asc)],
1430                    // [b ASC, ts ASC]
1431                    vec![(col_b, option_asc), (col_ts, option_asc)],
1432                ],
1433                // projection exprs
1434                vec![
1435                    (col_b, "b_new".to_string()),
1436                    (col_a, "a_new".to_string()),
1437                    (col_ts, "ts_new".to_string()),
1438                ],
1439                // expected
1440                vec![
1441                    // [a_new ASC, ts_new ASC]
1442                    vec![("a_new", option_asc), ("ts_new", option_asc)],
1443                    // [b_new ASC, ts_new ASC]
1444                    vec![("b_new", option_asc), ("ts_new", option_asc)],
1445                ],
1446            ),
1447            // ---------- TEST CASE 5 ------------
1448            (
1449                // orderings
1450                vec![
1451                    // [a + b ASC]
1452                    vec![(&a_plus_b, option_asc)],
1453                ],
1454                // projection exprs
1455                vec![
1456                    (col_b, "b_new".to_string()),
1457                    (col_a, "a_new".to_string()),
1458                    (&a_plus_b, "a+b".to_string()),
1459                ],
1460                // expected
1461                vec![
1462                    // [a + b ASC]
1463                    vec![("a+b", option_asc)],
1464                ],
1465            ),
1466            // ---------- TEST CASE 6 ------------
1467            (
1468                // orderings
1469                vec![
1470                    // [a + b ASC, c ASC]
1471                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
1472                ],
1473                // projection exprs
1474                vec![
1475                    (col_b, "b_new".to_string()),
1476                    (col_a, "a_new".to_string()),
1477                    (col_c, "c_new".to_string()),
1478                    (&a_plus_b, "a+b".to_string()),
1479                ],
1480                // expected
1481                vec![
1482                    // [a + b ASC, c_new ASC]
1483                    vec![("a+b", option_asc), ("c_new", option_asc)],
1484                ],
1485            ),
1486            // ------- TEST CASE 7 ----------
1487            (
1488                vec![
1489                    // [a ASC, b ASC, c ASC]
1490                    vec![(col_a, option_asc), (col_b, option_asc)],
1491                    // [a ASC, d ASC]
1492                    vec![(col_a, option_asc), (col_d, option_asc)],
1493                ],
1494                // b as b_new, a as a_new, d as d_new b+d
1495                vec![
1496                    (col_b, "b_new".to_string()),
1497                    (col_a, "a_new".to_string()),
1498                    (col_d, "d_new".to_string()),
1499                    (&b_plus_d, "b+d".to_string()),
1500                ],
1501                // expected
1502                vec![
1503                    // [a_new ASC, b_new ASC]
1504                    vec![("a_new", option_asc), ("b_new", option_asc)],
1505                    // [a_new ASC, d_new ASC]
1506                    vec![("a_new", option_asc), ("d_new", option_asc)],
1507                    // [a_new ASC, b+d ASC]
1508                    vec![("a_new", option_asc), ("b+d", option_asc)],
1509                ],
1510            ),
1511            // ------- TEST CASE 8 ----------
1512            (
1513                // orderings
1514                vec![
1515                    // [b+d ASC]
1516                    vec![(&b_plus_d, option_asc)],
1517                ],
1518                // proj exprs
1519                vec![
1520                    (col_b, "b_new".to_string()),
1521                    (col_a, "a_new".to_string()),
1522                    (col_d, "d_new".to_string()),
1523                    (&b_plus_d, "b+d".to_string()),
1524                ],
1525                // expected
1526                vec![
1527                    // [b+d ASC]
1528                    vec![("b+d", option_asc)],
1529                ],
1530            ),
1531            // ------- TEST CASE 9 ----------
1532            (
1533                // orderings
1534                vec![
1535                    // [a ASC, d ASC, b ASC]
1536                    vec![
1537                        (col_a, option_asc),
1538                        (col_d, option_asc),
1539                        (col_b, option_asc),
1540                    ],
1541                    // [c ASC]
1542                    vec![(col_c, option_asc)],
1543                ],
1544                // proj exprs
1545                vec![
1546                    (col_b, "b_new".to_string()),
1547                    (col_a, "a_new".to_string()),
1548                    (col_d, "d_new".to_string()),
1549                    (col_c, "c_new".to_string()),
1550                ],
1551                // expected
1552                vec![
1553                    // [a_new ASC, d_new ASC, b_new ASC]
1554                    vec![
1555                        ("a_new", option_asc),
1556                        ("d_new", option_asc),
1557                        ("b_new", option_asc),
1558                    ],
1559                    // [c_new ASC],
1560                    vec![("c_new", option_asc)],
1561                ],
1562            ),
1563            // ------- TEST CASE 10 ----------
1564            (
1565                vec![
1566                    // [a ASC, b ASC, c ASC]
1567                    vec![
1568                        (col_a, option_asc),
1569                        (col_b, option_asc),
1570                        (col_c, option_asc),
1571                    ],
1572                    // [a ASC, d ASC]
1573                    vec![(col_a, option_asc), (col_d, option_asc)],
1574                ],
1575                // proj exprs
1576                vec![
1577                    (col_b, "b_new".to_string()),
1578                    (col_a, "a_new".to_string()),
1579                    (col_c, "c_new".to_string()),
1580                    (&c_plus_d, "c+d".to_string()),
1581                ],
1582                // expected
1583                vec![
1584                    // [a_new ASC, b_new ASC, c_new ASC]
1585                    vec![
1586                        ("a_new", option_asc),
1587                        ("b_new", option_asc),
1588                        ("c_new", option_asc),
1589                    ],
1590                    // [a_new ASC, b_new ASC, c+d ASC]
1591                    vec![
1592                        ("a_new", option_asc),
1593                        ("b_new", option_asc),
1594                        ("c+d", option_asc),
1595                    ],
1596                ],
1597            ),
1598            // ------- TEST CASE 11 ----------
1599            (
1600                // orderings
1601                vec![
1602                    // [a ASC, b ASC]
1603                    vec![(col_a, option_asc), (col_b, option_asc)],
1604                    // [a ASC, d ASC]
1605                    vec![(col_a, option_asc), (col_d, option_asc)],
1606                ],
1607                // proj exprs
1608                vec![
1609                    (col_b, "b_new".to_string()),
1610                    (col_a, "a_new".to_string()),
1611                    (&b_plus_d, "b+d".to_string()),
1612                ],
1613                // expected
1614                vec![
1615                    // [a_new ASC, b_new ASC]
1616                    vec![("a_new", option_asc), ("b_new", option_asc)],
1617                    // [a_new ASC, b + d ASC]
1618                    vec![("a_new", option_asc), ("b+d", option_asc)],
1619                ],
1620            ),
1621            // ------- TEST CASE 12 ----------
1622            (
1623                // orderings
1624                vec![
1625                    // [a ASC, b ASC, c ASC]
1626                    vec![
1627                        (col_a, option_asc),
1628                        (col_b, option_asc),
1629                        (col_c, option_asc),
1630                    ],
1631                ],
1632                // proj exprs
1633                vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
1634                // expected
1635                vec![
1636                    // [a_new ASC]
1637                    vec![("a_new", option_asc)],
1638                ],
1639            ),
1640            // ------- TEST CASE 13 ----------
1641            (
1642                // orderings
1643                vec![
1644                    // [a ASC, b ASC, c ASC]
1645                    vec![
1646                        (col_a, option_asc),
1647                        (col_b, option_asc),
1648                        (col_c, option_asc),
1649                    ],
1650                    // [a ASC, a + b ASC, c ASC]
1651                    vec![
1652                        (col_a, option_asc),
1653                        (&a_plus_b, option_asc),
1654                        (col_c, option_asc),
1655                    ],
1656                ],
1657                // proj exprs
1658                vec![
1659                    (col_c, "c_new".to_string()),
1660                    (col_b, "b_new".to_string()),
1661                    (col_a, "a_new".to_string()),
1662                    (&a_plus_b, "a+b".to_string()),
1663                ],
1664                // expected
1665                vec![
1666                    // [a_new ASC, b_new ASC, c_new ASC]
1667                    vec![
1668                        ("a_new", option_asc),
1669                        ("b_new", option_asc),
1670                        ("c_new", option_asc),
1671                    ],
1672                    // [a_new ASC, a+b ASC, c_new ASC]
1673                    vec![
1674                        ("a_new", option_asc),
1675                        ("a+b", option_asc),
1676                        ("c_new", option_asc),
1677                    ],
1678                ],
1679            ),
1680            // ------- TEST CASE 14 ----------
1681            (
1682                // orderings
1683                vec![
1684                    // [a ASC, b ASC]
1685                    vec![(col_a, option_asc), (col_b, option_asc)],
1686                    // [c ASC, b ASC]
1687                    vec![(col_c, option_asc), (col_b, option_asc)],
1688                    // [d ASC, e ASC]
1689                    vec![(col_d, option_asc), (col_e, option_asc)],
1690                ],
1691                // proj exprs
1692                vec![
1693                    (col_c, "c_new".to_string()),
1694                    (col_d, "d_new".to_string()),
1695                    (col_a, "a_new".to_string()),
1696                    (&b_plus_e, "b+e".to_string()),
1697                ],
1698                // expected
1699                vec![
1700                    // [a_new ASC, d_new ASC, b+e ASC]
1701                    vec![
1702                        ("a_new", option_asc),
1703                        ("d_new", option_asc),
1704                        ("b+e", option_asc),
1705                    ],
1706                    // [d_new ASC, a_new ASC, b+e ASC]
1707                    vec![
1708                        ("d_new", option_asc),
1709                        ("a_new", option_asc),
1710                        ("b+e", option_asc),
1711                    ],
1712                    // [c_new ASC, d_new ASC, b+e ASC]
1713                    vec![
1714                        ("c_new", option_asc),
1715                        ("d_new", option_asc),
1716                        ("b+e", option_asc),
1717                    ],
1718                    // [d_new ASC, c_new ASC, b+e ASC]
1719                    vec![
1720                        ("d_new", option_asc),
1721                        ("c_new", option_asc),
1722                        ("b+e", option_asc),
1723                    ],
1724                ],
1725            ),
1726            // ------- TEST CASE 15 ----------
1727            (
1728                // orderings
1729                vec![
1730                    // [a ASC, c ASC, b ASC]
1731                    vec![
1732                        (col_a, option_asc),
1733                        (col_c, option_asc),
1734                        (col_b, option_asc),
1735                    ],
1736                ],
1737                // proj exprs
1738                vec![
1739                    (col_c, "c_new".to_string()),
1740                    (col_a, "a_new".to_string()),
1741                    (&a_plus_b, "a+b".to_string()),
1742                ],
1743                // expected
1744                vec![
1745                    // [a_new ASC, d_new ASC, b+e ASC]
1746                    vec![
1747                        ("a_new", option_asc),
1748                        ("c_new", option_asc),
1749                        ("a+b", option_asc),
1750                    ],
1751                ],
1752            ),
1753            // ------- TEST CASE 16 ----------
1754            (
1755                // orderings
1756                vec![
1757                    // [a ASC, b ASC]
1758                    vec![(col_a, option_asc), (col_b, option_asc)],
1759                    // [c ASC, b DESC]
1760                    vec![(col_c, option_asc), (col_b, option_desc)],
1761                    // [e ASC]
1762                    vec![(col_e, option_asc)],
1763                ],
1764                // proj exprs
1765                vec![
1766                    (col_c, "c_new".to_string()),
1767                    (col_a, "a_new".to_string()),
1768                    (col_b, "b_new".to_string()),
1769                    (&b_plus_e, "b+e".to_string()),
1770                ],
1771                // expected
1772                vec![
1773                    // [a_new ASC, b_new ASC]
1774                    vec![("a_new", option_asc), ("b_new", option_asc)],
1775                    // [a_new ASC, b_new ASC]
1776                    vec![("a_new", option_asc), ("b+e", option_asc)],
1777                    // [c_new ASC, b_new DESC]
1778                    vec![("c_new", option_asc), ("b_new", option_desc)],
1779                ],
1780            ),
1781        ];
1782
1783        for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
1784        {
1785            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
1786
1787            let orderings = convert_to_orderings(&orderings);
1788            eq_properties.add_orderings(orderings);
1789
1790            let proj_exprs = proj_exprs
1791                .into_iter()
1792                .map(|(expr, name)| (Arc::clone(expr), name));
1793            let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
1794            let output_schema = output_schema(&projection_mapping, &schema)?;
1795
1796            let expected = expected
1797                .into_iter()
1798                .map(|ordering| {
1799                    ordering
1800                        .into_iter()
1801                        .map(|(name, options)| {
1802                            (col(name, &output_schema).unwrap(), options)
1803                        })
1804                        .collect::<Vec<_>>()
1805                })
1806                .collect::<Vec<_>>();
1807            let expected = convert_to_orderings(&expected);
1808
1809            let projected_eq = eq_properties.project(&projection_mapping, output_schema);
1810            let orderings = projected_eq.oeq_class();
1811
1812            let err_msg = format!(
1813                "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
1814            );
1815
1816            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
1817            for expected_ordering in &expected {
1818                assert!(orderings.contains(expected_ordering), "{}", err_msg)
1819            }
1820        }
1821
1822        Ok(())
1823    }
1824
1825    #[test]
1826    fn project_orderings2() -> Result<()> {
1827        let schema = Arc::new(Schema::new(vec![
1828            Field::new("a", DataType::Int32, true),
1829            Field::new("b", DataType::Int32, true),
1830            Field::new("c", DataType::Int32, true),
1831            Field::new("d", DataType::Int32, true),
1832            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
1833        ]));
1834        let col_a = &col("a", &schema)?;
1835        let col_b = &col("b", &schema)?;
1836        let col_c = &col("c", &schema)?;
1837        let col_ts = &col("ts", &schema)?;
1838        let a_plus_b = Arc::new(BinaryExpr::new(
1839            Arc::clone(col_a),
1840            Operator::Plus,
1841            Arc::clone(col_b),
1842        )) as Arc<dyn PhysicalExpr>;
1843
1844        let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
1845
1846        let round_c = Arc::new(ScalarFunctionExpr::try_new(
1847            test_fun,
1848            vec![Arc::clone(col_c)],
1849            &schema,
1850            Arc::new(ConfigOptions::default()),
1851        )?) as PhysicalExprRef;
1852
1853        let option_asc = SortOptions {
1854            descending: false,
1855            nulls_first: false,
1856        };
1857
1858        let proj_exprs = vec![
1859            (col_b, "b_new".to_string()),
1860            (col_a, "a_new".to_string()),
1861            (col_c, "c_new".to_string()),
1862            (&round_c, "round_c_res".to_string()),
1863        ];
1864        let proj_exprs = proj_exprs
1865            .into_iter()
1866            .map(|(expr, name)| (Arc::clone(expr), name));
1867        let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
1868        let output_schema = output_schema(&projection_mapping, &schema)?;
1869
1870        let col_a_new = &col("a_new", &output_schema)?;
1871        let col_b_new = &col("b_new", &output_schema)?;
1872        let col_c_new = &col("c_new", &output_schema)?;
1873        let col_round_c_res = &col("round_c_res", &output_schema)?;
1874        let a_new_plus_b_new = Arc::new(BinaryExpr::new(
1875            Arc::clone(col_a_new),
1876            Operator::Plus,
1877            Arc::clone(col_b_new),
1878        )) as Arc<dyn PhysicalExpr>;
1879
1880        let test_cases = [
1881            // ---------- TEST CASE 1 ------------
1882            (
1883                // orderings
1884                vec![
1885                    // [a ASC]
1886                    vec![(col_a, option_asc)],
1887                ],
1888                // expected
1889                vec![
1890                    // [b_new ASC]
1891                    vec![(col_a_new, option_asc)],
1892                ],
1893            ),
1894            // ---------- TEST CASE 2 ------------
1895            (
1896                // orderings
1897                vec![
1898                    // [a+b ASC]
1899                    vec![(&a_plus_b, option_asc)],
1900                ],
1901                // expected
1902                vec![
1903                    // [b_new ASC]
1904                    vec![(&a_new_plus_b_new, option_asc)],
1905                ],
1906            ),
1907            // ---------- TEST CASE 3 ------------
1908            (
1909                // orderings
1910                vec![
1911                    // [a ASC, ts ASC]
1912                    vec![(col_a, option_asc), (col_ts, option_asc)],
1913                ],
1914                // expected
1915                vec![
1916                    // [a_new ASC, date_bin_res ASC]
1917                    vec![(col_a_new, option_asc)],
1918                ],
1919            ),
1920            // ---------- TEST CASE 4 ------------
1921            (
1922                // orderings
1923                vec![
1924                    // [a ASC, ts ASC, b ASC]
1925                    vec![
1926                        (col_a, option_asc),
1927                        (col_ts, option_asc),
1928                        (col_b, option_asc),
1929                    ],
1930                ],
1931                // expected
1932                vec![
1933                    // [a_new ASC, date_bin_res ASC]
1934                    vec![(col_a_new, option_asc)],
1935                ],
1936            ),
1937            // ---------- TEST CASE 5 ------------
1938            (
1939                // orderings
1940                vec![
1941                    // [a ASC, c ASC]
1942                    vec![(col_a, option_asc), (col_c, option_asc)],
1943                ],
1944                // expected
1945                vec![
1946                    // [a_new ASC, round_c_res ASC, c_new ASC]
1947                    vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
1948                    // [a_new ASC, c_new ASC]
1949                    vec![(col_a_new, option_asc), (col_c_new, option_asc)],
1950                ],
1951            ),
1952            // ---------- TEST CASE 6 ------------
1953            (
1954                // orderings
1955                vec![
1956                    // [c ASC, b ASC]
1957                    vec![(col_c, option_asc), (col_b, option_asc)],
1958                ],
1959                // expected
1960                vec![
1961                    // [round_c_res ASC]
1962                    vec![(col_round_c_res, option_asc)],
1963                    // [c_new ASC, b_new ASC]
1964                    vec![(col_c_new, option_asc), (col_b_new, option_asc)],
1965                ],
1966            ),
1967            // ---------- TEST CASE 7 ------------
1968            (
1969                // orderings
1970                vec![
1971                    // [a+b ASC, c ASC]
1972                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
1973                ],
1974                // expected
1975                vec![
1976                    // [a+b ASC, round(c) ASC, c_new ASC]
1977                    vec![
1978                        (&a_new_plus_b_new, option_asc),
1979                        (col_round_c_res, option_asc),
1980                    ],
1981                    // [a+b ASC, c_new ASC]
1982                    vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
1983                ],
1984            ),
1985        ];
1986
1987        for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
1988            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
1989
1990            let orderings = convert_to_orderings(orderings);
1991            eq_properties.add_orderings(orderings);
1992
1993            let expected = convert_to_orderings(expected);
1994
1995            let projected_eq =
1996                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
1997            let orderings = projected_eq.oeq_class();
1998
1999            let err_msg = format!(
2000                "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
2001            );
2002
2003            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
2004            for expected_ordering in &expected {
2005                assert!(orderings.contains(expected_ordering), "{}", err_msg)
2006            }
2007        }
2008        Ok(())
2009    }
2010
2011    #[test]
2012    fn project_orderings3() -> Result<()> {
2013        let schema = Arc::new(Schema::new(vec![
2014            Field::new("a", DataType::Int32, true),
2015            Field::new("b", DataType::Int32, true),
2016            Field::new("c", DataType::Int32, true),
2017            Field::new("d", DataType::Int32, true),
2018            Field::new("e", DataType::Int32, true),
2019            Field::new("f", DataType::Int32, true),
2020        ]));
2021        let col_a = &col("a", &schema)?;
2022        let col_b = &col("b", &schema)?;
2023        let col_c = &col("c", &schema)?;
2024        let col_d = &col("d", &schema)?;
2025        let col_e = &col("e", &schema)?;
2026        let col_f = &col("f", &schema)?;
2027        let a_plus_b = Arc::new(BinaryExpr::new(
2028            Arc::clone(col_a),
2029            Operator::Plus,
2030            Arc::clone(col_b),
2031        )) as Arc<dyn PhysicalExpr>;
2032
2033        let option_asc = SortOptions {
2034            descending: false,
2035            nulls_first: false,
2036        };
2037
2038        let proj_exprs = vec![
2039            (col_c, "c_new".to_string()),
2040            (col_d, "d_new".to_string()),
2041            (&a_plus_b, "a+b".to_string()),
2042        ];
2043        let proj_exprs = proj_exprs
2044            .into_iter()
2045            .map(|(expr, name)| (Arc::clone(expr), name));
2046        let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
2047        let output_schema = output_schema(&projection_mapping, &schema)?;
2048
2049        let col_a_plus_b_new = &col("a+b", &output_schema)?;
2050        let col_c_new = &col("c_new", &output_schema)?;
2051        let col_d_new = &col("d_new", &output_schema)?;
2052
2053        let test_cases = vec![
2054            // ---------- TEST CASE 1 ------------
2055            (
2056                // orderings
2057                vec![
2058                    // [d ASC, b ASC]
2059                    vec![(col_d, option_asc), (col_b, option_asc)],
2060                    // [c ASC, a ASC]
2061                    vec![(col_c, option_asc), (col_a, option_asc)],
2062                ],
2063                // equal conditions
2064                vec![],
2065                // expected
2066                vec![
2067                    // [d_new ASC, c_new ASC, a+b ASC]
2068                    vec![
2069                        (col_d_new, option_asc),
2070                        (col_c_new, option_asc),
2071                        (col_a_plus_b_new, option_asc),
2072                    ],
2073                    // [c_new ASC, d_new ASC, a+b ASC]
2074                    vec![
2075                        (col_c_new, option_asc),
2076                        (col_d_new, option_asc),
2077                        (col_a_plus_b_new, option_asc),
2078                    ],
2079                ],
2080            ),
2081            // ---------- TEST CASE 2 ------------
2082            (
2083                // orderings
2084                vec![
2085                    // [d ASC, b ASC]
2086                    vec![(col_d, option_asc), (col_b, option_asc)],
2087                    // [c ASC, e ASC], Please note that a=e
2088                    vec![(col_c, option_asc), (col_e, option_asc)],
2089                ],
2090                // equal conditions
2091                vec![(col_e, col_a)],
2092                // expected
2093                vec![
2094                    // [d_new ASC, c_new ASC, a+b ASC]
2095                    vec![
2096                        (col_d_new, option_asc),
2097                        (col_c_new, option_asc),
2098                        (col_a_plus_b_new, option_asc),
2099                    ],
2100                    // [c_new ASC, d_new ASC, a+b ASC]
2101                    vec![
2102                        (col_c_new, option_asc),
2103                        (col_d_new, option_asc),
2104                        (col_a_plus_b_new, option_asc),
2105                    ],
2106                ],
2107            ),
2108            // ---------- TEST CASE 3 ------------
2109            (
2110                // orderings
2111                vec![
2112                    // [d ASC, b ASC]
2113                    vec![(col_d, option_asc), (col_b, option_asc)],
2114                    // [c ASC, e ASC], Please note that a=f
2115                    vec![(col_c, option_asc), (col_e, option_asc)],
2116                ],
2117                // equal conditions
2118                vec![(col_a, col_f)],
2119                // expected
2120                vec![
2121                    // [d_new ASC]
2122                    vec![(col_d_new, option_asc)],
2123                    // [c_new ASC]
2124                    vec![(col_c_new, option_asc)],
2125                ],
2126            ),
2127        ];
2128        for (orderings, equal_columns, expected) in test_cases {
2129            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
2130            for (lhs, rhs) in equal_columns {
2131                eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?;
2132            }
2133
2134            let orderings = convert_to_orderings(&orderings);
2135            eq_properties.add_orderings(orderings);
2136
2137            let expected = convert_to_orderings(&expected);
2138
2139            let projected_eq =
2140                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
2141            let orderings = projected_eq.oeq_class();
2142
2143            let err_msg = format!(
2144                "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
2145            );
2146
2147            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
2148            for expected_ordering in &expected {
2149                assert!(orderings.contains(expected_ordering), "{}", err_msg)
2150            }
2151        }
2152
2153        Ok(())
2154    }
2155
2156    fn get_stats() -> Statistics {
2157        Statistics {
2158            num_rows: Precision::Exact(5),
2159            total_byte_size: Precision::Exact(23),
2160            column_statistics: vec![
2161                ColumnStatistics {
2162                    distinct_count: Precision::Exact(5),
2163                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
2164                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
2165                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
2166                    null_count: Precision::Exact(0),
2167                    byte_size: Precision::Absent,
2168                },
2169                ColumnStatistics {
2170                    distinct_count: Precision::Exact(1),
2171                    max_value: Precision::Exact(ScalarValue::from("x")),
2172                    min_value: Precision::Exact(ScalarValue::from("a")),
2173                    sum_value: Precision::Absent,
2174                    null_count: Precision::Exact(3),
2175                    byte_size: Precision::Absent,
2176                },
2177                ColumnStatistics {
2178                    distinct_count: Precision::Absent,
2179                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
2180                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
2181                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
2182                    null_count: Precision::Absent,
2183                    byte_size: Precision::Absent,
2184                },
2185            ],
2186        }
2187    }
2188
2189    fn get_schema() -> Schema {
2190        let field_0 = Field::new("col0", DataType::Int64, false);
2191        let field_1 = Field::new("col1", DataType::Utf8, false);
2192        let field_2 = Field::new("col2", DataType::Float32, false);
2193        Schema::new(vec![field_0, field_1, field_2])
2194    }
2195
2196    #[test]
2197    fn test_stats_projection_columns_only() {
2198        let source = get_stats();
2199        let schema = get_schema();
2200
2201        let projection = ProjectionExprs::new(vec![
2202            ProjectionExpr {
2203                expr: Arc::new(Column::new("col1", 1)),
2204                alias: "col1".to_string(),
2205            },
2206            ProjectionExpr {
2207                expr: Arc::new(Column::new("col0", 0)),
2208                alias: "col0".to_string(),
2209            },
2210        ]);
2211
2212        let result = projection
2213            .project_statistics(source, &projection.project_schema(&schema).unwrap())
2214            .unwrap();
2215
2216        let expected = Statistics {
2217            num_rows: Precision::Exact(5),
2218            // Because there is a variable length Utf8 column we cannot calculate exact byte size after projection
2219            // Thus we set it to Inexact (originally it was Exact(23))
2220            total_byte_size: Precision::Inexact(23),
2221            column_statistics: vec![
2222                ColumnStatistics {
2223                    distinct_count: Precision::Exact(1),
2224                    max_value: Precision::Exact(ScalarValue::from("x")),
2225                    min_value: Precision::Exact(ScalarValue::from("a")),
2226                    sum_value: Precision::Absent,
2227                    null_count: Precision::Exact(3),
2228                    byte_size: Precision::Absent,
2229                },
2230                ColumnStatistics {
2231                    distinct_count: Precision::Exact(5),
2232                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
2233                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
2234                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
2235                    null_count: Precision::Exact(0),
2236                    byte_size: Precision::Absent,
2237                },
2238            ],
2239        };
2240
2241        assert_eq!(result, expected);
2242    }
2243
2244    #[test]
2245    fn test_stats_projection_column_with_primitive_width_only() {
2246        let source = get_stats();
2247        let schema = get_schema();
2248
2249        let projection = ProjectionExprs::new(vec![
2250            ProjectionExpr {
2251                expr: Arc::new(Column::new("col2", 2)),
2252                alias: "col2".to_string(),
2253            },
2254            ProjectionExpr {
2255                expr: Arc::new(Column::new("col0", 0)),
2256                alias: "col0".to_string(),
2257            },
2258        ]);
2259
2260        let result = projection
2261            .project_statistics(source, &projection.project_schema(&schema).unwrap())
2262            .unwrap();
2263
2264        let expected = Statistics {
2265            num_rows: Precision::Exact(5),
2266            total_byte_size: Precision::Exact(60),
2267            column_statistics: vec![
2268                ColumnStatistics {
2269                    distinct_count: Precision::Absent,
2270                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
2271                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
2272                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
2273                    null_count: Precision::Absent,
2274                    byte_size: Precision::Absent,
2275                },
2276                ColumnStatistics {
2277                    distinct_count: Precision::Exact(5),
2278                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
2279                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
2280                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
2281                    null_count: Precision::Exact(0),
2282                    byte_size: Precision::Absent,
2283                },
2284            ],
2285        };
2286
2287        assert_eq!(result, expected);
2288    }
2289
2290    // Tests for Projection struct
2291
2292    #[test]
2293    fn test_projection_new() -> Result<()> {
2294        let exprs = vec![
2295            ProjectionExpr {
2296                expr: Arc::new(Column::new("a", 0)),
2297                alias: "a".to_string(),
2298            },
2299            ProjectionExpr {
2300                expr: Arc::new(Column::new("b", 1)),
2301                alias: "b".to_string(),
2302            },
2303        ];
2304        let projection = ProjectionExprs::new(exprs.clone());
2305        assert_eq!(projection.as_ref().len(), 2);
2306        Ok(())
2307    }
2308
2309    #[test]
2310    fn test_projection_from_vec() -> Result<()> {
2311        let exprs = vec![ProjectionExpr {
2312            expr: Arc::new(Column::new("x", 0)),
2313            alias: "x".to_string(),
2314        }];
2315        let projection: ProjectionExprs = exprs.clone().into();
2316        assert_eq!(projection.as_ref().len(), 1);
2317        Ok(())
2318    }
2319
2320    #[test]
2321    fn test_projection_as_ref() -> Result<()> {
2322        let exprs = vec![
2323            ProjectionExpr {
2324                expr: Arc::new(Column::new("col1", 0)),
2325                alias: "col1".to_string(),
2326            },
2327            ProjectionExpr {
2328                expr: Arc::new(Column::new("col2", 1)),
2329                alias: "col2".to_string(),
2330            },
2331        ];
2332        let projection = ProjectionExprs::new(exprs);
2333        let as_ref: &[ProjectionExpr] = projection.as_ref();
2334        assert_eq!(as_ref.len(), 2);
2335        Ok(())
2336    }
2337
2338    #[test]
2339    fn test_column_indices_multiple_columns() -> Result<()> {
2340        // Test with reversed column order to ensure proper reordering
2341        let projection = ProjectionExprs::new(vec![
2342            ProjectionExpr {
2343                expr: Arc::new(Column::new("c", 5)),
2344                alias: "c".to_string(),
2345            },
2346            ProjectionExpr {
2347                expr: Arc::new(Column::new("b", 2)),
2348                alias: "b".to_string(),
2349            },
2350            ProjectionExpr {
2351                expr: Arc::new(Column::new("a", 0)),
2352                alias: "a".to_string(),
2353            },
2354        ]);
2355        // Should return sorted indices regardless of projection order
2356        assert_eq!(projection.column_indices(), vec![0, 2, 5]);
2357        Ok(())
2358    }
2359
2360    #[test]
2361    fn test_column_indices_duplicates() -> Result<()> {
2362        // Test that duplicate column indices appear only once
2363        let projection = ProjectionExprs::new(vec![
2364            ProjectionExpr {
2365                expr: Arc::new(Column::new("a", 1)),
2366                alias: "a".to_string(),
2367            },
2368            ProjectionExpr {
2369                expr: Arc::new(Column::new("b", 3)),
2370                alias: "b".to_string(),
2371            },
2372            ProjectionExpr {
2373                expr: Arc::new(Column::new("a2", 1)), // duplicate index
2374                alias: "a2".to_string(),
2375            },
2376        ]);
2377        assert_eq!(projection.column_indices(), vec![1, 3]);
2378        Ok(())
2379    }
2380
2381    #[test]
2382    fn test_column_indices_unsorted() -> Result<()> {
2383        // Test that column indices are sorted in the output
2384        let projection = ProjectionExprs::new(vec![
2385            ProjectionExpr {
2386                expr: Arc::new(Column::new("c", 5)),
2387                alias: "c".to_string(),
2388            },
2389            ProjectionExpr {
2390                expr: Arc::new(Column::new("a", 1)),
2391                alias: "a".to_string(),
2392            },
2393            ProjectionExpr {
2394                expr: Arc::new(Column::new("b", 3)),
2395                alias: "b".to_string(),
2396            },
2397        ]);
2398        assert_eq!(projection.column_indices(), vec![1, 3, 5]);
2399        Ok(())
2400    }
2401
2402    #[test]
2403    fn test_column_indices_complex_expr() -> Result<()> {
2404        // Test with complex expressions containing multiple columns
2405        let expr = Arc::new(BinaryExpr::new(
2406            Arc::new(Column::new("a", 1)),
2407            Operator::Plus,
2408            Arc::new(Column::new("b", 4)),
2409        ));
2410        let projection = ProjectionExprs::new(vec![
2411            ProjectionExpr {
2412                expr,
2413                alias: "sum".to_string(),
2414            },
2415            ProjectionExpr {
2416                expr: Arc::new(Column::new("c", 2)),
2417                alias: "c".to_string(),
2418            },
2419        ]);
2420        // Should return [1, 2, 4] - all columns used, sorted and deduplicated
2421        assert_eq!(projection.column_indices(), vec![1, 2, 4]);
2422        Ok(())
2423    }
2424
2425    #[test]
2426    fn test_column_indices_empty() -> Result<()> {
2427        let projection = ProjectionExprs::new(vec![]);
2428        assert_eq!(projection.column_indices(), Vec::<usize>::new());
2429        Ok(())
2430    }
2431
2432    #[test]
2433    fn test_merge_simple_columns() -> Result<()> {
2434        // First projection: SELECT c@2 AS x, b@1 AS y, a@0 AS z
2435        let base_projection = ProjectionExprs::new(vec![
2436            ProjectionExpr {
2437                expr: Arc::new(Column::new("c", 2)),
2438                alias: "x".to_string(),
2439            },
2440            ProjectionExpr {
2441                expr: Arc::new(Column::new("b", 1)),
2442                alias: "y".to_string(),
2443            },
2444            ProjectionExpr {
2445                expr: Arc::new(Column::new("a", 0)),
2446                alias: "z".to_string(),
2447            },
2448        ]);
2449
2450        // Second projection: SELECT y@1 AS col2, x@0 AS col1
2451        let top_projection = ProjectionExprs::new(vec![
2452            ProjectionExpr {
2453                expr: Arc::new(Column::new("y", 1)),
2454                alias: "col2".to_string(),
2455            },
2456            ProjectionExpr {
2457                expr: Arc::new(Column::new("x", 0)),
2458                alias: "col1".to_string(),
2459            },
2460        ]);
2461
2462        // Merge should produce: SELECT b@1 AS col2, c@2 AS col1
2463        let merged = base_projection.try_merge(&top_projection)?;
2464        assert_snapshot!(format!("{merged}"), @"Projection[b@1 AS col2, c@2 AS col1]");
2465
2466        Ok(())
2467    }
2468
2469    #[test]
2470    fn test_merge_with_expressions() -> Result<()> {
2471        // First projection: SELECT c@2 AS x, b@1 AS y, a@0 AS z
2472        let base_projection = ProjectionExprs::new(vec![
2473            ProjectionExpr {
2474                expr: Arc::new(Column::new("c", 2)),
2475                alias: "x".to_string(),
2476            },
2477            ProjectionExpr {
2478                expr: Arc::new(Column::new("b", 1)),
2479                alias: "y".to_string(),
2480            },
2481            ProjectionExpr {
2482                expr: Arc::new(Column::new("a", 0)),
2483                alias: "z".to_string(),
2484            },
2485        ]);
2486
2487        // Second projection: SELECT y@1 + z@2 AS c2, x@0 + 1 AS c1
2488        let top_projection = ProjectionExprs::new(vec![
2489            ProjectionExpr {
2490                expr: Arc::new(BinaryExpr::new(
2491                    Arc::new(Column::new("y", 1)),
2492                    Operator::Plus,
2493                    Arc::new(Column::new("z", 2)),
2494                )),
2495                alias: "c2".to_string(),
2496            },
2497            ProjectionExpr {
2498                expr: Arc::new(BinaryExpr::new(
2499                    Arc::new(Column::new("x", 0)),
2500                    Operator::Plus,
2501                    Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
2502                )),
2503                alias: "c1".to_string(),
2504            },
2505        ]);
2506
2507        // Merge should produce: SELECT b@1 + a@0 AS c2, c@2 + 1 AS c1
2508        let merged = base_projection.try_merge(&top_projection)?;
2509        assert_snapshot!(format!("{merged}"), @"Projection[b@1 + a@0 AS c2, c@2 + 1 AS c1]");
2510
2511        Ok(())
2512    }
2513
2514    #[test]
2515    fn try_merge_error() {
2516        // Create a base projection
2517        let base = ProjectionExprs::new(vec![
2518            ProjectionExpr {
2519                expr: Arc::new(Column::new("a", 0)),
2520                alias: "x".to_string(),
2521            },
2522            ProjectionExpr {
2523                expr: Arc::new(Column::new("b", 1)),
2524                alias: "y".to_string(),
2525            },
2526        ]);
2527
2528        // Create a top projection that references a non-existent column index
2529        let top = ProjectionExprs::new(vec![ProjectionExpr {
2530            expr: Arc::new(Column::new("z", 5)), // Invalid index
2531            alias: "result".to_string(),
2532        }]);
2533
2534        // Attempt to merge and expect an error
2535        let err_msg = base.try_merge(&top).unwrap_err().to_string();
2536        assert!(
2537            err_msg.contains("Internal error: Column index 5 out of bounds for projected expressions of length 2"),
2538            "Unexpected error message: {err_msg}",
2539        );
2540    }
2541
2542    #[test]
2543    fn test_merge_empty_projection_with_literal() -> Result<()> {
2544        // This test reproduces the issue from roundtrip_empty_projection test
2545        // Query like: SELECT 1 FROM table
2546        // where the file scan needs no columns (empty projection)
2547        // but we project a literal on top
2548
2549        // Empty base projection (no columns needed from file)
2550        let base_projection = ProjectionExprs::new(vec![]);
2551
2552        // Top projection with a literal expression: SELECT 1
2553        let top_projection = ProjectionExprs::new(vec![ProjectionExpr {
2554            expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
2555            alias: "Int64(1)".to_string(),
2556        }]);
2557
2558        // This should succeed - literals don't reference columns so they should
2559        // pass through unchanged when merged with an empty projection
2560        let merged = base_projection.try_merge(&top_projection)?;
2561        assert_snapshot!(format!("{merged}"), @"Projection[1 AS Int64(1)]");
2562
2563        Ok(())
2564    }
2565
2566    #[test]
2567    fn test_update_expr_with_literal() -> Result<()> {
2568        // Test that update_expr correctly handles expressions without column references
2569        let literal_expr: Arc<dyn PhysicalExpr> =
2570            Arc::new(Literal::new(ScalarValue::Int64(Some(42))));
2571        let empty_projection: Vec<ProjectionExpr> = vec![];
2572
2573        // Updating a literal with an empty projection should return the literal unchanged
2574        let result = update_expr(&literal_expr, &empty_projection, true)?;
2575        assert!(result.is_some(), "Literal expression should be valid");
2576
2577        let result_expr = result.unwrap();
2578        assert_eq!(
2579            result_expr.downcast_ref::<Literal>().unwrap().value(),
2580            &ScalarValue::Int64(Some(42))
2581        );
2582
2583        Ok(())
2584    }
2585
2586    #[test]
2587    fn test_update_expr_with_complex_literal_expr() -> Result<()> {
2588        // Test update_expr with an expression containing both literals and a column
2589        // This tests the case where we have: literal + column
2590        let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
2591            Arc::new(Literal::new(ScalarValue::Int64(Some(10)))),
2592            Operator::Plus,
2593            Arc::new(Column::new("x", 0)),
2594        ));
2595
2596        // Base projection that maps column 0 to a different expression
2597        let base_projection = vec![ProjectionExpr {
2598            expr: Arc::new(Column::new("a", 5)),
2599            alias: "x".to_string(),
2600        }];
2601
2602        // The expression should be updated: 10 + x@0 becomes 10 + a@5
2603        let result = update_expr(&expr, &base_projection, true)?;
2604        assert!(result.is_some(), "Expression should be valid");
2605
2606        let result_expr = result.unwrap();
2607        let binary = result_expr
2608            .downcast_ref::<BinaryExpr>()
2609            .expect("Should be a BinaryExpr");
2610
2611        // Left side should still be the literal
2612        assert!(binary.left().downcast_ref::<Literal>().is_some());
2613
2614        // Right side should be updated to reference column at index 5
2615        let right_col = binary
2616            .right()
2617            .downcast_ref::<Column>()
2618            .expect("Right should be a Column");
2619        assert_eq!(right_col.index(), 5);
2620
2621        Ok(())
2622    }
2623
2624    #[test]
2625    fn test_project_schema_simple_columns() -> Result<()> {
2626        // Input schema: [col0: Int64, col1: Utf8, col2: Float32]
2627        let input_schema = get_schema();
2628
2629        // Projection: SELECT col2 AS c, col0 AS a
2630        let projection = ProjectionExprs::new(vec![
2631            ProjectionExpr {
2632                expr: Arc::new(Column::new("col2", 2)),
2633                alias: "c".to_string(),
2634            },
2635            ProjectionExpr {
2636                expr: Arc::new(Column::new("col0", 0)),
2637                alias: "a".to_string(),
2638            },
2639        ]);
2640
2641        let output_schema = projection.project_schema(&input_schema)?;
2642
2643        // Should have 2 fields
2644        assert_eq!(output_schema.fields().len(), 2);
2645
2646        // First field should be "c" with Float32 type
2647        assert_eq!(output_schema.field(0).name(), "c");
2648        assert_eq!(output_schema.field(0).data_type(), &DataType::Float32);
2649
2650        // Second field should be "a" with Int64 type
2651        assert_eq!(output_schema.field(1).name(), "a");
2652        assert_eq!(output_schema.field(1).data_type(), &DataType::Int64);
2653
2654        Ok(())
2655    }
2656
2657    #[test]
2658    fn test_project_schema_with_expressions() -> Result<()> {
2659        // Input schema: [col0: Int64, col1: Utf8, col2: Float32]
2660        let input_schema = get_schema();
2661
2662        // Projection: SELECT col0 + 1 AS incremented
2663        let projection = ProjectionExprs::new(vec![ProjectionExpr {
2664            expr: Arc::new(BinaryExpr::new(
2665                Arc::new(Column::new("col0", 0)),
2666                Operator::Plus,
2667                Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
2668            )),
2669            alias: "incremented".to_string(),
2670        }]);
2671
2672        let output_schema = projection.project_schema(&input_schema)?;
2673
2674        // Should have 1 field
2675        assert_eq!(output_schema.fields().len(), 1);
2676
2677        // Field should be "incremented" with Int64 type
2678        assert_eq!(output_schema.field(0).name(), "incremented");
2679        assert_eq!(output_schema.field(0).data_type(), &DataType::Int64);
2680
2681        Ok(())
2682    }
2683
2684    #[test]
2685    fn test_project_schema_preserves_metadata() -> Result<()> {
2686        // Create schema with metadata
2687        let mut metadata = HashMap::new();
2688        metadata.insert("key".to_string(), "value".to_string());
2689        let field_with_metadata =
2690            Field::new("col0", DataType::Int64, false).with_metadata(metadata.clone());
2691        let input_schema = Schema::new(vec![
2692            field_with_metadata,
2693            Field::new("col1", DataType::Utf8, false),
2694        ]);
2695
2696        // Projection: SELECT col0 AS renamed
2697        let projection = ProjectionExprs::new(vec![ProjectionExpr {
2698            expr: Arc::new(Column::new("col0", 0)),
2699            alias: "renamed".to_string(),
2700        }]);
2701
2702        let output_schema = projection.project_schema(&input_schema)?;
2703
2704        // Should have 1 field
2705        assert_eq!(output_schema.fields().len(), 1);
2706
2707        // Field should be "renamed" with metadata preserved
2708        assert_eq!(output_schema.field(0).name(), "renamed");
2709        assert_eq!(output_schema.field(0).metadata(), &metadata);
2710
2711        Ok(())
2712    }
2713
2714    #[test]
2715    fn test_project_schema_empty() -> Result<()> {
2716        let input_schema = get_schema();
2717        let projection = ProjectionExprs::new(vec![]);
2718
2719        let output_schema = projection.project_schema(&input_schema)?;
2720
2721        assert_eq!(output_schema.fields().len(), 0);
2722
2723        Ok(())
2724    }
2725
2726    #[test]
2727    fn test_project_statistics_columns_only() -> Result<()> {
2728        let input_stats = get_stats();
2729        let input_schema = get_schema();
2730
2731        // Projection: SELECT col1 AS text, col0 AS num
2732        let projection = ProjectionExprs::new(vec![
2733            ProjectionExpr {
2734                expr: Arc::new(Column::new("col1", 1)),
2735                alias: "text".to_string(),
2736            },
2737            ProjectionExpr {
2738                expr: Arc::new(Column::new("col0", 0)),
2739                alias: "num".to_string(),
2740            },
2741        ]);
2742
2743        let output_stats = projection.project_statistics(
2744            input_stats,
2745            &projection.project_schema(&input_schema)?,
2746        )?;
2747
2748        // Row count should be preserved
2749        assert_eq!(output_stats.num_rows, Precision::Exact(5));
2750
2751        // Should have 2 column statistics (reordered from input)
2752        assert_eq!(output_stats.column_statistics.len(), 2);
2753
2754        // First column (col1 from input)
2755        assert_eq!(
2756            output_stats.column_statistics[0].distinct_count,
2757            Precision::Exact(1)
2758        );
2759        assert_eq!(
2760            output_stats.column_statistics[0].max_value,
2761            Precision::Exact(ScalarValue::from("x"))
2762        );
2763
2764        // Second column (col0 from input)
2765        assert_eq!(
2766            output_stats.column_statistics[1].distinct_count,
2767            Precision::Exact(5)
2768        );
2769        assert_eq!(
2770            output_stats.column_statistics[1].max_value,
2771            Precision::Exact(ScalarValue::Int64(Some(21)))
2772        );
2773
2774        Ok(())
2775    }
2776
2777    #[test]
2778    fn test_project_statistics_with_expressions() -> Result<()> {
2779        let input_stats = get_stats();
2780        let input_schema = get_schema();
2781
2782        // Projection with expression: SELECT col0 + 1 AS incremented, col1 AS text
2783        let projection = ProjectionExprs::new(vec![
2784            ProjectionExpr {
2785                expr: Arc::new(BinaryExpr::new(
2786                    Arc::new(Column::new("col0", 0)),
2787                    Operator::Plus,
2788                    Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
2789                )),
2790                alias: "incremented".to_string(),
2791            },
2792            ProjectionExpr {
2793                expr: Arc::new(Column::new("col1", 1)),
2794                alias: "text".to_string(),
2795            },
2796        ]);
2797
2798        let output_stats = projection.project_statistics(
2799            input_stats,
2800            &projection.project_schema(&input_schema)?,
2801        )?;
2802
2803        // Row count should be preserved
2804        assert_eq!(output_stats.num_rows, Precision::Exact(5));
2805
2806        // Should have 2 column statistics
2807        assert_eq!(output_stats.column_statistics.len(), 2);
2808
2809        // First column (expression) should have unknown statistics
2810        assert_eq!(
2811            output_stats.column_statistics[0].distinct_count,
2812            Precision::Absent
2813        );
2814        assert_eq!(
2815            output_stats.column_statistics[0].max_value,
2816            Precision::Absent
2817        );
2818
2819        // Second column (col1) should preserve statistics
2820        assert_eq!(
2821            output_stats.column_statistics[1].distinct_count,
2822            Precision::Exact(1)
2823        );
2824
2825        Ok(())
2826    }
2827
2828    #[test]
2829    fn test_project_statistics_with_cast() -> Result<()> {
2830        let input_stats = get_stats();
2831        let input_schema = get_schema();
2832
2833        // SELECT CAST(col0 AS Int32) AS casted
2834        let projection = ProjectionExprs::new(vec![ProjectionExpr {
2835            expr: Arc::new(CastExpr::new(
2836                Arc::new(Column::new("col0", 0)),
2837                DataType::Int32,
2838                None,
2839            )),
2840            alias: "casted".to_string(),
2841        }]);
2842
2843        let output_stats = projection.project_statistics(
2844            input_stats,
2845            &projection.project_schema(&input_schema)?,
2846        )?;
2847
2848        assert_eq!(
2849            output_stats.column_statistics[0].min_value,
2850            Precision::Exact(ScalarValue::Int32(Some(-4)))
2851        );
2852        assert_eq!(
2853            output_stats.column_statistics[0].max_value,
2854            Precision::Exact(ScalarValue::Int32(Some(21)))
2855        );
2856
2857        Ok(())
2858    }
2859
2860    #[test]
2861    fn test_project_statistics_primitive_width_only() -> Result<()> {
2862        let input_stats = get_stats();
2863        let input_schema = get_schema();
2864
2865        // Projection with only primitive width columns: SELECT col2 AS f, col0 AS i
2866        let projection = ProjectionExprs::new(vec![
2867            ProjectionExpr {
2868                expr: Arc::new(Column::new("col2", 2)),
2869                alias: "f".to_string(),
2870            },
2871            ProjectionExpr {
2872                expr: Arc::new(Column::new("col0", 0)),
2873                alias: "i".to_string(),
2874            },
2875        ]);
2876
2877        let output_stats = projection.project_statistics(
2878            input_stats,
2879            &projection.project_schema(&input_schema)?,
2880        )?;
2881
2882        // Row count should be preserved
2883        assert_eq!(output_stats.num_rows, Precision::Exact(5));
2884
2885        // Total byte size should be recalculated for primitive types
2886        // Float32 (4 bytes) + Int64 (8 bytes) = 12 bytes per row, 5 rows = 60 bytes
2887        assert_eq!(output_stats.total_byte_size, Precision::Exact(60));
2888
2889        // Should have 2 column statistics
2890        assert_eq!(output_stats.column_statistics.len(), 2);
2891
2892        Ok(())
2893    }
2894
2895    #[test]
2896    fn test_project_statistics_empty() -> Result<()> {
2897        let input_stats = get_stats();
2898        let input_schema = get_schema();
2899
2900        let projection = ProjectionExprs::new(vec![]);
2901
2902        let output_stats = projection.project_statistics(
2903            input_stats,
2904            &projection.project_schema(&input_schema)?,
2905        )?;
2906
2907        // Row count should be preserved
2908        assert_eq!(output_stats.num_rows, Precision::Exact(5));
2909
2910        // Should have no column statistics
2911        assert_eq!(output_stats.column_statistics.len(), 0);
2912
2913        // Total byte size should be 0 for empty projection
2914        assert_eq!(output_stats.total_byte_size, Precision::Exact(0));
2915
2916        Ok(())
2917    }
2918
2919    // Test statistics calculation for non-null literal (numeric constant)
2920    #[test]
2921    fn test_project_statistics_with_literal() -> Result<()> {
2922        let input_stats = get_stats();
2923        let input_schema = get_schema();
2924
2925        // Projection with literal: SELECT 42 AS constant, col0 AS num
2926        let projection = ProjectionExprs::new(vec![
2927            ProjectionExpr {
2928                expr: Arc::new(Literal::new(ScalarValue::Int64(Some(42)))),
2929                alias: "constant".to_string(),
2930            },
2931            ProjectionExpr {
2932                expr: Arc::new(Column::new("col0", 0)),
2933                alias: "num".to_string(),
2934            },
2935        ]);
2936
2937        let output_stats = projection.project_statistics(
2938            input_stats,
2939            &projection.project_schema(&input_schema)?,
2940        )?;
2941
2942        // Row count should be preserved
2943        assert_eq!(output_stats.num_rows, Precision::Exact(5));
2944
2945        // Should have 2 column statistics
2946        assert_eq!(output_stats.column_statistics.len(), 2);
2947
2948        // First column (literal 42) should have proper constant statistics
2949        assert_eq!(
2950            output_stats.column_statistics[0].min_value,
2951            Precision::Exact(ScalarValue::Int64(Some(42)))
2952        );
2953        assert_eq!(
2954            output_stats.column_statistics[0].max_value,
2955            Precision::Exact(ScalarValue::Int64(Some(42)))
2956        );
2957        assert_eq!(
2958            output_stats.column_statistics[0].distinct_count,
2959            Precision::Exact(1)
2960        );
2961        assert_eq!(
2962            output_stats.column_statistics[0].null_count,
2963            Precision::Exact(0)
2964        );
2965        // Int64 is 8 bytes, 5 rows = 40 bytes
2966        assert_eq!(
2967            output_stats.column_statistics[0].byte_size,
2968            Precision::Exact(40)
2969        );
2970        // For a constant column, sum_value = value * num_rows = 42 * 5 = 210
2971        assert_eq!(
2972            output_stats.column_statistics[0].sum_value,
2973            Precision::Exact(ScalarValue::Int64(Some(210)))
2974        );
2975
2976        // Second column (col0) should preserve statistics
2977        assert_eq!(
2978            output_stats.column_statistics[1].distinct_count,
2979            Precision::Exact(5)
2980        );
2981        assert_eq!(
2982            output_stats.column_statistics[1].max_value,
2983            Precision::Exact(ScalarValue::Int64(Some(21)))
2984        );
2985
2986        Ok(())
2987    }
2988
2989    #[test]
2990    fn test_project_statistics_with_i32_literal_sum_widens_to_i64() -> Result<()> {
2991        let input_stats = get_stats();
2992        let input_schema = get_schema();
2993
2994        let projection = ProjectionExprs::new(vec![
2995            ProjectionExpr {
2996                expr: Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2997                alias: "constant".to_string(),
2998            },
2999            ProjectionExpr {
3000                expr: Arc::new(Column::new("col0", 0)),
3001                alias: "num".to_string(),
3002            },
3003        ]);
3004
3005        let output_stats = projection.project_statistics(
3006            input_stats,
3007            &projection.project_schema(&input_schema)?,
3008        )?;
3009
3010        assert_eq!(
3011            output_stats.column_statistics[0].sum_value,
3012            Precision::Exact(ScalarValue::Int64(Some(50)))
3013        );
3014
3015        Ok(())
3016    }
3017
3018    // Test statistics calculation for NULL literal (constant NULL column)
3019    #[test]
3020    fn test_project_statistics_with_null_literal() -> Result<()> {
3021        let input_stats = get_stats();
3022        let input_schema = get_schema();
3023
3024        // Projection with NULL literal: SELECT NULL AS null_col, col0 AS num
3025        let projection = ProjectionExprs::new(vec![
3026            ProjectionExpr {
3027                expr: Arc::new(Literal::new(ScalarValue::Int64(None))),
3028                alias: "null_col".to_string(),
3029            },
3030            ProjectionExpr {
3031                expr: Arc::new(Column::new("col0", 0)),
3032                alias: "num".to_string(),
3033            },
3034        ]);
3035
3036        let output_stats = projection.project_statistics(
3037            input_stats,
3038            &projection.project_schema(&input_schema)?,
3039        )?;
3040
3041        // Row count should be preserved
3042        assert_eq!(output_stats.num_rows, Precision::Exact(5));
3043
3044        // Should have 2 column statistics
3045        assert_eq!(output_stats.column_statistics.len(), 2);
3046
3047        // First column (NULL literal) should have proper constant NULL statistics
3048        assert_eq!(
3049            output_stats.column_statistics[0].min_value,
3050            Precision::Exact(ScalarValue::Int64(None))
3051        );
3052        assert_eq!(
3053            output_stats.column_statistics[0].max_value,
3054            Precision::Exact(ScalarValue::Int64(None))
3055        );
3056        assert_eq!(
3057            output_stats.column_statistics[0].distinct_count,
3058            Precision::Exact(1) // All NULLs are considered the same
3059        );
3060        assert_eq!(
3061            output_stats.column_statistics[0].null_count,
3062            Precision::Exact(5) // All rows are NULL
3063        );
3064        assert_eq!(
3065            output_stats.column_statistics[0].byte_size,
3066            Precision::Exact(0)
3067        );
3068        assert_eq!(
3069            output_stats.column_statistics[0].sum_value,
3070            Precision::Exact(ScalarValue::Int64(None))
3071        );
3072
3073        // Second column (col0) should preserve statistics
3074        assert_eq!(
3075            output_stats.column_statistics[1].distinct_count,
3076            Precision::Exact(5)
3077        );
3078        assert_eq!(
3079            output_stats.column_statistics[1].max_value,
3080            Precision::Exact(ScalarValue::Int64(Some(21)))
3081        );
3082
3083        Ok(())
3084    }
3085
3086    // Test statistics calculation for complex type literal (e.g., Utf8 string)
3087    #[test]
3088    fn test_project_statistics_with_complex_type_literal() -> Result<()> {
3089        let input_stats = get_stats();
3090        let input_schema = get_schema();
3091
3092        // Projection with Utf8 literal (complex type): SELECT 'hello' AS text, col0 AS num
3093        let projection = ProjectionExprs::new(vec![
3094            ProjectionExpr {
3095                expr: Arc::new(Literal::new(ScalarValue::Utf8(Some(
3096                    "hello".to_string(),
3097                )))),
3098                alias: "text".to_string(),
3099            },
3100            ProjectionExpr {
3101                expr: Arc::new(Column::new("col0", 0)),
3102                alias: "num".to_string(),
3103            },
3104        ]);
3105
3106        let output_stats = projection.project_statistics(
3107            input_stats,
3108            &projection.project_schema(&input_schema)?,
3109        )?;
3110
3111        // Row count should be preserved
3112        assert_eq!(output_stats.num_rows, Precision::Exact(5));
3113
3114        // Should have 2 column statistics
3115        assert_eq!(output_stats.column_statistics.len(), 2);
3116
3117        // First column (Utf8 literal 'hello') should have proper constant statistics
3118        // but byte_size should be Absent for complex types
3119        assert_eq!(
3120            output_stats.column_statistics[0].min_value,
3121            Precision::Exact(ScalarValue::Utf8(Some("hello".to_string())))
3122        );
3123        assert_eq!(
3124            output_stats.column_statistics[0].max_value,
3125            Precision::Exact(ScalarValue::Utf8(Some("hello".to_string())))
3126        );
3127        assert_eq!(
3128            output_stats.column_statistics[0].distinct_count,
3129            Precision::Exact(1)
3130        );
3131        assert_eq!(
3132            output_stats.column_statistics[0].null_count,
3133            Precision::Exact(0)
3134        );
3135        // Complex types (Utf8, List, etc.) should have byte_size = Absent
3136        // because we can't calculate exact size without knowing the actual data
3137        assert_eq!(
3138            output_stats.column_statistics[0].byte_size,
3139            Precision::Absent
3140        );
3141        // Non-numeric types (Utf8) should have sum_value = Absent
3142        // because sum is only meaningful for numeric types
3143        assert_eq!(
3144            output_stats.column_statistics[0].sum_value,
3145            Precision::Absent
3146        );
3147
3148        // Second column (col0) should preserve statistics
3149        assert_eq!(
3150            output_stats.column_statistics[1].distinct_count,
3151            Precision::Exact(5)
3152        );
3153        assert_eq!(
3154            output_stats.column_statistics[1].max_value,
3155            Precision::Exact(ScalarValue::Int64(Some(21)))
3156        );
3157
3158        Ok(())
3159    }
3160}