datafusion_physical_expr/equivalence/projection.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::expressions::Column;
22use crate::PhysicalExpr;
23
24use arrow::datatypes::SchemaRef;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_common::{internal_err, Result};
27
28use indexmap::IndexMap;
29
30/// Stores target expressions, along with their indices, that associate with a
31/// source expression in a projection mapping.
32#[derive(Clone, Debug, Default)]
33pub struct ProjectionTargets {
34 /// A non-empty vector of pairs of target expressions and their indices.
35 /// Consider using a special non-empty collection type in the future (e.g.
36 /// if Rust provides one in the standard library).
37 exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
38}
39
40impl ProjectionTargets {
41 /// Returns the first target expression and its index.
42 pub fn first(&self) -> &(Arc<dyn PhysicalExpr>, usize) {
43 // Since the vector is non-empty, we can safely unwrap:
44 self.exprs_indices.first().unwrap()
45 }
46
47 /// Adds a target expression and its index to the list of targets.
48 pub fn push(&mut self, target: (Arc<dyn PhysicalExpr>, usize)) {
49 self.exprs_indices.push(target);
50 }
51}
52
53impl Deref for ProjectionTargets {
54 type Target = [(Arc<dyn PhysicalExpr>, usize)];
55
56 fn deref(&self) -> &Self::Target {
57 &self.exprs_indices
58 }
59}
60
61impl From<Vec<(Arc<dyn PhysicalExpr>, usize)>> for ProjectionTargets {
62 fn from(exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>) -> Self {
63 Self { exprs_indices }
64 }
65}
66
67/// Stores the mapping between source expressions and target expressions for a
68/// projection.
69#[derive(Clone, Debug)]
70pub struct ProjectionMapping {
71 /// Mapping between source expressions and target expressions.
72 /// Vector indices correspond to the indices after projection.
73 map: IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>,
74}
75
76impl ProjectionMapping {
77 /// Constructs the mapping between a projection's input and output
78 /// expressions.
79 ///
80 /// For example, given the input projection expressions (`a + b`, `c + d`)
81 /// and an output schema with two columns `"c + d"` and `"a + b"`, the
82 /// projection mapping would be:
83 ///
84 /// ```text
85 /// [0]: (c + d, [(col("c + d"), 0)])
86 /// [1]: (a + b, [(col("a + b"), 1)])
87 /// ```
88 ///
89 /// where `col("c + d")` means the column named `"c + d"`.
90 pub fn try_new(
91 expr: impl IntoIterator<Item = (Arc<dyn PhysicalExpr>, String)>,
92 input_schema: &SchemaRef,
93 ) -> Result<Self> {
94 // Construct a map from the input expressions to the output expression of the projection:
95 let mut map = IndexMap::<_, ProjectionTargets>::new();
96 for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
97 let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
98 let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
99 Some(col) => {
100 // Sometimes, an expression and its name in the input_schema
101 // doesn't match. This can cause problems, so we make sure
102 // that the expression name matches with the name in `input_schema`.
103 // Conceptually, `source_expr` and `expression` should be the same.
104 let idx = col.index();
105 let matching_field = input_schema.field(idx);
106 let matching_name = matching_field.name();
107 if col.name() != matching_name {
108 return internal_err!(
109 "Input field name {} does not match with the projection expression {}",
110 matching_name,
111 col.name()
112 );
113 }
114 let matching_column = Column::new(matching_name, idx);
115 Ok(Transformed::yes(Arc::new(matching_column)))
116 }
117 None => Ok(Transformed::no(e)),
118 })
119 .data()?;
120 map.entry(source_expr)
121 .or_default()
122 .push((target_expr, expr_idx));
123 }
124 Ok(Self { map })
125 }
126
127 /// Constructs a subset mapping using the provided indices.
128 ///
129 /// This is used when the output is a subset of the input without any
130 /// other transformations. The indices are for columns in the schema.
131 pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
132 let projection_exprs = indices.iter().map(|index| {
133 let field = schema.field(*index);
134 let column = Arc::new(Column::new(field.name(), *index));
135 (column as _, field.name().clone())
136 });
137 ProjectionMapping::try_new(projection_exprs, schema)
138 }
139}
140
141impl Deref for ProjectionMapping {
142 type Target = IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>;
143
144 fn deref(&self) -> &Self::Target {
145 &self.map
146 }
147}
148
149impl FromIterator<(Arc<dyn PhysicalExpr>, ProjectionTargets)> for ProjectionMapping {
150 fn from_iter<T: IntoIterator<Item = (Arc<dyn PhysicalExpr>, ProjectionTargets)>>(
151 iter: T,
152 ) -> Self {
153 Self {
154 map: IndexMap::from_iter(iter),
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::equivalence::tests::output_schema;
163 use crate::equivalence::{convert_to_orderings, EquivalenceProperties};
164 use crate::expressions::{col, BinaryExpr};
165 use crate::utils::tests::TestScalarUDF;
166 use crate::{PhysicalExprRef, ScalarFunctionExpr};
167
168 use arrow::compute::SortOptions;
169 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
170 use datafusion_expr::{Operator, ScalarUDF};
171
172 #[test]
173 fn project_orderings() -> Result<()> {
174 let schema = Arc::new(Schema::new(vec![
175 Field::new("a", DataType::Int32, true),
176 Field::new("b", DataType::Int32, true),
177 Field::new("c", DataType::Int32, true),
178 Field::new("d", DataType::Int32, true),
179 Field::new("e", DataType::Int32, true),
180 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
181 ]));
182 let col_a = &col("a", &schema)?;
183 let col_b = &col("b", &schema)?;
184 let col_c = &col("c", &schema)?;
185 let col_d = &col("d", &schema)?;
186 let col_e = &col("e", &schema)?;
187 let col_ts = &col("ts", &schema)?;
188 let a_plus_b = Arc::new(BinaryExpr::new(
189 Arc::clone(col_a),
190 Operator::Plus,
191 Arc::clone(col_b),
192 )) as Arc<dyn PhysicalExpr>;
193 let b_plus_d = Arc::new(BinaryExpr::new(
194 Arc::clone(col_b),
195 Operator::Plus,
196 Arc::clone(col_d),
197 )) as Arc<dyn PhysicalExpr>;
198 let b_plus_e = Arc::new(BinaryExpr::new(
199 Arc::clone(col_b),
200 Operator::Plus,
201 Arc::clone(col_e),
202 )) as Arc<dyn PhysicalExpr>;
203 let c_plus_d = Arc::new(BinaryExpr::new(
204 Arc::clone(col_c),
205 Operator::Plus,
206 Arc::clone(col_d),
207 )) as Arc<dyn PhysicalExpr>;
208
209 let option_asc = SortOptions {
210 descending: false,
211 nulls_first: false,
212 };
213 let option_desc = SortOptions {
214 descending: true,
215 nulls_first: true,
216 };
217
218 let test_cases = vec![
219 // ---------- TEST CASE 1 ------------
220 (
221 // orderings
222 vec![
223 // [b ASC]
224 vec![(col_b, option_asc)],
225 ],
226 // projection exprs
227 vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
228 // expected
229 vec![
230 // [b_new ASC]
231 vec![("b_new", option_asc)],
232 ],
233 ),
234 // ---------- TEST CASE 2 ------------
235 (
236 // orderings
237 vec![
238 // empty ordering
239 ],
240 // projection exprs
241 vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
242 // expected
243 vec![
244 // no ordering at the output
245 ],
246 ),
247 // ---------- TEST CASE 3 ------------
248 (
249 // orderings
250 vec![
251 // [ts ASC]
252 vec![(col_ts, option_asc)],
253 ],
254 // projection exprs
255 vec![
256 (col_b, "b_new".to_string()),
257 (col_a, "a_new".to_string()),
258 (col_ts, "ts_new".to_string()),
259 ],
260 // expected
261 vec![
262 // [ts_new ASC]
263 vec![("ts_new", option_asc)],
264 ],
265 ),
266 // ---------- TEST CASE 4 ------------
267 (
268 // orderings
269 vec![
270 // [a ASC, ts ASC]
271 vec![(col_a, option_asc), (col_ts, option_asc)],
272 // [b ASC, ts ASC]
273 vec![(col_b, option_asc), (col_ts, option_asc)],
274 ],
275 // projection exprs
276 vec![
277 (col_b, "b_new".to_string()),
278 (col_a, "a_new".to_string()),
279 (col_ts, "ts_new".to_string()),
280 ],
281 // expected
282 vec![
283 // [a_new ASC, ts_new ASC]
284 vec![("a_new", option_asc), ("ts_new", option_asc)],
285 // [b_new ASC, ts_new ASC]
286 vec![("b_new", option_asc), ("ts_new", option_asc)],
287 ],
288 ),
289 // ---------- TEST CASE 5 ------------
290 (
291 // orderings
292 vec![
293 // [a + b ASC]
294 vec![(&a_plus_b, option_asc)],
295 ],
296 // projection exprs
297 vec![
298 (col_b, "b_new".to_string()),
299 (col_a, "a_new".to_string()),
300 (&a_plus_b, "a+b".to_string()),
301 ],
302 // expected
303 vec![
304 // [a + b ASC]
305 vec![("a+b", option_asc)],
306 ],
307 ),
308 // ---------- TEST CASE 6 ------------
309 (
310 // orderings
311 vec![
312 // [a + b ASC, c ASC]
313 vec![(&a_plus_b, option_asc), (col_c, option_asc)],
314 ],
315 // projection exprs
316 vec![
317 (col_b, "b_new".to_string()),
318 (col_a, "a_new".to_string()),
319 (col_c, "c_new".to_string()),
320 (&a_plus_b, "a+b".to_string()),
321 ],
322 // expected
323 vec![
324 // [a + b ASC, c_new ASC]
325 vec![("a+b", option_asc), ("c_new", option_asc)],
326 ],
327 ),
328 // ------- TEST CASE 7 ----------
329 (
330 vec![
331 // [a ASC, b ASC, c ASC]
332 vec![(col_a, option_asc), (col_b, option_asc)],
333 // [a ASC, d ASC]
334 vec![(col_a, option_asc), (col_d, option_asc)],
335 ],
336 // b as b_new, a as a_new, d as d_new b+d
337 vec![
338 (col_b, "b_new".to_string()),
339 (col_a, "a_new".to_string()),
340 (col_d, "d_new".to_string()),
341 (&b_plus_d, "b+d".to_string()),
342 ],
343 // expected
344 vec![
345 // [a_new ASC, b_new ASC]
346 vec![("a_new", option_asc), ("b_new", option_asc)],
347 // [a_new ASC, d_new ASC]
348 vec![("a_new", option_asc), ("d_new", option_asc)],
349 // [a_new ASC, b+d ASC]
350 vec![("a_new", option_asc), ("b+d", option_asc)],
351 ],
352 ),
353 // ------- TEST CASE 8 ----------
354 (
355 // orderings
356 vec![
357 // [b+d ASC]
358 vec![(&b_plus_d, option_asc)],
359 ],
360 // proj exprs
361 vec![
362 (col_b, "b_new".to_string()),
363 (col_a, "a_new".to_string()),
364 (col_d, "d_new".to_string()),
365 (&b_plus_d, "b+d".to_string()),
366 ],
367 // expected
368 vec![
369 // [b+d ASC]
370 vec![("b+d", option_asc)],
371 ],
372 ),
373 // ------- TEST CASE 9 ----------
374 (
375 // orderings
376 vec![
377 // [a ASC, d ASC, b ASC]
378 vec![
379 (col_a, option_asc),
380 (col_d, option_asc),
381 (col_b, option_asc),
382 ],
383 // [c ASC]
384 vec![(col_c, option_asc)],
385 ],
386 // proj exprs
387 vec![
388 (col_b, "b_new".to_string()),
389 (col_a, "a_new".to_string()),
390 (col_d, "d_new".to_string()),
391 (col_c, "c_new".to_string()),
392 ],
393 // expected
394 vec![
395 // [a_new ASC, d_new ASC, b_new ASC]
396 vec![
397 ("a_new", option_asc),
398 ("d_new", option_asc),
399 ("b_new", option_asc),
400 ],
401 // [c_new ASC],
402 vec![("c_new", option_asc)],
403 ],
404 ),
405 // ------- TEST CASE 10 ----------
406 (
407 vec![
408 // [a ASC, b ASC, c ASC]
409 vec![
410 (col_a, option_asc),
411 (col_b, option_asc),
412 (col_c, option_asc),
413 ],
414 // [a ASC, d ASC]
415 vec![(col_a, option_asc), (col_d, option_asc)],
416 ],
417 // proj exprs
418 vec![
419 (col_b, "b_new".to_string()),
420 (col_a, "a_new".to_string()),
421 (col_c, "c_new".to_string()),
422 (&c_plus_d, "c+d".to_string()),
423 ],
424 // expected
425 vec![
426 // [a_new ASC, b_new ASC, c_new ASC]
427 vec![
428 ("a_new", option_asc),
429 ("b_new", option_asc),
430 ("c_new", option_asc),
431 ],
432 // [a_new ASC, b_new ASC, c+d ASC]
433 vec![
434 ("a_new", option_asc),
435 ("b_new", option_asc),
436 ("c+d", option_asc),
437 ],
438 ],
439 ),
440 // ------- TEST CASE 11 ----------
441 (
442 // orderings
443 vec![
444 // [a ASC, b ASC]
445 vec![(col_a, option_asc), (col_b, option_asc)],
446 // [a ASC, d ASC]
447 vec![(col_a, option_asc), (col_d, option_asc)],
448 ],
449 // proj exprs
450 vec![
451 (col_b, "b_new".to_string()),
452 (col_a, "a_new".to_string()),
453 (&b_plus_d, "b+d".to_string()),
454 ],
455 // expected
456 vec![
457 // [a_new ASC, b_new ASC]
458 vec![("a_new", option_asc), ("b_new", option_asc)],
459 // [a_new ASC, b + d ASC]
460 vec![("a_new", option_asc), ("b+d", option_asc)],
461 ],
462 ),
463 // ------- TEST CASE 12 ----------
464 (
465 // orderings
466 vec![
467 // [a ASC, b ASC, c ASC]
468 vec![
469 (col_a, option_asc),
470 (col_b, option_asc),
471 (col_c, option_asc),
472 ],
473 ],
474 // proj exprs
475 vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
476 // expected
477 vec![
478 // [a_new ASC]
479 vec![("a_new", option_asc)],
480 ],
481 ),
482 // ------- TEST CASE 13 ----------
483 (
484 // orderings
485 vec![
486 // [a ASC, b ASC, c ASC]
487 vec![
488 (col_a, option_asc),
489 (col_b, option_asc),
490 (col_c, option_asc),
491 ],
492 // [a ASC, a + b ASC, c ASC]
493 vec![
494 (col_a, option_asc),
495 (&a_plus_b, option_asc),
496 (col_c, option_asc),
497 ],
498 ],
499 // proj exprs
500 vec![
501 (col_c, "c_new".to_string()),
502 (col_b, "b_new".to_string()),
503 (col_a, "a_new".to_string()),
504 (&a_plus_b, "a+b".to_string()),
505 ],
506 // expected
507 vec![
508 // [a_new ASC, b_new ASC, c_new ASC]
509 vec![
510 ("a_new", option_asc),
511 ("b_new", option_asc),
512 ("c_new", option_asc),
513 ],
514 // [a_new ASC, a+b ASC, c_new ASC]
515 vec![
516 ("a_new", option_asc),
517 ("a+b", option_asc),
518 ("c_new", option_asc),
519 ],
520 ],
521 ),
522 // ------- TEST CASE 14 ----------
523 (
524 // orderings
525 vec![
526 // [a ASC, b ASC]
527 vec![(col_a, option_asc), (col_b, option_asc)],
528 // [c ASC, b ASC]
529 vec![(col_c, option_asc), (col_b, option_asc)],
530 // [d ASC, e ASC]
531 vec![(col_d, option_asc), (col_e, option_asc)],
532 ],
533 // proj exprs
534 vec![
535 (col_c, "c_new".to_string()),
536 (col_d, "d_new".to_string()),
537 (col_a, "a_new".to_string()),
538 (&b_plus_e, "b+e".to_string()),
539 ],
540 // expected
541 vec![
542 // [a_new ASC, d_new ASC, b+e ASC]
543 vec![
544 ("a_new", option_asc),
545 ("d_new", option_asc),
546 ("b+e", option_asc),
547 ],
548 // [d_new ASC, a_new ASC, b+e ASC]
549 vec![
550 ("d_new", option_asc),
551 ("a_new", option_asc),
552 ("b+e", option_asc),
553 ],
554 // [c_new ASC, d_new ASC, b+e ASC]
555 vec![
556 ("c_new", option_asc),
557 ("d_new", option_asc),
558 ("b+e", option_asc),
559 ],
560 // [d_new ASC, c_new ASC, b+e ASC]
561 vec![
562 ("d_new", option_asc),
563 ("c_new", option_asc),
564 ("b+e", option_asc),
565 ],
566 ],
567 ),
568 // ------- TEST CASE 15 ----------
569 (
570 // orderings
571 vec![
572 // [a ASC, c ASC, b ASC]
573 vec![
574 (col_a, option_asc),
575 (col_c, option_asc),
576 (col_b, option_asc),
577 ],
578 ],
579 // proj exprs
580 vec![
581 (col_c, "c_new".to_string()),
582 (col_a, "a_new".to_string()),
583 (&a_plus_b, "a+b".to_string()),
584 ],
585 // expected
586 vec![
587 // [a_new ASC, d_new ASC, b+e ASC]
588 vec![
589 ("a_new", option_asc),
590 ("c_new", option_asc),
591 ("a+b", option_asc),
592 ],
593 ],
594 ),
595 // ------- TEST CASE 16 ----------
596 (
597 // orderings
598 vec![
599 // [a ASC, b ASC]
600 vec![(col_a, option_asc), (col_b, option_asc)],
601 // [c ASC, b DESC]
602 vec![(col_c, option_asc), (col_b, option_desc)],
603 // [e ASC]
604 vec![(col_e, option_asc)],
605 ],
606 // proj exprs
607 vec![
608 (col_c, "c_new".to_string()),
609 (col_a, "a_new".to_string()),
610 (col_b, "b_new".to_string()),
611 (&b_plus_e, "b+e".to_string()),
612 ],
613 // expected
614 vec![
615 // [a_new ASC, b_new ASC]
616 vec![("a_new", option_asc), ("b_new", option_asc)],
617 // [a_new ASC, b_new ASC]
618 vec![("a_new", option_asc), ("b+e", option_asc)],
619 // [c_new ASC, b_new DESC]
620 vec![("c_new", option_asc), ("b_new", option_desc)],
621 ],
622 ),
623 ];
624
625 for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
626 {
627 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
628
629 let orderings = convert_to_orderings(&orderings);
630 eq_properties.add_orderings(orderings);
631
632 let proj_exprs = proj_exprs
633 .into_iter()
634 .map(|(expr, name)| (Arc::clone(expr), name));
635 let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
636 let output_schema = output_schema(&projection_mapping, &schema)?;
637
638 let expected = expected
639 .into_iter()
640 .map(|ordering| {
641 ordering
642 .into_iter()
643 .map(|(name, options)| {
644 (col(name, &output_schema).unwrap(), options)
645 })
646 .collect::<Vec<_>>()
647 })
648 .collect::<Vec<_>>();
649 let expected = convert_to_orderings(&expected);
650
651 let projected_eq = eq_properties.project(&projection_mapping, output_schema);
652 let orderings = projected_eq.oeq_class();
653
654 let err_msg = format!(
655 "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
656 );
657
658 assert_eq!(orderings.len(), expected.len(), "{err_msg}");
659 for expected_ordering in &expected {
660 assert!(orderings.contains(expected_ordering), "{}", err_msg)
661 }
662 }
663
664 Ok(())
665 }
666
667 #[test]
668 fn project_orderings2() -> Result<()> {
669 let schema = Arc::new(Schema::new(vec![
670 Field::new("a", DataType::Int32, true),
671 Field::new("b", DataType::Int32, true),
672 Field::new("c", DataType::Int32, true),
673 Field::new("d", DataType::Int32, true),
674 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
675 ]));
676 let col_a = &col("a", &schema)?;
677 let col_b = &col("b", &schema)?;
678 let col_c = &col("c", &schema)?;
679 let col_ts = &col("ts", &schema)?;
680 let a_plus_b = Arc::new(BinaryExpr::new(
681 Arc::clone(col_a),
682 Operator::Plus,
683 Arc::clone(col_b),
684 )) as Arc<dyn PhysicalExpr>;
685
686 let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
687
688 let round_c = Arc::new(ScalarFunctionExpr::try_new(
689 test_fun,
690 vec![Arc::clone(col_c)],
691 &schema,
692 )?) as PhysicalExprRef;
693
694 let option_asc = SortOptions {
695 descending: false,
696 nulls_first: false,
697 };
698
699 let proj_exprs = vec![
700 (col_b, "b_new".to_string()),
701 (col_a, "a_new".to_string()),
702 (col_c, "c_new".to_string()),
703 (&round_c, "round_c_res".to_string()),
704 ];
705 let proj_exprs = proj_exprs
706 .into_iter()
707 .map(|(expr, name)| (Arc::clone(expr), name));
708 let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
709 let output_schema = output_schema(&projection_mapping, &schema)?;
710
711 let col_a_new = &col("a_new", &output_schema)?;
712 let col_b_new = &col("b_new", &output_schema)?;
713 let col_c_new = &col("c_new", &output_schema)?;
714 let col_round_c_res = &col("round_c_res", &output_schema)?;
715 let a_new_plus_b_new = Arc::new(BinaryExpr::new(
716 Arc::clone(col_a_new),
717 Operator::Plus,
718 Arc::clone(col_b_new),
719 )) as Arc<dyn PhysicalExpr>;
720
721 let test_cases = vec![
722 // ---------- TEST CASE 1 ------------
723 (
724 // orderings
725 vec![
726 // [a ASC]
727 vec![(col_a, option_asc)],
728 ],
729 // expected
730 vec![
731 // [b_new ASC]
732 vec![(col_a_new, option_asc)],
733 ],
734 ),
735 // ---------- TEST CASE 2 ------------
736 (
737 // orderings
738 vec![
739 // [a+b ASC]
740 vec![(&a_plus_b, option_asc)],
741 ],
742 // expected
743 vec![
744 // [b_new ASC]
745 vec![(&a_new_plus_b_new, option_asc)],
746 ],
747 ),
748 // ---------- TEST CASE 3 ------------
749 (
750 // orderings
751 vec![
752 // [a ASC, ts ASC]
753 vec![(col_a, option_asc), (col_ts, option_asc)],
754 ],
755 // expected
756 vec![
757 // [a_new ASC, date_bin_res ASC]
758 vec![(col_a_new, option_asc)],
759 ],
760 ),
761 // ---------- TEST CASE 4 ------------
762 (
763 // orderings
764 vec![
765 // [a ASC, ts ASC, b ASC]
766 vec![
767 (col_a, option_asc),
768 (col_ts, option_asc),
769 (col_b, option_asc),
770 ],
771 ],
772 // expected
773 vec![
774 // [a_new ASC, date_bin_res ASC]
775 vec![(col_a_new, option_asc)],
776 ],
777 ),
778 // ---------- TEST CASE 5 ------------
779 (
780 // orderings
781 vec![
782 // [a ASC, c ASC]
783 vec![(col_a, option_asc), (col_c, option_asc)],
784 ],
785 // expected
786 vec![
787 // [a_new ASC, round_c_res ASC, c_new ASC]
788 vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
789 // [a_new ASC, c_new ASC]
790 vec![(col_a_new, option_asc), (col_c_new, option_asc)],
791 ],
792 ),
793 // ---------- TEST CASE 6 ------------
794 (
795 // orderings
796 vec![
797 // [c ASC, b ASC]
798 vec![(col_c, option_asc), (col_b, option_asc)],
799 ],
800 // expected
801 vec![
802 // [round_c_res ASC]
803 vec![(col_round_c_res, option_asc)],
804 // [c_new ASC, b_new ASC]
805 vec![(col_c_new, option_asc), (col_b_new, option_asc)],
806 ],
807 ),
808 // ---------- TEST CASE 7 ------------
809 (
810 // orderings
811 vec![
812 // [a+b ASC, c ASC]
813 vec![(&a_plus_b, option_asc), (col_c, option_asc)],
814 ],
815 // expected
816 vec![
817 // [a+b ASC, round(c) ASC, c_new ASC]
818 vec![
819 (&a_new_plus_b_new, option_asc),
820 (col_round_c_res, option_asc),
821 ],
822 // [a+b ASC, c_new ASC]
823 vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
824 ],
825 ),
826 ];
827
828 for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
829 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
830
831 let orderings = convert_to_orderings(orderings);
832 eq_properties.add_orderings(orderings);
833
834 let expected = convert_to_orderings(expected);
835
836 let projected_eq =
837 eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
838 let orderings = projected_eq.oeq_class();
839
840 let err_msg = format!(
841 "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
842 );
843
844 assert_eq!(orderings.len(), expected.len(), "{err_msg}");
845 for expected_ordering in &expected {
846 assert!(orderings.contains(expected_ordering), "{}", err_msg)
847 }
848 }
849 Ok(())
850 }
851
852 #[test]
853 fn project_orderings3() -> Result<()> {
854 let schema = Arc::new(Schema::new(vec![
855 Field::new("a", DataType::Int32, true),
856 Field::new("b", DataType::Int32, true),
857 Field::new("c", DataType::Int32, true),
858 Field::new("d", DataType::Int32, true),
859 Field::new("e", DataType::Int32, true),
860 Field::new("f", DataType::Int32, true),
861 ]));
862 let col_a = &col("a", &schema)?;
863 let col_b = &col("b", &schema)?;
864 let col_c = &col("c", &schema)?;
865 let col_d = &col("d", &schema)?;
866 let col_e = &col("e", &schema)?;
867 let col_f = &col("f", &schema)?;
868 let a_plus_b = Arc::new(BinaryExpr::new(
869 Arc::clone(col_a),
870 Operator::Plus,
871 Arc::clone(col_b),
872 )) as Arc<dyn PhysicalExpr>;
873
874 let option_asc = SortOptions {
875 descending: false,
876 nulls_first: false,
877 };
878
879 let proj_exprs = vec![
880 (col_c, "c_new".to_string()),
881 (col_d, "d_new".to_string()),
882 (&a_plus_b, "a+b".to_string()),
883 ];
884 let proj_exprs = proj_exprs
885 .into_iter()
886 .map(|(expr, name)| (Arc::clone(expr), name));
887 let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
888 let output_schema = output_schema(&projection_mapping, &schema)?;
889
890 let col_a_plus_b_new = &col("a+b", &output_schema)?;
891 let col_c_new = &col("c_new", &output_schema)?;
892 let col_d_new = &col("d_new", &output_schema)?;
893
894 let test_cases = vec![
895 // ---------- TEST CASE 1 ------------
896 (
897 // orderings
898 vec![
899 // [d ASC, b ASC]
900 vec![(col_d, option_asc), (col_b, option_asc)],
901 // [c ASC, a ASC]
902 vec![(col_c, option_asc), (col_a, option_asc)],
903 ],
904 // equal conditions
905 vec![],
906 // expected
907 vec![
908 // [d_new ASC, c_new ASC, a+b ASC]
909 vec![
910 (col_d_new, option_asc),
911 (col_c_new, option_asc),
912 (col_a_plus_b_new, option_asc),
913 ],
914 // [c_new ASC, d_new ASC, a+b ASC]
915 vec![
916 (col_c_new, option_asc),
917 (col_d_new, option_asc),
918 (col_a_plus_b_new, option_asc),
919 ],
920 ],
921 ),
922 // ---------- TEST CASE 2 ------------
923 (
924 // orderings
925 vec![
926 // [d ASC, b ASC]
927 vec![(col_d, option_asc), (col_b, option_asc)],
928 // [c ASC, e ASC], Please note that a=e
929 vec![(col_c, option_asc), (col_e, option_asc)],
930 ],
931 // equal conditions
932 vec![(col_e, col_a)],
933 // expected
934 vec![
935 // [d_new ASC, c_new ASC, a+b ASC]
936 vec![
937 (col_d_new, option_asc),
938 (col_c_new, option_asc),
939 (col_a_plus_b_new, option_asc),
940 ],
941 // [c_new ASC, d_new ASC, a+b ASC]
942 vec![
943 (col_c_new, option_asc),
944 (col_d_new, option_asc),
945 (col_a_plus_b_new, option_asc),
946 ],
947 ],
948 ),
949 // ---------- TEST CASE 3 ------------
950 (
951 // orderings
952 vec![
953 // [d ASC, b ASC]
954 vec![(col_d, option_asc), (col_b, option_asc)],
955 // [c ASC, e ASC], Please note that a=f
956 vec![(col_c, option_asc), (col_e, option_asc)],
957 ],
958 // equal conditions
959 vec![(col_a, col_f)],
960 // expected
961 vec![
962 // [d_new ASC]
963 vec![(col_d_new, option_asc)],
964 // [c_new ASC]
965 vec![(col_c_new, option_asc)],
966 ],
967 ),
968 ];
969 for (orderings, equal_columns, expected) in test_cases {
970 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
971 for (lhs, rhs) in equal_columns {
972 eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?;
973 }
974
975 let orderings = convert_to_orderings(&orderings);
976 eq_properties.add_orderings(orderings);
977
978 let expected = convert_to_orderings(&expected);
979
980 let projected_eq =
981 eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
982 let orderings = projected_eq.oeq_class();
983
984 let err_msg = format!(
985 "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
986 );
987
988 assert_eq!(orderings.len(), expected.len(), "{err_msg}");
989 for expected_ordering in &expected {
990 assert!(orderings.contains(expected_ordering), "{}", err_msg)
991 }
992 }
993
994 Ok(())
995 }
996}