datafusion_physical_expr/equivalence/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::expressions::Column;
22use crate::PhysicalExpr;
23
24use arrow::datatypes::SchemaRef;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_common::{internal_err, Result};
27
28use indexmap::IndexMap;
29
30/// Stores target expressions, along with their indices, that associate with a
31/// source expression in a projection mapping.
32#[derive(Clone, Debug, Default)]
33pub struct ProjectionTargets {
34    /// A non-empty vector of pairs of target expressions and their indices.
35    /// Consider using a special non-empty collection type in the future (e.g.
36    /// if Rust provides one in the standard library).
37    exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
38}
39
40impl ProjectionTargets {
41    /// Returns the first target expression and its index.
42    pub fn first(&self) -> &(Arc<dyn PhysicalExpr>, usize) {
43        // Since the vector is non-empty, we can safely unwrap:
44        self.exprs_indices.first().unwrap()
45    }
46
47    /// Adds a target expression and its index to the list of targets.
48    pub fn push(&mut self, target: (Arc<dyn PhysicalExpr>, usize)) {
49        self.exprs_indices.push(target);
50    }
51}
52
53impl Deref for ProjectionTargets {
54    type Target = [(Arc<dyn PhysicalExpr>, usize)];
55
56    fn deref(&self) -> &Self::Target {
57        &self.exprs_indices
58    }
59}
60
61impl From<Vec<(Arc<dyn PhysicalExpr>, usize)>> for ProjectionTargets {
62    fn from(exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>) -> Self {
63        Self { exprs_indices }
64    }
65}
66
67/// Stores the mapping between source expressions and target expressions for a
68/// projection.
69#[derive(Clone, Debug)]
70pub struct ProjectionMapping {
71    /// Mapping between source expressions and target expressions.
72    /// Vector indices correspond to the indices after projection.
73    map: IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>,
74}
75
76impl ProjectionMapping {
77    /// Constructs the mapping between a projection's input and output
78    /// expressions.
79    ///
80    /// For example, given the input projection expressions (`a + b`, `c + d`)
81    /// and an output schema with two columns `"c + d"` and `"a + b"`, the
82    /// projection mapping would be:
83    ///
84    /// ```text
85    ///  [0]: (c + d, [(col("c + d"), 0)])
86    ///  [1]: (a + b, [(col("a + b"), 1)])
87    /// ```
88    ///
89    /// where `col("c + d")` means the column named `"c + d"`.
90    pub fn try_new(
91        expr: impl IntoIterator<Item = (Arc<dyn PhysicalExpr>, String)>,
92        input_schema: &SchemaRef,
93    ) -> Result<Self> {
94        // Construct a map from the input expressions to the output expression of the projection:
95        let mut map = IndexMap::<_, ProjectionTargets>::new();
96        for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
97            let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
98            let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
99                Some(col) => {
100                    // Sometimes, an expression and its name in the input_schema
101                    // doesn't match. This can cause problems, so we make sure
102                    // that the expression name matches with the name in `input_schema`.
103                    // Conceptually, `source_expr` and `expression` should be the same.
104                    let idx = col.index();
105                    let matching_field = input_schema.field(idx);
106                    let matching_name = matching_field.name();
107                    if col.name() != matching_name {
108                        return internal_err!(
109                            "Input field name {} does not match with the projection expression {}",
110                            matching_name,
111                            col.name()
112                        );
113                    }
114                    let matching_column = Column::new(matching_name, idx);
115                    Ok(Transformed::yes(Arc::new(matching_column)))
116                }
117                None => Ok(Transformed::no(e)),
118            })
119            .data()?;
120            map.entry(source_expr)
121                .or_default()
122                .push((target_expr, expr_idx));
123        }
124        Ok(Self { map })
125    }
126
127    /// Constructs a subset mapping using the provided indices.
128    ///
129    /// This is used when the output is a subset of the input without any
130    /// other transformations. The indices are for columns in the schema.
131    pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
132        let projection_exprs = indices.iter().map(|index| {
133            let field = schema.field(*index);
134            let column = Arc::new(Column::new(field.name(), *index));
135            (column as _, field.name().clone())
136        });
137        ProjectionMapping::try_new(projection_exprs, schema)
138    }
139}
140
141impl Deref for ProjectionMapping {
142    type Target = IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>;
143
144    fn deref(&self) -> &Self::Target {
145        &self.map
146    }
147}
148
149impl FromIterator<(Arc<dyn PhysicalExpr>, ProjectionTargets)> for ProjectionMapping {
150    fn from_iter<T: IntoIterator<Item = (Arc<dyn PhysicalExpr>, ProjectionTargets)>>(
151        iter: T,
152    ) -> Self {
153        Self {
154            map: IndexMap::from_iter(iter),
155        }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::equivalence::tests::output_schema;
163    use crate::equivalence::{convert_to_orderings, EquivalenceProperties};
164    use crate::expressions::{col, BinaryExpr};
165    use crate::utils::tests::TestScalarUDF;
166    use crate::{PhysicalExprRef, ScalarFunctionExpr};
167
168    use arrow::compute::SortOptions;
169    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
170    use datafusion_common::config::ConfigOptions;
171    use datafusion_expr::{Operator, ScalarUDF};
172
173    #[test]
174    fn project_orderings() -> Result<()> {
175        let schema = Arc::new(Schema::new(vec![
176            Field::new("a", DataType::Int32, true),
177            Field::new("b", DataType::Int32, true),
178            Field::new("c", DataType::Int32, true),
179            Field::new("d", DataType::Int32, true),
180            Field::new("e", DataType::Int32, true),
181            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
182        ]));
183        let col_a = &col("a", &schema)?;
184        let col_b = &col("b", &schema)?;
185        let col_c = &col("c", &schema)?;
186        let col_d = &col("d", &schema)?;
187        let col_e = &col("e", &schema)?;
188        let col_ts = &col("ts", &schema)?;
189        let a_plus_b = Arc::new(BinaryExpr::new(
190            Arc::clone(col_a),
191            Operator::Plus,
192            Arc::clone(col_b),
193        )) as Arc<dyn PhysicalExpr>;
194        let b_plus_d = Arc::new(BinaryExpr::new(
195            Arc::clone(col_b),
196            Operator::Plus,
197            Arc::clone(col_d),
198        )) as Arc<dyn PhysicalExpr>;
199        let b_plus_e = Arc::new(BinaryExpr::new(
200            Arc::clone(col_b),
201            Operator::Plus,
202            Arc::clone(col_e),
203        )) as Arc<dyn PhysicalExpr>;
204        let c_plus_d = Arc::new(BinaryExpr::new(
205            Arc::clone(col_c),
206            Operator::Plus,
207            Arc::clone(col_d),
208        )) as Arc<dyn PhysicalExpr>;
209
210        let option_asc = SortOptions {
211            descending: false,
212            nulls_first: false,
213        };
214        let option_desc = SortOptions {
215            descending: true,
216            nulls_first: true,
217        };
218
219        let test_cases = vec![
220            // ---------- TEST CASE 1 ------------
221            (
222                // orderings
223                vec![
224                    // [b ASC]
225                    vec![(col_b, option_asc)],
226                ],
227                // projection exprs
228                vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
229                // expected
230                vec![
231                    // [b_new ASC]
232                    vec![("b_new", option_asc)],
233                ],
234            ),
235            // ---------- TEST CASE 2 ------------
236            (
237                // orderings
238                vec![
239                    // empty ordering
240                ],
241                // projection exprs
242                vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
243                // expected
244                vec![
245                    // no ordering at the output
246                ],
247            ),
248            // ---------- TEST CASE 3 ------------
249            (
250                // orderings
251                vec![
252                    // [ts ASC]
253                    vec![(col_ts, option_asc)],
254                ],
255                // projection exprs
256                vec![
257                    (col_b, "b_new".to_string()),
258                    (col_a, "a_new".to_string()),
259                    (col_ts, "ts_new".to_string()),
260                ],
261                // expected
262                vec![
263                    // [ts_new ASC]
264                    vec![("ts_new", option_asc)],
265                ],
266            ),
267            // ---------- TEST CASE 4 ------------
268            (
269                // orderings
270                vec![
271                    // [a ASC, ts ASC]
272                    vec![(col_a, option_asc), (col_ts, option_asc)],
273                    // [b ASC, ts ASC]
274                    vec![(col_b, option_asc), (col_ts, option_asc)],
275                ],
276                // projection exprs
277                vec![
278                    (col_b, "b_new".to_string()),
279                    (col_a, "a_new".to_string()),
280                    (col_ts, "ts_new".to_string()),
281                ],
282                // expected
283                vec![
284                    // [a_new ASC, ts_new ASC]
285                    vec![("a_new", option_asc), ("ts_new", option_asc)],
286                    // [b_new ASC, ts_new ASC]
287                    vec![("b_new", option_asc), ("ts_new", option_asc)],
288                ],
289            ),
290            // ---------- TEST CASE 5 ------------
291            (
292                // orderings
293                vec![
294                    // [a + b ASC]
295                    vec![(&a_plus_b, option_asc)],
296                ],
297                // projection exprs
298                vec![
299                    (col_b, "b_new".to_string()),
300                    (col_a, "a_new".to_string()),
301                    (&a_plus_b, "a+b".to_string()),
302                ],
303                // expected
304                vec![
305                    // [a + b ASC]
306                    vec![("a+b", option_asc)],
307                ],
308            ),
309            // ---------- TEST CASE 6 ------------
310            (
311                // orderings
312                vec![
313                    // [a + b ASC, c ASC]
314                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
315                ],
316                // projection exprs
317                vec![
318                    (col_b, "b_new".to_string()),
319                    (col_a, "a_new".to_string()),
320                    (col_c, "c_new".to_string()),
321                    (&a_plus_b, "a+b".to_string()),
322                ],
323                // expected
324                vec![
325                    // [a + b ASC, c_new ASC]
326                    vec![("a+b", option_asc), ("c_new", option_asc)],
327                ],
328            ),
329            // ------- TEST CASE 7 ----------
330            (
331                vec![
332                    // [a ASC, b ASC, c ASC]
333                    vec![(col_a, option_asc), (col_b, option_asc)],
334                    // [a ASC, d ASC]
335                    vec![(col_a, option_asc), (col_d, option_asc)],
336                ],
337                // b as b_new, a as a_new, d as d_new b+d
338                vec![
339                    (col_b, "b_new".to_string()),
340                    (col_a, "a_new".to_string()),
341                    (col_d, "d_new".to_string()),
342                    (&b_plus_d, "b+d".to_string()),
343                ],
344                // expected
345                vec![
346                    // [a_new ASC, b_new ASC]
347                    vec![("a_new", option_asc), ("b_new", option_asc)],
348                    // [a_new ASC, d_new ASC]
349                    vec![("a_new", option_asc), ("d_new", option_asc)],
350                    // [a_new ASC, b+d ASC]
351                    vec![("a_new", option_asc), ("b+d", option_asc)],
352                ],
353            ),
354            // ------- TEST CASE 8 ----------
355            (
356                // orderings
357                vec![
358                    // [b+d ASC]
359                    vec![(&b_plus_d, option_asc)],
360                ],
361                // proj exprs
362                vec![
363                    (col_b, "b_new".to_string()),
364                    (col_a, "a_new".to_string()),
365                    (col_d, "d_new".to_string()),
366                    (&b_plus_d, "b+d".to_string()),
367                ],
368                // expected
369                vec![
370                    // [b+d ASC]
371                    vec![("b+d", option_asc)],
372                ],
373            ),
374            // ------- TEST CASE 9 ----------
375            (
376                // orderings
377                vec![
378                    // [a ASC, d ASC, b ASC]
379                    vec![
380                        (col_a, option_asc),
381                        (col_d, option_asc),
382                        (col_b, option_asc),
383                    ],
384                    // [c ASC]
385                    vec![(col_c, option_asc)],
386                ],
387                // proj exprs
388                vec![
389                    (col_b, "b_new".to_string()),
390                    (col_a, "a_new".to_string()),
391                    (col_d, "d_new".to_string()),
392                    (col_c, "c_new".to_string()),
393                ],
394                // expected
395                vec![
396                    // [a_new ASC, d_new ASC, b_new ASC]
397                    vec![
398                        ("a_new", option_asc),
399                        ("d_new", option_asc),
400                        ("b_new", option_asc),
401                    ],
402                    // [c_new ASC],
403                    vec![("c_new", option_asc)],
404                ],
405            ),
406            // ------- TEST CASE 10 ----------
407            (
408                vec![
409                    // [a ASC, b ASC, c ASC]
410                    vec![
411                        (col_a, option_asc),
412                        (col_b, option_asc),
413                        (col_c, option_asc),
414                    ],
415                    // [a ASC, d ASC]
416                    vec![(col_a, option_asc), (col_d, option_asc)],
417                ],
418                // proj exprs
419                vec![
420                    (col_b, "b_new".to_string()),
421                    (col_a, "a_new".to_string()),
422                    (col_c, "c_new".to_string()),
423                    (&c_plus_d, "c+d".to_string()),
424                ],
425                // expected
426                vec![
427                    // [a_new ASC, b_new ASC, c_new ASC]
428                    vec![
429                        ("a_new", option_asc),
430                        ("b_new", option_asc),
431                        ("c_new", option_asc),
432                    ],
433                    // [a_new ASC, b_new ASC, c+d ASC]
434                    vec![
435                        ("a_new", option_asc),
436                        ("b_new", option_asc),
437                        ("c+d", option_asc),
438                    ],
439                ],
440            ),
441            // ------- TEST CASE 11 ----------
442            (
443                // orderings
444                vec![
445                    // [a ASC, b ASC]
446                    vec![(col_a, option_asc), (col_b, option_asc)],
447                    // [a ASC, d ASC]
448                    vec![(col_a, option_asc), (col_d, option_asc)],
449                ],
450                // proj exprs
451                vec![
452                    (col_b, "b_new".to_string()),
453                    (col_a, "a_new".to_string()),
454                    (&b_plus_d, "b+d".to_string()),
455                ],
456                // expected
457                vec![
458                    // [a_new ASC, b_new ASC]
459                    vec![("a_new", option_asc), ("b_new", option_asc)],
460                    // [a_new ASC, b + d ASC]
461                    vec![("a_new", option_asc), ("b+d", option_asc)],
462                ],
463            ),
464            // ------- TEST CASE 12 ----------
465            (
466                // orderings
467                vec![
468                    // [a ASC, b ASC, c ASC]
469                    vec![
470                        (col_a, option_asc),
471                        (col_b, option_asc),
472                        (col_c, option_asc),
473                    ],
474                ],
475                // proj exprs
476                vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
477                // expected
478                vec![
479                    // [a_new ASC]
480                    vec![("a_new", option_asc)],
481                ],
482            ),
483            // ------- TEST CASE 13 ----------
484            (
485                // orderings
486                vec![
487                    // [a ASC, b ASC, c ASC]
488                    vec![
489                        (col_a, option_asc),
490                        (col_b, option_asc),
491                        (col_c, option_asc),
492                    ],
493                    // [a ASC, a + b ASC, c ASC]
494                    vec![
495                        (col_a, option_asc),
496                        (&a_plus_b, option_asc),
497                        (col_c, option_asc),
498                    ],
499                ],
500                // proj exprs
501                vec![
502                    (col_c, "c_new".to_string()),
503                    (col_b, "b_new".to_string()),
504                    (col_a, "a_new".to_string()),
505                    (&a_plus_b, "a+b".to_string()),
506                ],
507                // expected
508                vec![
509                    // [a_new ASC, b_new ASC, c_new ASC]
510                    vec![
511                        ("a_new", option_asc),
512                        ("b_new", option_asc),
513                        ("c_new", option_asc),
514                    ],
515                    // [a_new ASC, a+b ASC, c_new ASC]
516                    vec![
517                        ("a_new", option_asc),
518                        ("a+b", option_asc),
519                        ("c_new", option_asc),
520                    ],
521                ],
522            ),
523            // ------- TEST CASE 14 ----------
524            (
525                // orderings
526                vec![
527                    // [a ASC, b ASC]
528                    vec![(col_a, option_asc), (col_b, option_asc)],
529                    // [c ASC, b ASC]
530                    vec![(col_c, option_asc), (col_b, option_asc)],
531                    // [d ASC, e ASC]
532                    vec![(col_d, option_asc), (col_e, option_asc)],
533                ],
534                // proj exprs
535                vec![
536                    (col_c, "c_new".to_string()),
537                    (col_d, "d_new".to_string()),
538                    (col_a, "a_new".to_string()),
539                    (&b_plus_e, "b+e".to_string()),
540                ],
541                // expected
542                vec![
543                    // [a_new ASC, d_new ASC, b+e ASC]
544                    vec![
545                        ("a_new", option_asc),
546                        ("d_new", option_asc),
547                        ("b+e", option_asc),
548                    ],
549                    // [d_new ASC, a_new ASC, b+e ASC]
550                    vec![
551                        ("d_new", option_asc),
552                        ("a_new", option_asc),
553                        ("b+e", option_asc),
554                    ],
555                    // [c_new ASC, d_new ASC, b+e ASC]
556                    vec![
557                        ("c_new", option_asc),
558                        ("d_new", option_asc),
559                        ("b+e", option_asc),
560                    ],
561                    // [d_new ASC, c_new ASC, b+e ASC]
562                    vec![
563                        ("d_new", option_asc),
564                        ("c_new", option_asc),
565                        ("b+e", option_asc),
566                    ],
567                ],
568            ),
569            // ------- TEST CASE 15 ----------
570            (
571                // orderings
572                vec![
573                    // [a ASC, c ASC, b ASC]
574                    vec![
575                        (col_a, option_asc),
576                        (col_c, option_asc),
577                        (col_b, option_asc),
578                    ],
579                ],
580                // proj exprs
581                vec![
582                    (col_c, "c_new".to_string()),
583                    (col_a, "a_new".to_string()),
584                    (&a_plus_b, "a+b".to_string()),
585                ],
586                // expected
587                vec![
588                    // [a_new ASC, d_new ASC, b+e ASC]
589                    vec![
590                        ("a_new", option_asc),
591                        ("c_new", option_asc),
592                        ("a+b", option_asc),
593                    ],
594                ],
595            ),
596            // ------- TEST CASE 16 ----------
597            (
598                // orderings
599                vec![
600                    // [a ASC, b ASC]
601                    vec![(col_a, option_asc), (col_b, option_asc)],
602                    // [c ASC, b DESC]
603                    vec![(col_c, option_asc), (col_b, option_desc)],
604                    // [e ASC]
605                    vec![(col_e, option_asc)],
606                ],
607                // proj exprs
608                vec![
609                    (col_c, "c_new".to_string()),
610                    (col_a, "a_new".to_string()),
611                    (col_b, "b_new".to_string()),
612                    (&b_plus_e, "b+e".to_string()),
613                ],
614                // expected
615                vec![
616                    // [a_new ASC, b_new ASC]
617                    vec![("a_new", option_asc), ("b_new", option_asc)],
618                    // [a_new ASC, b_new ASC]
619                    vec![("a_new", option_asc), ("b+e", option_asc)],
620                    // [c_new ASC, b_new DESC]
621                    vec![("c_new", option_asc), ("b_new", option_desc)],
622                ],
623            ),
624        ];
625
626        for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
627        {
628            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
629
630            let orderings = convert_to_orderings(&orderings);
631            eq_properties.add_orderings(orderings);
632
633            let proj_exprs = proj_exprs
634                .into_iter()
635                .map(|(expr, name)| (Arc::clone(expr), name));
636            let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
637            let output_schema = output_schema(&projection_mapping, &schema)?;
638
639            let expected = expected
640                .into_iter()
641                .map(|ordering| {
642                    ordering
643                        .into_iter()
644                        .map(|(name, options)| {
645                            (col(name, &output_schema).unwrap(), options)
646                        })
647                        .collect::<Vec<_>>()
648                })
649                .collect::<Vec<_>>();
650            let expected = convert_to_orderings(&expected);
651
652            let projected_eq = eq_properties.project(&projection_mapping, output_schema);
653            let orderings = projected_eq.oeq_class();
654
655            let err_msg = format!(
656                "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
657            );
658
659            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
660            for expected_ordering in &expected {
661                assert!(orderings.contains(expected_ordering), "{}", err_msg)
662            }
663        }
664
665        Ok(())
666    }
667
668    #[test]
669    fn project_orderings2() -> Result<()> {
670        let schema = Arc::new(Schema::new(vec![
671            Field::new("a", DataType::Int32, true),
672            Field::new("b", DataType::Int32, true),
673            Field::new("c", DataType::Int32, true),
674            Field::new("d", DataType::Int32, true),
675            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
676        ]));
677        let col_a = &col("a", &schema)?;
678        let col_b = &col("b", &schema)?;
679        let col_c = &col("c", &schema)?;
680        let col_ts = &col("ts", &schema)?;
681        let a_plus_b = Arc::new(BinaryExpr::new(
682            Arc::clone(col_a),
683            Operator::Plus,
684            Arc::clone(col_b),
685        )) as Arc<dyn PhysicalExpr>;
686
687        let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
688
689        let round_c = Arc::new(ScalarFunctionExpr::try_new(
690            test_fun,
691            vec![Arc::clone(col_c)],
692            &schema,
693            Arc::new(ConfigOptions::default()),
694        )?) as PhysicalExprRef;
695
696        let option_asc = SortOptions {
697            descending: false,
698            nulls_first: false,
699        };
700
701        let proj_exprs = vec![
702            (col_b, "b_new".to_string()),
703            (col_a, "a_new".to_string()),
704            (col_c, "c_new".to_string()),
705            (&round_c, "round_c_res".to_string()),
706        ];
707        let proj_exprs = proj_exprs
708            .into_iter()
709            .map(|(expr, name)| (Arc::clone(expr), name));
710        let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
711        let output_schema = output_schema(&projection_mapping, &schema)?;
712
713        let col_a_new = &col("a_new", &output_schema)?;
714        let col_b_new = &col("b_new", &output_schema)?;
715        let col_c_new = &col("c_new", &output_schema)?;
716        let col_round_c_res = &col("round_c_res", &output_schema)?;
717        let a_new_plus_b_new = Arc::new(BinaryExpr::new(
718            Arc::clone(col_a_new),
719            Operator::Plus,
720            Arc::clone(col_b_new),
721        )) as Arc<dyn PhysicalExpr>;
722
723        let test_cases = vec![
724            // ---------- TEST CASE 1 ------------
725            (
726                // orderings
727                vec![
728                    // [a ASC]
729                    vec![(col_a, option_asc)],
730                ],
731                // expected
732                vec![
733                    // [b_new ASC]
734                    vec![(col_a_new, option_asc)],
735                ],
736            ),
737            // ---------- TEST CASE 2 ------------
738            (
739                // orderings
740                vec![
741                    // [a+b ASC]
742                    vec![(&a_plus_b, option_asc)],
743                ],
744                // expected
745                vec![
746                    // [b_new ASC]
747                    vec![(&a_new_plus_b_new, option_asc)],
748                ],
749            ),
750            // ---------- TEST CASE 3 ------------
751            (
752                // orderings
753                vec![
754                    // [a ASC, ts ASC]
755                    vec![(col_a, option_asc), (col_ts, option_asc)],
756                ],
757                // expected
758                vec![
759                    // [a_new ASC, date_bin_res ASC]
760                    vec![(col_a_new, option_asc)],
761                ],
762            ),
763            // ---------- TEST CASE 4 ------------
764            (
765                // orderings
766                vec![
767                    // [a ASC, ts ASC, b ASC]
768                    vec![
769                        (col_a, option_asc),
770                        (col_ts, option_asc),
771                        (col_b, option_asc),
772                    ],
773                ],
774                // expected
775                vec![
776                    // [a_new ASC, date_bin_res ASC]
777                    vec![(col_a_new, option_asc)],
778                ],
779            ),
780            // ---------- TEST CASE 5 ------------
781            (
782                // orderings
783                vec![
784                    // [a ASC, c ASC]
785                    vec![(col_a, option_asc), (col_c, option_asc)],
786                ],
787                // expected
788                vec![
789                    // [a_new ASC, round_c_res ASC, c_new ASC]
790                    vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
791                    // [a_new ASC, c_new ASC]
792                    vec![(col_a_new, option_asc), (col_c_new, option_asc)],
793                ],
794            ),
795            // ---------- TEST CASE 6 ------------
796            (
797                // orderings
798                vec![
799                    // [c ASC, b ASC]
800                    vec![(col_c, option_asc), (col_b, option_asc)],
801                ],
802                // expected
803                vec![
804                    // [round_c_res ASC]
805                    vec![(col_round_c_res, option_asc)],
806                    // [c_new ASC, b_new ASC]
807                    vec![(col_c_new, option_asc), (col_b_new, option_asc)],
808                ],
809            ),
810            // ---------- TEST CASE 7 ------------
811            (
812                // orderings
813                vec![
814                    // [a+b ASC, c ASC]
815                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
816                ],
817                // expected
818                vec![
819                    // [a+b ASC, round(c) ASC, c_new ASC]
820                    vec![
821                        (&a_new_plus_b_new, option_asc),
822                        (col_round_c_res, option_asc),
823                    ],
824                    // [a+b ASC, c_new ASC]
825                    vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
826                ],
827            ),
828        ];
829
830        for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
831            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
832
833            let orderings = convert_to_orderings(orderings);
834            eq_properties.add_orderings(orderings);
835
836            let expected = convert_to_orderings(expected);
837
838            let projected_eq =
839                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
840            let orderings = projected_eq.oeq_class();
841
842            let err_msg = format!(
843                "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
844            );
845
846            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
847            for expected_ordering in &expected {
848                assert!(orderings.contains(expected_ordering), "{}", err_msg)
849            }
850        }
851        Ok(())
852    }
853
854    #[test]
855    fn project_orderings3() -> Result<()> {
856        let schema = Arc::new(Schema::new(vec![
857            Field::new("a", DataType::Int32, true),
858            Field::new("b", DataType::Int32, true),
859            Field::new("c", DataType::Int32, true),
860            Field::new("d", DataType::Int32, true),
861            Field::new("e", DataType::Int32, true),
862            Field::new("f", DataType::Int32, true),
863        ]));
864        let col_a = &col("a", &schema)?;
865        let col_b = &col("b", &schema)?;
866        let col_c = &col("c", &schema)?;
867        let col_d = &col("d", &schema)?;
868        let col_e = &col("e", &schema)?;
869        let col_f = &col("f", &schema)?;
870        let a_plus_b = Arc::new(BinaryExpr::new(
871            Arc::clone(col_a),
872            Operator::Plus,
873            Arc::clone(col_b),
874        )) as Arc<dyn PhysicalExpr>;
875
876        let option_asc = SortOptions {
877            descending: false,
878            nulls_first: false,
879        };
880
881        let proj_exprs = vec![
882            (col_c, "c_new".to_string()),
883            (col_d, "d_new".to_string()),
884            (&a_plus_b, "a+b".to_string()),
885        ];
886        let proj_exprs = proj_exprs
887            .into_iter()
888            .map(|(expr, name)| (Arc::clone(expr), name));
889        let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
890        let output_schema = output_schema(&projection_mapping, &schema)?;
891
892        let col_a_plus_b_new = &col("a+b", &output_schema)?;
893        let col_c_new = &col("c_new", &output_schema)?;
894        let col_d_new = &col("d_new", &output_schema)?;
895
896        let test_cases = vec![
897            // ---------- TEST CASE 1 ------------
898            (
899                // orderings
900                vec![
901                    // [d ASC, b ASC]
902                    vec![(col_d, option_asc), (col_b, option_asc)],
903                    // [c ASC, a ASC]
904                    vec![(col_c, option_asc), (col_a, option_asc)],
905                ],
906                // equal conditions
907                vec![],
908                // expected
909                vec![
910                    // [d_new ASC, c_new ASC, a+b ASC]
911                    vec![
912                        (col_d_new, option_asc),
913                        (col_c_new, option_asc),
914                        (col_a_plus_b_new, option_asc),
915                    ],
916                    // [c_new ASC, d_new ASC, a+b ASC]
917                    vec![
918                        (col_c_new, option_asc),
919                        (col_d_new, option_asc),
920                        (col_a_plus_b_new, option_asc),
921                    ],
922                ],
923            ),
924            // ---------- TEST CASE 2 ------------
925            (
926                // orderings
927                vec![
928                    // [d ASC, b ASC]
929                    vec![(col_d, option_asc), (col_b, option_asc)],
930                    // [c ASC, e ASC], Please note that a=e
931                    vec![(col_c, option_asc), (col_e, option_asc)],
932                ],
933                // equal conditions
934                vec![(col_e, col_a)],
935                // expected
936                vec![
937                    // [d_new ASC, c_new ASC, a+b ASC]
938                    vec![
939                        (col_d_new, option_asc),
940                        (col_c_new, option_asc),
941                        (col_a_plus_b_new, option_asc),
942                    ],
943                    // [c_new ASC, d_new ASC, a+b ASC]
944                    vec![
945                        (col_c_new, option_asc),
946                        (col_d_new, option_asc),
947                        (col_a_plus_b_new, option_asc),
948                    ],
949                ],
950            ),
951            // ---------- TEST CASE 3 ------------
952            (
953                // orderings
954                vec![
955                    // [d ASC, b ASC]
956                    vec![(col_d, option_asc), (col_b, option_asc)],
957                    // [c ASC, e ASC], Please note that a=f
958                    vec![(col_c, option_asc), (col_e, option_asc)],
959                ],
960                // equal conditions
961                vec![(col_a, col_f)],
962                // expected
963                vec![
964                    // [d_new ASC]
965                    vec![(col_d_new, option_asc)],
966                    // [c_new ASC]
967                    vec![(col_c_new, option_asc)],
968                ],
969            ),
970        ];
971        for (orderings, equal_columns, expected) in test_cases {
972            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
973            for (lhs, rhs) in equal_columns {
974                eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?;
975            }
976
977            let orderings = convert_to_orderings(&orderings);
978            eq_properties.add_orderings(orderings);
979
980            let expected = convert_to_orderings(&expected);
981
982            let projected_eq =
983                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
984            let orderings = projected_eq.oeq_class();
985
986            let err_msg = format!(
987                "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
988            );
989
990            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
991            for expected_ordering in &expected {
992                assert!(orderings.contains(expected_ordering), "{}", err_msg)
993            }
994        }
995
996        Ok(())
997    }
998}