datafusion_physical_expr/equivalence/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::expressions::Column;
21use crate::PhysicalExpr;
22
23use arrow::datatypes::SchemaRef;
24use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25use datafusion_common::{internal_err, Result};
26
27/// Stores the mapping between source expressions and target expressions for a
28/// projection.
29#[derive(Debug, Clone)]
30pub struct ProjectionMapping {
31    /// Mapping between source expressions and target expressions.
32    /// Vector indices correspond to the indices after projection.
33    pub map: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
34}
35
36impl ProjectionMapping {
37    /// Constructs the mapping between a projection's input and output
38    /// expressions.
39    ///
40    /// For example, given the input projection expressions (`a + b`, `c + d`)
41    /// and an output schema with two columns `"c + d"` and `"a + b"`, the
42    /// projection mapping would be:
43    ///
44    /// ```text
45    ///  [0]: (c + d, col("c + d"))
46    ///  [1]: (a + b, col("a + b"))
47    /// ```
48    ///
49    /// where `col("c + d")` means the column named `"c + d"`.
50    pub fn try_new(
51        expr: &[(Arc<dyn PhysicalExpr>, String)],
52        input_schema: &SchemaRef,
53    ) -> Result<Self> {
54        // Construct a map from the input expressions to the output expression of the projection:
55        expr.iter()
56            .enumerate()
57            .map(|(expr_idx, (expression, name))| {
58                let target_expr = Arc::new(Column::new(name, expr_idx)) as _;
59                Arc::clone(expression)
60                    .transform_down(|e| match e.as_any().downcast_ref::<Column>() {
61                        Some(col) => {
62                            // Sometimes, an expression and its name in the input_schema
63                            // doesn't match. This can cause problems, so we make sure
64                            // that the expression name matches with the name in `input_schema`.
65                            // Conceptually, `source_expr` and `expression` should be the same.
66                            let idx = col.index();
67                            let matching_input_field = input_schema.field(idx);
68                            if col.name() != matching_input_field.name() {
69                                return internal_err!("Input field name {} does not match with the projection expression {}",
70                                matching_input_field.name(),col.name())
71                            }
72                            let matching_input_column =
73                                Column::new(matching_input_field.name(), idx);
74                            Ok(Transformed::yes(Arc::new(matching_input_column)))
75                        }
76                        None => Ok(Transformed::no(e)),
77                    })
78                    .data()
79                    .map(|source_expr| (source_expr, target_expr))
80            })
81            .collect::<Result<Vec<_>>>()
82            .map(|map| Self { map })
83    }
84
85    /// Constructs a subset mapping using the provided indices.
86    ///
87    /// This is used when the output is a subset of the input without any
88    /// other transformations. The indices are for columns in the schema.
89    pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
90        let projection_exprs = project_index_to_exprs(indices, schema);
91        ProjectionMapping::try_new(&projection_exprs, schema)
92    }
93
94    /// Iterate over pairs of (source, target) expressions
95    pub fn iter(
96        &self,
97    ) -> impl Iterator<Item = &(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> + '_ {
98        self.map.iter()
99    }
100
101    /// This function returns the target expression for a given source expression.
102    ///
103    /// # Arguments
104    ///
105    /// * `expr` - Source physical expression.
106    ///
107    /// # Returns
108    ///
109    /// An `Option` containing the target for the given source expression,
110    /// where a `None` value means that `expr` is not inside the mapping.
111    pub fn target_expr(
112        &self,
113        expr: &Arc<dyn PhysicalExpr>,
114    ) -> Option<Arc<dyn PhysicalExpr>> {
115        self.map
116            .iter()
117            .find(|(source, _)| source.eq(expr))
118            .map(|(_, target)| Arc::clone(target))
119    }
120}
121
122fn project_index_to_exprs(
123    projection_index: &[usize],
124    schema: &SchemaRef,
125) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
126    projection_index
127        .iter()
128        .map(|index| {
129            let field = schema.field(*index);
130            (
131                Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
132                field.name().to_owned(),
133            )
134        })
135        .collect::<Vec<_>>()
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::equivalence::tests::{
142        convert_to_orderings, convert_to_orderings_owned, output_schema,
143    };
144    use crate::equivalence::EquivalenceProperties;
145    use crate::expressions::{col, BinaryExpr};
146    use crate::utils::tests::TestScalarUDF;
147    use crate::{PhysicalExprRef, ScalarFunctionExpr};
148
149    use arrow::compute::SortOptions;
150    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
151    use datafusion_expr::{Operator, ScalarUDF};
152
153    #[test]
154    fn project_orderings() -> Result<()> {
155        let schema = Arc::new(Schema::new(vec![
156            Field::new("a", DataType::Int32, true),
157            Field::new("b", DataType::Int32, true),
158            Field::new("c", DataType::Int32, true),
159            Field::new("d", DataType::Int32, true),
160            Field::new("e", DataType::Int32, true),
161            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
162        ]));
163        let col_a = &col("a", &schema)?;
164        let col_b = &col("b", &schema)?;
165        let col_c = &col("c", &schema)?;
166        let col_d = &col("d", &schema)?;
167        let col_e = &col("e", &schema)?;
168        let col_ts = &col("ts", &schema)?;
169        let a_plus_b = Arc::new(BinaryExpr::new(
170            Arc::clone(col_a),
171            Operator::Plus,
172            Arc::clone(col_b),
173        )) as Arc<dyn PhysicalExpr>;
174        let b_plus_d = Arc::new(BinaryExpr::new(
175            Arc::clone(col_b),
176            Operator::Plus,
177            Arc::clone(col_d),
178        )) as Arc<dyn PhysicalExpr>;
179        let b_plus_e = Arc::new(BinaryExpr::new(
180            Arc::clone(col_b),
181            Operator::Plus,
182            Arc::clone(col_e),
183        )) as Arc<dyn PhysicalExpr>;
184        let c_plus_d = Arc::new(BinaryExpr::new(
185            Arc::clone(col_c),
186            Operator::Plus,
187            Arc::clone(col_d),
188        )) as Arc<dyn PhysicalExpr>;
189
190        let option_asc = SortOptions {
191            descending: false,
192            nulls_first: false,
193        };
194        let option_desc = SortOptions {
195            descending: true,
196            nulls_first: true,
197        };
198
199        let test_cases = vec![
200            // ---------- TEST CASE 1 ------------
201            (
202                // orderings
203                vec![
204                    // [b ASC]
205                    vec![(col_b, option_asc)],
206                ],
207                // projection exprs
208                vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
209                // expected
210                vec![
211                    // [b_new ASC]
212                    vec![("b_new", option_asc)],
213                ],
214            ),
215            // ---------- TEST CASE 2 ------------
216            (
217                // orderings
218                vec![
219                    // empty ordering
220                ],
221                // projection exprs
222                vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
223                // expected
224                vec![
225                    // no ordering at the output
226                ],
227            ),
228            // ---------- TEST CASE 3 ------------
229            (
230                // orderings
231                vec![
232                    // [ts ASC]
233                    vec![(col_ts, option_asc)],
234                ],
235                // projection exprs
236                vec![
237                    (col_b, "b_new".to_string()),
238                    (col_a, "a_new".to_string()),
239                    (col_ts, "ts_new".to_string()),
240                ],
241                // expected
242                vec![
243                    // [ts_new ASC]
244                    vec![("ts_new", option_asc)],
245                ],
246            ),
247            // ---------- TEST CASE 4 ------------
248            (
249                // orderings
250                vec![
251                    // [a ASC, ts ASC]
252                    vec![(col_a, option_asc), (col_ts, option_asc)],
253                    // [b ASC, ts ASC]
254                    vec![(col_b, option_asc), (col_ts, option_asc)],
255                ],
256                // projection exprs
257                vec![
258                    (col_b, "b_new".to_string()),
259                    (col_a, "a_new".to_string()),
260                    (col_ts, "ts_new".to_string()),
261                ],
262                // expected
263                vec![
264                    // [a_new ASC, ts_new ASC]
265                    vec![("a_new", option_asc), ("ts_new", option_asc)],
266                    // [b_new ASC, ts_new ASC]
267                    vec![("b_new", option_asc), ("ts_new", option_asc)],
268                ],
269            ),
270            // ---------- TEST CASE 5 ------------
271            (
272                // orderings
273                vec![
274                    // [a + b ASC]
275                    vec![(&a_plus_b, option_asc)],
276                ],
277                // projection exprs
278                vec![
279                    (col_b, "b_new".to_string()),
280                    (col_a, "a_new".to_string()),
281                    (&a_plus_b, "a+b".to_string()),
282                ],
283                // expected
284                vec![
285                    // [a + b ASC]
286                    vec![("a+b", option_asc)],
287                ],
288            ),
289            // ---------- TEST CASE 6 ------------
290            (
291                // orderings
292                vec![
293                    // [a + b ASC, c ASC]
294                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
295                ],
296                // projection exprs
297                vec![
298                    (col_b, "b_new".to_string()),
299                    (col_a, "a_new".to_string()),
300                    (col_c, "c_new".to_string()),
301                    (&a_plus_b, "a+b".to_string()),
302                ],
303                // expected
304                vec![
305                    // [a + b ASC, c_new ASC]
306                    vec![("a+b", option_asc), ("c_new", option_asc)],
307                ],
308            ),
309            // ------- TEST CASE 7 ----------
310            (
311                vec![
312                    // [a ASC, b ASC, c ASC]
313                    vec![(col_a, option_asc), (col_b, option_asc)],
314                    // [a ASC, d ASC]
315                    vec![(col_a, option_asc), (col_d, option_asc)],
316                ],
317                // b as b_new, a as a_new, d as d_new b+d
318                vec![
319                    (col_b, "b_new".to_string()),
320                    (col_a, "a_new".to_string()),
321                    (col_d, "d_new".to_string()),
322                    (&b_plus_d, "b+d".to_string()),
323                ],
324                // expected
325                vec![
326                    // [a_new ASC, b_new ASC]
327                    vec![("a_new", option_asc), ("b_new", option_asc)],
328                    // [a_new ASC, d_new ASC]
329                    vec![("a_new", option_asc), ("d_new", option_asc)],
330                    // [a_new ASC, b+d ASC]
331                    vec![("a_new", option_asc), ("b+d", option_asc)],
332                ],
333            ),
334            // ------- TEST CASE 8 ----------
335            (
336                // orderings
337                vec![
338                    // [b+d ASC]
339                    vec![(&b_plus_d, option_asc)],
340                ],
341                // proj exprs
342                vec![
343                    (col_b, "b_new".to_string()),
344                    (col_a, "a_new".to_string()),
345                    (col_d, "d_new".to_string()),
346                    (&b_plus_d, "b+d".to_string()),
347                ],
348                // expected
349                vec![
350                    // [b+d ASC]
351                    vec![("b+d", option_asc)],
352                ],
353            ),
354            // ------- TEST CASE 9 ----------
355            (
356                // orderings
357                vec![
358                    // [a ASC, d ASC, b ASC]
359                    vec![
360                        (col_a, option_asc),
361                        (col_d, option_asc),
362                        (col_b, option_asc),
363                    ],
364                    // [c ASC]
365                    vec![(col_c, option_asc)],
366                ],
367                // proj exprs
368                vec![
369                    (col_b, "b_new".to_string()),
370                    (col_a, "a_new".to_string()),
371                    (col_d, "d_new".to_string()),
372                    (col_c, "c_new".to_string()),
373                ],
374                // expected
375                vec![
376                    // [a_new ASC, d_new ASC, b_new ASC]
377                    vec![
378                        ("a_new", option_asc),
379                        ("d_new", option_asc),
380                        ("b_new", option_asc),
381                    ],
382                    // [c_new ASC],
383                    vec![("c_new", option_asc)],
384                ],
385            ),
386            // ------- TEST CASE 10 ----------
387            (
388                vec![
389                    // [a ASC, b ASC, c ASC]
390                    vec![
391                        (col_a, option_asc),
392                        (col_b, option_asc),
393                        (col_c, option_asc),
394                    ],
395                    // [a ASC, d ASC]
396                    vec![(col_a, option_asc), (col_d, option_asc)],
397                ],
398                // proj exprs
399                vec![
400                    (col_b, "b_new".to_string()),
401                    (col_a, "a_new".to_string()),
402                    (col_c, "c_new".to_string()),
403                    (&c_plus_d, "c+d".to_string()),
404                ],
405                // expected
406                vec![
407                    // [a_new ASC, b_new ASC, c_new ASC]
408                    vec![
409                        ("a_new", option_asc),
410                        ("b_new", option_asc),
411                        ("c_new", option_asc),
412                    ],
413                    // [a_new ASC, b_new ASC, c+d ASC]
414                    vec![
415                        ("a_new", option_asc),
416                        ("b_new", option_asc),
417                        ("c+d", option_asc),
418                    ],
419                ],
420            ),
421            // ------- TEST CASE 11 ----------
422            (
423                // orderings
424                vec![
425                    // [a ASC, b ASC]
426                    vec![(col_a, option_asc), (col_b, option_asc)],
427                    // [a ASC, d ASC]
428                    vec![(col_a, option_asc), (col_d, option_asc)],
429                ],
430                // proj exprs
431                vec![
432                    (col_b, "b_new".to_string()),
433                    (col_a, "a_new".to_string()),
434                    (&b_plus_d, "b+d".to_string()),
435                ],
436                // expected
437                vec![
438                    // [a_new ASC, b_new ASC]
439                    vec![("a_new", option_asc), ("b_new", option_asc)],
440                    // [a_new ASC, b + d ASC]
441                    vec![("a_new", option_asc), ("b+d", option_asc)],
442                ],
443            ),
444            // ------- TEST CASE 12 ----------
445            (
446                // orderings
447                vec![
448                    // [a ASC, b ASC, c ASC]
449                    vec![
450                        (col_a, option_asc),
451                        (col_b, option_asc),
452                        (col_c, option_asc),
453                    ],
454                ],
455                // proj exprs
456                vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
457                // expected
458                vec![
459                    // [a_new ASC]
460                    vec![("a_new", option_asc)],
461                ],
462            ),
463            // ------- TEST CASE 13 ----------
464            (
465                // orderings
466                vec![
467                    // [a ASC, b ASC, c ASC]
468                    vec![
469                        (col_a, option_asc),
470                        (col_b, option_asc),
471                        (col_c, option_asc),
472                    ],
473                    // [a ASC, a + b ASC, c ASC]
474                    vec![
475                        (col_a, option_asc),
476                        (&a_plus_b, option_asc),
477                        (col_c, option_asc),
478                    ],
479                ],
480                // proj exprs
481                vec![
482                    (col_c, "c_new".to_string()),
483                    (col_b, "b_new".to_string()),
484                    (col_a, "a_new".to_string()),
485                    (&a_plus_b, "a+b".to_string()),
486                ],
487                // expected
488                vec![
489                    // [a_new ASC, b_new ASC, c_new ASC]
490                    vec![
491                        ("a_new", option_asc),
492                        ("b_new", option_asc),
493                        ("c_new", option_asc),
494                    ],
495                    // [a_new ASC, a+b ASC, c_new ASC]
496                    vec![
497                        ("a_new", option_asc),
498                        ("a+b", option_asc),
499                        ("c_new", option_asc),
500                    ],
501                ],
502            ),
503            // ------- TEST CASE 14 ----------
504            (
505                // orderings
506                vec![
507                    // [a ASC, b ASC]
508                    vec![(col_a, option_asc), (col_b, option_asc)],
509                    // [c ASC, b ASC]
510                    vec![(col_c, option_asc), (col_b, option_asc)],
511                    // [d ASC, e ASC]
512                    vec![(col_d, option_asc), (col_e, option_asc)],
513                ],
514                // proj exprs
515                vec![
516                    (col_c, "c_new".to_string()),
517                    (col_d, "d_new".to_string()),
518                    (col_a, "a_new".to_string()),
519                    (&b_plus_e, "b+e".to_string()),
520                ],
521                // expected
522                vec![
523                    // [a_new ASC, d_new ASC, b+e ASC]
524                    vec![
525                        ("a_new", option_asc),
526                        ("d_new", option_asc),
527                        ("b+e", option_asc),
528                    ],
529                    // [d_new ASC, a_new ASC, b+e ASC]
530                    vec![
531                        ("d_new", option_asc),
532                        ("a_new", option_asc),
533                        ("b+e", option_asc),
534                    ],
535                    // [c_new ASC, d_new ASC, b+e ASC]
536                    vec![
537                        ("c_new", option_asc),
538                        ("d_new", option_asc),
539                        ("b+e", option_asc),
540                    ],
541                    // [d_new ASC, c_new ASC, b+e ASC]
542                    vec![
543                        ("d_new", option_asc),
544                        ("c_new", option_asc),
545                        ("b+e", option_asc),
546                    ],
547                ],
548            ),
549            // ------- TEST CASE 15 ----------
550            (
551                // orderings
552                vec![
553                    // [a ASC, c ASC, b ASC]
554                    vec![
555                        (col_a, option_asc),
556                        (col_c, option_asc),
557                        (col_b, option_asc),
558                    ],
559                ],
560                // proj exprs
561                vec![
562                    (col_c, "c_new".to_string()),
563                    (col_a, "a_new".to_string()),
564                    (&a_plus_b, "a+b".to_string()),
565                ],
566                // expected
567                vec![
568                    // [a_new ASC, d_new ASC, b+e ASC]
569                    vec![
570                        ("a_new", option_asc),
571                        ("c_new", option_asc),
572                        ("a+b", option_asc),
573                    ],
574                ],
575            ),
576            // ------- TEST CASE 16 ----------
577            (
578                // orderings
579                vec![
580                    // [a ASC, b ASC]
581                    vec![(col_a, option_asc), (col_b, option_asc)],
582                    // [c ASC, b DESC]
583                    vec![(col_c, option_asc), (col_b, option_desc)],
584                    // [e ASC]
585                    vec![(col_e, option_asc)],
586                ],
587                // proj exprs
588                vec![
589                    (col_c, "c_new".to_string()),
590                    (col_a, "a_new".to_string()),
591                    (col_b, "b_new".to_string()),
592                    (&b_plus_e, "b+e".to_string()),
593                ],
594                // expected
595                vec![
596                    // [a_new ASC, b_new ASC]
597                    vec![("a_new", option_asc), ("b_new", option_asc)],
598                    // [a_new ASC, b_new ASC]
599                    vec![("a_new", option_asc), ("b+e", option_asc)],
600                    // [c_new ASC, b_new DESC]
601                    vec![("c_new", option_asc), ("b_new", option_desc)],
602                ],
603            ),
604        ];
605
606        for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
607        {
608            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
609
610            let orderings = convert_to_orderings(&orderings);
611            eq_properties.add_new_orderings(orderings);
612
613            let proj_exprs = proj_exprs
614                .into_iter()
615                .map(|(expr, name)| (Arc::clone(expr), name))
616                .collect::<Vec<_>>();
617            let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
618            let output_schema = output_schema(&projection_mapping, &schema)?;
619
620            let expected = expected
621                .into_iter()
622                .map(|ordering| {
623                    ordering
624                        .into_iter()
625                        .map(|(name, options)| {
626                            (col(name, &output_schema).unwrap(), options)
627                        })
628                        .collect::<Vec<_>>()
629                })
630                .collect::<Vec<_>>();
631            let expected = convert_to_orderings_owned(&expected);
632
633            let projected_eq = eq_properties.project(&projection_mapping, output_schema);
634            let orderings = projected_eq.oeq_class();
635
636            let err_msg = format!(
637                "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
638            );
639
640            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
641            for expected_ordering in &expected {
642                assert!(orderings.contains(expected_ordering), "{}", err_msg)
643            }
644        }
645
646        Ok(())
647    }
648
649    #[test]
650    fn project_orderings2() -> Result<()> {
651        let schema = Arc::new(Schema::new(vec![
652            Field::new("a", DataType::Int32, true),
653            Field::new("b", DataType::Int32, true),
654            Field::new("c", DataType::Int32, true),
655            Field::new("d", DataType::Int32, true),
656            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
657        ]));
658        let col_a = &col("a", &schema)?;
659        let col_b = &col("b", &schema)?;
660        let col_c = &col("c", &schema)?;
661        let col_ts = &col("ts", &schema)?;
662        let a_plus_b = Arc::new(BinaryExpr::new(
663            Arc::clone(col_a),
664            Operator::Plus,
665            Arc::clone(col_b),
666        )) as Arc<dyn PhysicalExpr>;
667
668        let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
669
670        let round_c = Arc::new(ScalarFunctionExpr::try_new(
671            test_fun,
672            vec![Arc::clone(col_c)],
673            &schema,
674        )?) as PhysicalExprRef;
675
676        let option_asc = SortOptions {
677            descending: false,
678            nulls_first: false,
679        };
680
681        let proj_exprs = vec![
682            (col_b, "b_new".to_string()),
683            (col_a, "a_new".to_string()),
684            (col_c, "c_new".to_string()),
685            (&round_c, "round_c_res".to_string()),
686        ];
687        let proj_exprs = proj_exprs
688            .into_iter()
689            .map(|(expr, name)| (Arc::clone(expr), name))
690            .collect::<Vec<_>>();
691        let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
692        let output_schema = output_schema(&projection_mapping, &schema)?;
693
694        let col_a_new = &col("a_new", &output_schema)?;
695        let col_b_new = &col("b_new", &output_schema)?;
696        let col_c_new = &col("c_new", &output_schema)?;
697        let col_round_c_res = &col("round_c_res", &output_schema)?;
698        let a_new_plus_b_new = Arc::new(BinaryExpr::new(
699            Arc::clone(col_a_new),
700            Operator::Plus,
701            Arc::clone(col_b_new),
702        )) as Arc<dyn PhysicalExpr>;
703
704        let test_cases = vec![
705            // ---------- TEST CASE 1 ------------
706            (
707                // orderings
708                vec![
709                    // [a ASC]
710                    vec![(col_a, option_asc)],
711                ],
712                // expected
713                vec![
714                    // [b_new ASC]
715                    vec![(col_a_new, option_asc)],
716                ],
717            ),
718            // ---------- TEST CASE 2 ------------
719            (
720                // orderings
721                vec![
722                    // [a+b ASC]
723                    vec![(&a_plus_b, option_asc)],
724                ],
725                // expected
726                vec![
727                    // [b_new ASC]
728                    vec![(&a_new_plus_b_new, option_asc)],
729                ],
730            ),
731            // ---------- TEST CASE 3 ------------
732            (
733                // orderings
734                vec![
735                    // [a ASC, ts ASC]
736                    vec![(col_a, option_asc), (col_ts, option_asc)],
737                ],
738                // expected
739                vec![
740                    // [a_new ASC, date_bin_res ASC]
741                    vec![(col_a_new, option_asc)],
742                ],
743            ),
744            // ---------- TEST CASE 4 ------------
745            (
746                // orderings
747                vec![
748                    // [a ASC, ts ASC, b ASC]
749                    vec![
750                        (col_a, option_asc),
751                        (col_ts, option_asc),
752                        (col_b, option_asc),
753                    ],
754                ],
755                // expected
756                vec![
757                    // [a_new ASC, date_bin_res ASC]
758                    vec![(col_a_new, option_asc)],
759                ],
760            ),
761            // ---------- TEST CASE 5 ------------
762            (
763                // orderings
764                vec![
765                    // [a ASC, c ASC]
766                    vec![(col_a, option_asc), (col_c, option_asc)],
767                ],
768                // expected
769                vec![
770                    // [a_new ASC, round_c_res ASC, c_new ASC]
771                    vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
772                    // [a_new ASC, c_new ASC]
773                    vec![(col_a_new, option_asc), (col_c_new, option_asc)],
774                ],
775            ),
776            // ---------- TEST CASE 6 ------------
777            (
778                // orderings
779                vec![
780                    // [c ASC, b ASC]
781                    vec![(col_c, option_asc), (col_b, option_asc)],
782                ],
783                // expected
784                vec![
785                    // [round_c_res ASC]
786                    vec![(col_round_c_res, option_asc)],
787                    // [c_new ASC, b_new ASC]
788                    vec![(col_c_new, option_asc), (col_b_new, option_asc)],
789                ],
790            ),
791            // ---------- TEST CASE 7 ------------
792            (
793                // orderings
794                vec![
795                    // [a+b ASC, c ASC]
796                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
797                ],
798                // expected
799                vec![
800                    // [a+b ASC, round(c) ASC, c_new ASC]
801                    vec![
802                        (&a_new_plus_b_new, option_asc),
803                        (col_round_c_res, option_asc),
804                    ],
805                    // [a+b ASC, c_new ASC]
806                    vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
807                ],
808            ),
809        ];
810
811        for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
812            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
813
814            let orderings = convert_to_orderings(orderings);
815            eq_properties.add_new_orderings(orderings);
816
817            let expected = convert_to_orderings(expected);
818
819            let projected_eq =
820                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
821            let orderings = projected_eq.oeq_class();
822
823            let err_msg = format!(
824                "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
825            );
826
827            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
828            for expected_ordering in &expected {
829                assert!(orderings.contains(expected_ordering), "{}", err_msg)
830            }
831        }
832        Ok(())
833    }
834
835    #[test]
836    fn project_orderings3() -> Result<()> {
837        let schema = Arc::new(Schema::new(vec![
838            Field::new("a", DataType::Int32, true),
839            Field::new("b", DataType::Int32, true),
840            Field::new("c", DataType::Int32, true),
841            Field::new("d", DataType::Int32, true),
842            Field::new("e", DataType::Int32, true),
843            Field::new("f", DataType::Int32, true),
844        ]));
845        let col_a = &col("a", &schema)?;
846        let col_b = &col("b", &schema)?;
847        let col_c = &col("c", &schema)?;
848        let col_d = &col("d", &schema)?;
849        let col_e = &col("e", &schema)?;
850        let col_f = &col("f", &schema)?;
851        let a_plus_b = Arc::new(BinaryExpr::new(
852            Arc::clone(col_a),
853            Operator::Plus,
854            Arc::clone(col_b),
855        )) as Arc<dyn PhysicalExpr>;
856
857        let option_asc = SortOptions {
858            descending: false,
859            nulls_first: false,
860        };
861
862        let proj_exprs = vec![
863            (col_c, "c_new".to_string()),
864            (col_d, "d_new".to_string()),
865            (&a_plus_b, "a+b".to_string()),
866        ];
867        let proj_exprs = proj_exprs
868            .into_iter()
869            .map(|(expr, name)| (Arc::clone(expr), name))
870            .collect::<Vec<_>>();
871        let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
872        let output_schema = output_schema(&projection_mapping, &schema)?;
873
874        let col_a_plus_b_new = &col("a+b", &output_schema)?;
875        let col_c_new = &col("c_new", &output_schema)?;
876        let col_d_new = &col("d_new", &output_schema)?;
877
878        let test_cases = vec![
879            // ---------- TEST CASE 1 ------------
880            (
881                // orderings
882                vec![
883                    // [d ASC, b ASC]
884                    vec![(col_d, option_asc), (col_b, option_asc)],
885                    // [c ASC, a ASC]
886                    vec![(col_c, option_asc), (col_a, option_asc)],
887                ],
888                // equal conditions
889                vec![],
890                // expected
891                vec![
892                    // [d_new ASC, c_new ASC, a+b ASC]
893                    vec![
894                        (col_d_new, option_asc),
895                        (col_c_new, option_asc),
896                        (col_a_plus_b_new, option_asc),
897                    ],
898                    // [c_new ASC, d_new ASC, a+b ASC]
899                    vec![
900                        (col_c_new, option_asc),
901                        (col_d_new, option_asc),
902                        (col_a_plus_b_new, option_asc),
903                    ],
904                ],
905            ),
906            // ---------- TEST CASE 2 ------------
907            (
908                // orderings
909                vec![
910                    // [d ASC, b ASC]
911                    vec![(col_d, option_asc), (col_b, option_asc)],
912                    // [c ASC, e ASC], Please note that a=e
913                    vec![(col_c, option_asc), (col_e, option_asc)],
914                ],
915                // equal conditions
916                vec![(col_e, col_a)],
917                // expected
918                vec![
919                    // [d_new ASC, c_new ASC, a+b ASC]
920                    vec![
921                        (col_d_new, option_asc),
922                        (col_c_new, option_asc),
923                        (col_a_plus_b_new, option_asc),
924                    ],
925                    // [c_new ASC, d_new ASC, a+b ASC]
926                    vec![
927                        (col_c_new, option_asc),
928                        (col_d_new, option_asc),
929                        (col_a_plus_b_new, option_asc),
930                    ],
931                ],
932            ),
933            // ---------- TEST CASE 3 ------------
934            (
935                // orderings
936                vec![
937                    // [d ASC, b ASC]
938                    vec![(col_d, option_asc), (col_b, option_asc)],
939                    // [c ASC, e ASC], Please note that a=f
940                    vec![(col_c, option_asc), (col_e, option_asc)],
941                ],
942                // equal conditions
943                vec![(col_a, col_f)],
944                // expected
945                vec![
946                    // [d_new ASC]
947                    vec![(col_d_new, option_asc)],
948                    // [c_new ASC]
949                    vec![(col_c_new, option_asc)],
950                ],
951            ),
952        ];
953        for (orderings, equal_columns, expected) in test_cases {
954            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
955            for (lhs, rhs) in equal_columns {
956                eq_properties.add_equal_conditions(lhs, rhs)?;
957            }
958
959            let orderings = convert_to_orderings(&orderings);
960            eq_properties.add_new_orderings(orderings);
961
962            let expected = convert_to_orderings(&expected);
963
964            let projected_eq =
965                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
966            let orderings = projected_eq.oeq_class();
967
968            let err_msg = format!(
969                "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
970            );
971
972            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
973            for expected_ordering in &expected {
974                assert!(orderings.contains(expected_ordering), "{}", err_msg)
975            }
976        }
977
978        Ok(())
979    }
980}