datafusion_physical_expr/equivalence/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::expressions::Column;
22use crate::PhysicalExpr;
23
24use arrow::datatypes::SchemaRef;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_common::{internal_err, Result};
27
28use indexmap::IndexMap;
29
30/// Stores target expressions, along with their indices, that associate with a
31/// source expression in a projection mapping.
32#[derive(Clone, Debug, Default)]
33pub struct ProjectionTargets {
34    /// A non-empty vector of pairs of target expressions and their indices.
35    /// Consider using a special non-empty collection type in the future (e.g.
36    /// if Rust provides one in the standard library).
37    exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
38}
39
40impl ProjectionTargets {
41    /// Returns the first target expression and its index.
42    pub fn first(&self) -> &(Arc<dyn PhysicalExpr>, usize) {
43        // Since the vector is non-empty, we can safely unwrap:
44        self.exprs_indices.first().unwrap()
45    }
46
47    /// Adds a target expression and its index to the list of targets.
48    pub fn push(&mut self, target: (Arc<dyn PhysicalExpr>, usize)) {
49        self.exprs_indices.push(target);
50    }
51}
52
53impl Deref for ProjectionTargets {
54    type Target = [(Arc<dyn PhysicalExpr>, usize)];
55
56    fn deref(&self) -> &Self::Target {
57        &self.exprs_indices
58    }
59}
60
61impl From<Vec<(Arc<dyn PhysicalExpr>, usize)>> for ProjectionTargets {
62    fn from(exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>) -> Self {
63        Self { exprs_indices }
64    }
65}
66
67/// Stores the mapping between source expressions and target expressions for a
68/// projection.
69#[derive(Clone, Debug)]
70pub struct ProjectionMapping {
71    /// Mapping between source expressions and target expressions.
72    /// Vector indices correspond to the indices after projection.
73    map: IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>,
74}
75
76impl ProjectionMapping {
77    /// Constructs the mapping between a projection's input and output
78    /// expressions.
79    ///
80    /// For example, given the input projection expressions (`a + b`, `c + d`)
81    /// and an output schema with two columns `"c + d"` and `"a + b"`, the
82    /// projection mapping would be:
83    ///
84    /// ```text
85    ///  [0]: (c + d, [(col("c + d"), 0)])
86    ///  [1]: (a + b, [(col("a + b"), 1)])
87    /// ```
88    ///
89    /// where `col("c + d")` means the column named `"c + d"`.
90    pub fn try_new(
91        expr: impl IntoIterator<Item = (Arc<dyn PhysicalExpr>, String)>,
92        input_schema: &SchemaRef,
93    ) -> Result<Self> {
94        // Construct a map from the input expressions to the output expression of the projection:
95        let mut map = IndexMap::<_, ProjectionTargets>::new();
96        for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
97            let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
98            let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
99                Some(col) => {
100                    // Sometimes, an expression and its name in the input_schema
101                    // doesn't match. This can cause problems, so we make sure
102                    // that the expression name matches with the name in `input_schema`.
103                    // Conceptually, `source_expr` and `expression` should be the same.
104                    let idx = col.index();
105                    let matching_field = input_schema.field(idx);
106                    let matching_name = matching_field.name();
107                    if col.name() != matching_name {
108                        return internal_err!(
109                            "Input field name {} does not match with the projection expression {}",
110                            matching_name,
111                            col.name()
112                        );
113                    }
114                    let matching_column = Column::new(matching_name, idx);
115                    Ok(Transformed::yes(Arc::new(matching_column)))
116                }
117                None => Ok(Transformed::no(e)),
118            })
119            .data()?;
120            map.entry(source_expr)
121                .or_default()
122                .push((target_expr, expr_idx));
123        }
124        Ok(Self { map })
125    }
126
127    /// Constructs a subset mapping using the provided indices.
128    ///
129    /// This is used when the output is a subset of the input without any
130    /// other transformations. The indices are for columns in the schema.
131    pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
132        let projection_exprs = indices.iter().map(|index| {
133            let field = schema.field(*index);
134            let column = Arc::new(Column::new(field.name(), *index));
135            (column as _, field.name().clone())
136        });
137        ProjectionMapping::try_new(projection_exprs, schema)
138    }
139}
140
141impl Deref for ProjectionMapping {
142    type Target = IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>;
143
144    fn deref(&self) -> &Self::Target {
145        &self.map
146    }
147}
148
149impl FromIterator<(Arc<dyn PhysicalExpr>, ProjectionTargets)> for ProjectionMapping {
150    fn from_iter<T: IntoIterator<Item = (Arc<dyn PhysicalExpr>, ProjectionTargets)>>(
151        iter: T,
152    ) -> Self {
153        Self {
154            map: IndexMap::from_iter(iter),
155        }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::equivalence::tests::output_schema;
163    use crate::equivalence::{convert_to_orderings, EquivalenceProperties};
164    use crate::expressions::{col, BinaryExpr};
165    use crate::utils::tests::TestScalarUDF;
166    use crate::{PhysicalExprRef, ScalarFunctionExpr};
167
168    use arrow::compute::SortOptions;
169    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
170    use datafusion_expr::{Operator, ScalarUDF};
171
172    #[test]
173    fn project_orderings() -> Result<()> {
174        let schema = Arc::new(Schema::new(vec![
175            Field::new("a", DataType::Int32, true),
176            Field::new("b", DataType::Int32, true),
177            Field::new("c", DataType::Int32, true),
178            Field::new("d", DataType::Int32, true),
179            Field::new("e", DataType::Int32, true),
180            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
181        ]));
182        let col_a = &col("a", &schema)?;
183        let col_b = &col("b", &schema)?;
184        let col_c = &col("c", &schema)?;
185        let col_d = &col("d", &schema)?;
186        let col_e = &col("e", &schema)?;
187        let col_ts = &col("ts", &schema)?;
188        let a_plus_b = Arc::new(BinaryExpr::new(
189            Arc::clone(col_a),
190            Operator::Plus,
191            Arc::clone(col_b),
192        )) as Arc<dyn PhysicalExpr>;
193        let b_plus_d = Arc::new(BinaryExpr::new(
194            Arc::clone(col_b),
195            Operator::Plus,
196            Arc::clone(col_d),
197        )) as Arc<dyn PhysicalExpr>;
198        let b_plus_e = Arc::new(BinaryExpr::new(
199            Arc::clone(col_b),
200            Operator::Plus,
201            Arc::clone(col_e),
202        )) as Arc<dyn PhysicalExpr>;
203        let c_plus_d = Arc::new(BinaryExpr::new(
204            Arc::clone(col_c),
205            Operator::Plus,
206            Arc::clone(col_d),
207        )) as Arc<dyn PhysicalExpr>;
208
209        let option_asc = SortOptions {
210            descending: false,
211            nulls_first: false,
212        };
213        let option_desc = SortOptions {
214            descending: true,
215            nulls_first: true,
216        };
217
218        let test_cases = vec![
219            // ---------- TEST CASE 1 ------------
220            (
221                // orderings
222                vec![
223                    // [b ASC]
224                    vec![(col_b, option_asc)],
225                ],
226                // projection exprs
227                vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
228                // expected
229                vec![
230                    // [b_new ASC]
231                    vec![("b_new", option_asc)],
232                ],
233            ),
234            // ---------- TEST CASE 2 ------------
235            (
236                // orderings
237                vec![
238                    // empty ordering
239                ],
240                // projection exprs
241                vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
242                // expected
243                vec![
244                    // no ordering at the output
245                ],
246            ),
247            // ---------- TEST CASE 3 ------------
248            (
249                // orderings
250                vec![
251                    // [ts ASC]
252                    vec![(col_ts, option_asc)],
253                ],
254                // projection exprs
255                vec![
256                    (col_b, "b_new".to_string()),
257                    (col_a, "a_new".to_string()),
258                    (col_ts, "ts_new".to_string()),
259                ],
260                // expected
261                vec![
262                    // [ts_new ASC]
263                    vec![("ts_new", option_asc)],
264                ],
265            ),
266            // ---------- TEST CASE 4 ------------
267            (
268                // orderings
269                vec![
270                    // [a ASC, ts ASC]
271                    vec![(col_a, option_asc), (col_ts, option_asc)],
272                    // [b ASC, ts ASC]
273                    vec![(col_b, option_asc), (col_ts, option_asc)],
274                ],
275                // projection exprs
276                vec![
277                    (col_b, "b_new".to_string()),
278                    (col_a, "a_new".to_string()),
279                    (col_ts, "ts_new".to_string()),
280                ],
281                // expected
282                vec![
283                    // [a_new ASC, ts_new ASC]
284                    vec![("a_new", option_asc), ("ts_new", option_asc)],
285                    // [b_new ASC, ts_new ASC]
286                    vec![("b_new", option_asc), ("ts_new", option_asc)],
287                ],
288            ),
289            // ---------- TEST CASE 5 ------------
290            (
291                // orderings
292                vec![
293                    // [a + b ASC]
294                    vec![(&a_plus_b, option_asc)],
295                ],
296                // projection exprs
297                vec![
298                    (col_b, "b_new".to_string()),
299                    (col_a, "a_new".to_string()),
300                    (&a_plus_b, "a+b".to_string()),
301                ],
302                // expected
303                vec![
304                    // [a + b ASC]
305                    vec![("a+b", option_asc)],
306                ],
307            ),
308            // ---------- TEST CASE 6 ------------
309            (
310                // orderings
311                vec![
312                    // [a + b ASC, c ASC]
313                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
314                ],
315                // projection exprs
316                vec![
317                    (col_b, "b_new".to_string()),
318                    (col_a, "a_new".to_string()),
319                    (col_c, "c_new".to_string()),
320                    (&a_plus_b, "a+b".to_string()),
321                ],
322                // expected
323                vec![
324                    // [a + b ASC, c_new ASC]
325                    vec![("a+b", option_asc), ("c_new", option_asc)],
326                ],
327            ),
328            // ------- TEST CASE 7 ----------
329            (
330                vec![
331                    // [a ASC, b ASC, c ASC]
332                    vec![(col_a, option_asc), (col_b, option_asc)],
333                    // [a ASC, d ASC]
334                    vec![(col_a, option_asc), (col_d, option_asc)],
335                ],
336                // b as b_new, a as a_new, d as d_new b+d
337                vec![
338                    (col_b, "b_new".to_string()),
339                    (col_a, "a_new".to_string()),
340                    (col_d, "d_new".to_string()),
341                    (&b_plus_d, "b+d".to_string()),
342                ],
343                // expected
344                vec![
345                    // [a_new ASC, b_new ASC]
346                    vec![("a_new", option_asc), ("b_new", option_asc)],
347                    // [a_new ASC, d_new ASC]
348                    vec![("a_new", option_asc), ("d_new", option_asc)],
349                    // [a_new ASC, b+d ASC]
350                    vec![("a_new", option_asc), ("b+d", option_asc)],
351                ],
352            ),
353            // ------- TEST CASE 8 ----------
354            (
355                // orderings
356                vec![
357                    // [b+d ASC]
358                    vec![(&b_plus_d, option_asc)],
359                ],
360                // proj exprs
361                vec![
362                    (col_b, "b_new".to_string()),
363                    (col_a, "a_new".to_string()),
364                    (col_d, "d_new".to_string()),
365                    (&b_plus_d, "b+d".to_string()),
366                ],
367                // expected
368                vec![
369                    // [b+d ASC]
370                    vec![("b+d", option_asc)],
371                ],
372            ),
373            // ------- TEST CASE 9 ----------
374            (
375                // orderings
376                vec![
377                    // [a ASC, d ASC, b ASC]
378                    vec![
379                        (col_a, option_asc),
380                        (col_d, option_asc),
381                        (col_b, option_asc),
382                    ],
383                    // [c ASC]
384                    vec![(col_c, option_asc)],
385                ],
386                // proj exprs
387                vec![
388                    (col_b, "b_new".to_string()),
389                    (col_a, "a_new".to_string()),
390                    (col_d, "d_new".to_string()),
391                    (col_c, "c_new".to_string()),
392                ],
393                // expected
394                vec![
395                    // [a_new ASC, d_new ASC, b_new ASC]
396                    vec![
397                        ("a_new", option_asc),
398                        ("d_new", option_asc),
399                        ("b_new", option_asc),
400                    ],
401                    // [c_new ASC],
402                    vec![("c_new", option_asc)],
403                ],
404            ),
405            // ------- TEST CASE 10 ----------
406            (
407                vec![
408                    // [a ASC, b ASC, c ASC]
409                    vec![
410                        (col_a, option_asc),
411                        (col_b, option_asc),
412                        (col_c, option_asc),
413                    ],
414                    // [a ASC, d ASC]
415                    vec![(col_a, option_asc), (col_d, option_asc)],
416                ],
417                // proj exprs
418                vec![
419                    (col_b, "b_new".to_string()),
420                    (col_a, "a_new".to_string()),
421                    (col_c, "c_new".to_string()),
422                    (&c_plus_d, "c+d".to_string()),
423                ],
424                // expected
425                vec![
426                    // [a_new ASC, b_new ASC, c_new ASC]
427                    vec![
428                        ("a_new", option_asc),
429                        ("b_new", option_asc),
430                        ("c_new", option_asc),
431                    ],
432                    // [a_new ASC, b_new ASC, c+d ASC]
433                    vec![
434                        ("a_new", option_asc),
435                        ("b_new", option_asc),
436                        ("c+d", option_asc),
437                    ],
438                ],
439            ),
440            // ------- TEST CASE 11 ----------
441            (
442                // orderings
443                vec![
444                    // [a ASC, b ASC]
445                    vec![(col_a, option_asc), (col_b, option_asc)],
446                    // [a ASC, d ASC]
447                    vec![(col_a, option_asc), (col_d, option_asc)],
448                ],
449                // proj exprs
450                vec![
451                    (col_b, "b_new".to_string()),
452                    (col_a, "a_new".to_string()),
453                    (&b_plus_d, "b+d".to_string()),
454                ],
455                // expected
456                vec![
457                    // [a_new ASC, b_new ASC]
458                    vec![("a_new", option_asc), ("b_new", option_asc)],
459                    // [a_new ASC, b + d ASC]
460                    vec![("a_new", option_asc), ("b+d", option_asc)],
461                ],
462            ),
463            // ------- TEST CASE 12 ----------
464            (
465                // orderings
466                vec![
467                    // [a ASC, b ASC, c ASC]
468                    vec![
469                        (col_a, option_asc),
470                        (col_b, option_asc),
471                        (col_c, option_asc),
472                    ],
473                ],
474                // proj exprs
475                vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
476                // expected
477                vec![
478                    // [a_new ASC]
479                    vec![("a_new", option_asc)],
480                ],
481            ),
482            // ------- TEST CASE 13 ----------
483            (
484                // orderings
485                vec![
486                    // [a ASC, b ASC, c ASC]
487                    vec![
488                        (col_a, option_asc),
489                        (col_b, option_asc),
490                        (col_c, option_asc),
491                    ],
492                    // [a ASC, a + b ASC, c ASC]
493                    vec![
494                        (col_a, option_asc),
495                        (&a_plus_b, option_asc),
496                        (col_c, option_asc),
497                    ],
498                ],
499                // proj exprs
500                vec![
501                    (col_c, "c_new".to_string()),
502                    (col_b, "b_new".to_string()),
503                    (col_a, "a_new".to_string()),
504                    (&a_plus_b, "a+b".to_string()),
505                ],
506                // expected
507                vec![
508                    // [a_new ASC, b_new ASC, c_new ASC]
509                    vec![
510                        ("a_new", option_asc),
511                        ("b_new", option_asc),
512                        ("c_new", option_asc),
513                    ],
514                    // [a_new ASC, a+b ASC, c_new ASC]
515                    vec![
516                        ("a_new", option_asc),
517                        ("a+b", option_asc),
518                        ("c_new", option_asc),
519                    ],
520                ],
521            ),
522            // ------- TEST CASE 14 ----------
523            (
524                // orderings
525                vec![
526                    // [a ASC, b ASC]
527                    vec![(col_a, option_asc), (col_b, option_asc)],
528                    // [c ASC, b ASC]
529                    vec![(col_c, option_asc), (col_b, option_asc)],
530                    // [d ASC, e ASC]
531                    vec![(col_d, option_asc), (col_e, option_asc)],
532                ],
533                // proj exprs
534                vec![
535                    (col_c, "c_new".to_string()),
536                    (col_d, "d_new".to_string()),
537                    (col_a, "a_new".to_string()),
538                    (&b_plus_e, "b+e".to_string()),
539                ],
540                // expected
541                vec![
542                    // [a_new ASC, d_new ASC, b+e ASC]
543                    vec![
544                        ("a_new", option_asc),
545                        ("d_new", option_asc),
546                        ("b+e", option_asc),
547                    ],
548                    // [d_new ASC, a_new ASC, b+e ASC]
549                    vec![
550                        ("d_new", option_asc),
551                        ("a_new", option_asc),
552                        ("b+e", option_asc),
553                    ],
554                    // [c_new ASC, d_new ASC, b+e ASC]
555                    vec![
556                        ("c_new", option_asc),
557                        ("d_new", option_asc),
558                        ("b+e", option_asc),
559                    ],
560                    // [d_new ASC, c_new ASC, b+e ASC]
561                    vec![
562                        ("d_new", option_asc),
563                        ("c_new", option_asc),
564                        ("b+e", option_asc),
565                    ],
566                ],
567            ),
568            // ------- TEST CASE 15 ----------
569            (
570                // orderings
571                vec![
572                    // [a ASC, c ASC, b ASC]
573                    vec![
574                        (col_a, option_asc),
575                        (col_c, option_asc),
576                        (col_b, option_asc),
577                    ],
578                ],
579                // proj exprs
580                vec![
581                    (col_c, "c_new".to_string()),
582                    (col_a, "a_new".to_string()),
583                    (&a_plus_b, "a+b".to_string()),
584                ],
585                // expected
586                vec![
587                    // [a_new ASC, d_new ASC, b+e ASC]
588                    vec![
589                        ("a_new", option_asc),
590                        ("c_new", option_asc),
591                        ("a+b", option_asc),
592                    ],
593                ],
594            ),
595            // ------- TEST CASE 16 ----------
596            (
597                // orderings
598                vec![
599                    // [a ASC, b ASC]
600                    vec![(col_a, option_asc), (col_b, option_asc)],
601                    // [c ASC, b DESC]
602                    vec![(col_c, option_asc), (col_b, option_desc)],
603                    // [e ASC]
604                    vec![(col_e, option_asc)],
605                ],
606                // proj exprs
607                vec![
608                    (col_c, "c_new".to_string()),
609                    (col_a, "a_new".to_string()),
610                    (col_b, "b_new".to_string()),
611                    (&b_plus_e, "b+e".to_string()),
612                ],
613                // expected
614                vec![
615                    // [a_new ASC, b_new ASC]
616                    vec![("a_new", option_asc), ("b_new", option_asc)],
617                    // [a_new ASC, b_new ASC]
618                    vec![("a_new", option_asc), ("b+e", option_asc)],
619                    // [c_new ASC, b_new DESC]
620                    vec![("c_new", option_asc), ("b_new", option_desc)],
621                ],
622            ),
623        ];
624
625        for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
626        {
627            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
628
629            let orderings = convert_to_orderings(&orderings);
630            eq_properties.add_orderings(orderings);
631
632            let proj_exprs = proj_exprs
633                .into_iter()
634                .map(|(expr, name)| (Arc::clone(expr), name));
635            let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
636            let output_schema = output_schema(&projection_mapping, &schema)?;
637
638            let expected = expected
639                .into_iter()
640                .map(|ordering| {
641                    ordering
642                        .into_iter()
643                        .map(|(name, options)| {
644                            (col(name, &output_schema).unwrap(), options)
645                        })
646                        .collect::<Vec<_>>()
647                })
648                .collect::<Vec<_>>();
649            let expected = convert_to_orderings(&expected);
650
651            let projected_eq = eq_properties.project(&projection_mapping, output_schema);
652            let orderings = projected_eq.oeq_class();
653
654            let err_msg = format!(
655                "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
656            );
657
658            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
659            for expected_ordering in &expected {
660                assert!(orderings.contains(expected_ordering), "{}", err_msg)
661            }
662        }
663
664        Ok(())
665    }
666
667    #[test]
668    fn project_orderings2() -> Result<()> {
669        let schema = Arc::new(Schema::new(vec![
670            Field::new("a", DataType::Int32, true),
671            Field::new("b", DataType::Int32, true),
672            Field::new("c", DataType::Int32, true),
673            Field::new("d", DataType::Int32, true),
674            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
675        ]));
676        let col_a = &col("a", &schema)?;
677        let col_b = &col("b", &schema)?;
678        let col_c = &col("c", &schema)?;
679        let col_ts = &col("ts", &schema)?;
680        let a_plus_b = Arc::new(BinaryExpr::new(
681            Arc::clone(col_a),
682            Operator::Plus,
683            Arc::clone(col_b),
684        )) as Arc<dyn PhysicalExpr>;
685
686        let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
687
688        let round_c = Arc::new(ScalarFunctionExpr::try_new(
689            test_fun,
690            vec![Arc::clone(col_c)],
691            &schema,
692        )?) as PhysicalExprRef;
693
694        let option_asc = SortOptions {
695            descending: false,
696            nulls_first: false,
697        };
698
699        let proj_exprs = vec![
700            (col_b, "b_new".to_string()),
701            (col_a, "a_new".to_string()),
702            (col_c, "c_new".to_string()),
703            (&round_c, "round_c_res".to_string()),
704        ];
705        let proj_exprs = proj_exprs
706            .into_iter()
707            .map(|(expr, name)| (Arc::clone(expr), name));
708        let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
709        let output_schema = output_schema(&projection_mapping, &schema)?;
710
711        let col_a_new = &col("a_new", &output_schema)?;
712        let col_b_new = &col("b_new", &output_schema)?;
713        let col_c_new = &col("c_new", &output_schema)?;
714        let col_round_c_res = &col("round_c_res", &output_schema)?;
715        let a_new_plus_b_new = Arc::new(BinaryExpr::new(
716            Arc::clone(col_a_new),
717            Operator::Plus,
718            Arc::clone(col_b_new),
719        )) as Arc<dyn PhysicalExpr>;
720
721        let test_cases = vec![
722            // ---------- TEST CASE 1 ------------
723            (
724                // orderings
725                vec![
726                    // [a ASC]
727                    vec![(col_a, option_asc)],
728                ],
729                // expected
730                vec![
731                    // [b_new ASC]
732                    vec![(col_a_new, option_asc)],
733                ],
734            ),
735            // ---------- TEST CASE 2 ------------
736            (
737                // orderings
738                vec![
739                    // [a+b ASC]
740                    vec![(&a_plus_b, option_asc)],
741                ],
742                // expected
743                vec![
744                    // [b_new ASC]
745                    vec![(&a_new_plus_b_new, option_asc)],
746                ],
747            ),
748            // ---------- TEST CASE 3 ------------
749            (
750                // orderings
751                vec![
752                    // [a ASC, ts ASC]
753                    vec![(col_a, option_asc), (col_ts, option_asc)],
754                ],
755                // expected
756                vec![
757                    // [a_new ASC, date_bin_res ASC]
758                    vec![(col_a_new, option_asc)],
759                ],
760            ),
761            // ---------- TEST CASE 4 ------------
762            (
763                // orderings
764                vec![
765                    // [a ASC, ts ASC, b ASC]
766                    vec![
767                        (col_a, option_asc),
768                        (col_ts, option_asc),
769                        (col_b, option_asc),
770                    ],
771                ],
772                // expected
773                vec![
774                    // [a_new ASC, date_bin_res ASC]
775                    vec![(col_a_new, option_asc)],
776                ],
777            ),
778            // ---------- TEST CASE 5 ------------
779            (
780                // orderings
781                vec![
782                    // [a ASC, c ASC]
783                    vec![(col_a, option_asc), (col_c, option_asc)],
784                ],
785                // expected
786                vec![
787                    // [a_new ASC, round_c_res ASC, c_new ASC]
788                    vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
789                    // [a_new ASC, c_new ASC]
790                    vec![(col_a_new, option_asc), (col_c_new, option_asc)],
791                ],
792            ),
793            // ---------- TEST CASE 6 ------------
794            (
795                // orderings
796                vec![
797                    // [c ASC, b ASC]
798                    vec![(col_c, option_asc), (col_b, option_asc)],
799                ],
800                // expected
801                vec![
802                    // [round_c_res ASC]
803                    vec![(col_round_c_res, option_asc)],
804                    // [c_new ASC, b_new ASC]
805                    vec![(col_c_new, option_asc), (col_b_new, option_asc)],
806                ],
807            ),
808            // ---------- TEST CASE 7 ------------
809            (
810                // orderings
811                vec![
812                    // [a+b ASC, c ASC]
813                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
814                ],
815                // expected
816                vec![
817                    // [a+b ASC, round(c) ASC, c_new ASC]
818                    vec![
819                        (&a_new_plus_b_new, option_asc),
820                        (col_round_c_res, option_asc),
821                    ],
822                    // [a+b ASC, c_new ASC]
823                    vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
824                ],
825            ),
826        ];
827
828        for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
829            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
830
831            let orderings = convert_to_orderings(orderings);
832            eq_properties.add_orderings(orderings);
833
834            let expected = convert_to_orderings(expected);
835
836            let projected_eq =
837                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
838            let orderings = projected_eq.oeq_class();
839
840            let err_msg = format!(
841                "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
842            );
843
844            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
845            for expected_ordering in &expected {
846                assert!(orderings.contains(expected_ordering), "{}", err_msg)
847            }
848        }
849        Ok(())
850    }
851
852    #[test]
853    fn project_orderings3() -> Result<()> {
854        let schema = Arc::new(Schema::new(vec![
855            Field::new("a", DataType::Int32, true),
856            Field::new("b", DataType::Int32, true),
857            Field::new("c", DataType::Int32, true),
858            Field::new("d", DataType::Int32, true),
859            Field::new("e", DataType::Int32, true),
860            Field::new("f", DataType::Int32, true),
861        ]));
862        let col_a = &col("a", &schema)?;
863        let col_b = &col("b", &schema)?;
864        let col_c = &col("c", &schema)?;
865        let col_d = &col("d", &schema)?;
866        let col_e = &col("e", &schema)?;
867        let col_f = &col("f", &schema)?;
868        let a_plus_b = Arc::new(BinaryExpr::new(
869            Arc::clone(col_a),
870            Operator::Plus,
871            Arc::clone(col_b),
872        )) as Arc<dyn PhysicalExpr>;
873
874        let option_asc = SortOptions {
875            descending: false,
876            nulls_first: false,
877        };
878
879        let proj_exprs = vec![
880            (col_c, "c_new".to_string()),
881            (col_d, "d_new".to_string()),
882            (&a_plus_b, "a+b".to_string()),
883        ];
884        let proj_exprs = proj_exprs
885            .into_iter()
886            .map(|(expr, name)| (Arc::clone(expr), name));
887        let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
888        let output_schema = output_schema(&projection_mapping, &schema)?;
889
890        let col_a_plus_b_new = &col("a+b", &output_schema)?;
891        let col_c_new = &col("c_new", &output_schema)?;
892        let col_d_new = &col("d_new", &output_schema)?;
893
894        let test_cases = vec![
895            // ---------- TEST CASE 1 ------------
896            (
897                // orderings
898                vec![
899                    // [d ASC, b ASC]
900                    vec![(col_d, option_asc), (col_b, option_asc)],
901                    // [c ASC, a ASC]
902                    vec![(col_c, option_asc), (col_a, option_asc)],
903                ],
904                // equal conditions
905                vec![],
906                // expected
907                vec![
908                    // [d_new ASC, c_new ASC, a+b ASC]
909                    vec![
910                        (col_d_new, option_asc),
911                        (col_c_new, option_asc),
912                        (col_a_plus_b_new, option_asc),
913                    ],
914                    // [c_new ASC, d_new ASC, a+b ASC]
915                    vec![
916                        (col_c_new, option_asc),
917                        (col_d_new, option_asc),
918                        (col_a_plus_b_new, option_asc),
919                    ],
920                ],
921            ),
922            // ---------- TEST CASE 2 ------------
923            (
924                // orderings
925                vec![
926                    // [d ASC, b ASC]
927                    vec![(col_d, option_asc), (col_b, option_asc)],
928                    // [c ASC, e ASC], Please note that a=e
929                    vec![(col_c, option_asc), (col_e, option_asc)],
930                ],
931                // equal conditions
932                vec![(col_e, col_a)],
933                // expected
934                vec![
935                    // [d_new ASC, c_new ASC, a+b ASC]
936                    vec![
937                        (col_d_new, option_asc),
938                        (col_c_new, option_asc),
939                        (col_a_plus_b_new, option_asc),
940                    ],
941                    // [c_new ASC, d_new ASC, a+b ASC]
942                    vec![
943                        (col_c_new, option_asc),
944                        (col_d_new, option_asc),
945                        (col_a_plus_b_new, option_asc),
946                    ],
947                ],
948            ),
949            // ---------- TEST CASE 3 ------------
950            (
951                // orderings
952                vec![
953                    // [d ASC, b ASC]
954                    vec![(col_d, option_asc), (col_b, option_asc)],
955                    // [c ASC, e ASC], Please note that a=f
956                    vec![(col_c, option_asc), (col_e, option_asc)],
957                ],
958                // equal conditions
959                vec![(col_a, col_f)],
960                // expected
961                vec![
962                    // [d_new ASC]
963                    vec![(col_d_new, option_asc)],
964                    // [c_new ASC]
965                    vec![(col_c_new, option_asc)],
966                ],
967            ),
968        ];
969        for (orderings, equal_columns, expected) in test_cases {
970            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
971            for (lhs, rhs) in equal_columns {
972                eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?;
973            }
974
975            let orderings = convert_to_orderings(&orderings);
976            eq_properties.add_orderings(orderings);
977
978            let expected = convert_to_orderings(&expected);
979
980            let projected_eq =
981                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
982            let orderings = projected_eq.oeq_class();
983
984            let err_msg = format!(
985                "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
986            );
987
988            assert_eq!(orderings.len(), expected.len(), "{err_msg}");
989            for expected_ordering in &expected {
990                assert!(orderings.contains(expected_ordering), "{}", err_msg)
991            }
992        }
993
994        Ok(())
995    }
996}