datafusion_physical_expr/equivalence/projection.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ops::Deref;
19use std::sync::Arc;
20
21use crate::expressions::Column;
22use crate::PhysicalExpr;
23
24use arrow::datatypes::SchemaRef;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_common::{internal_err, Result};
27
28use indexmap::IndexMap;
29
30/// Stores target expressions, along with their indices, that associate with a
31/// source expression in a projection mapping.
32#[derive(Clone, Debug, Default)]
33pub struct ProjectionTargets {
34 /// A non-empty vector of pairs of target expressions and their indices.
35 /// Consider using a special non-empty collection type in the future (e.g.
36 /// if Rust provides one in the standard library).
37 exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
38}
39
40impl ProjectionTargets {
41 /// Returns the first target expression and its index.
42 pub fn first(&self) -> &(Arc<dyn PhysicalExpr>, usize) {
43 // Since the vector is non-empty, we can safely unwrap:
44 self.exprs_indices.first().unwrap()
45 }
46
47 /// Adds a target expression and its index to the list of targets.
48 pub fn push(&mut self, target: (Arc<dyn PhysicalExpr>, usize)) {
49 self.exprs_indices.push(target);
50 }
51}
52
53impl Deref for ProjectionTargets {
54 type Target = [(Arc<dyn PhysicalExpr>, usize)];
55
56 fn deref(&self) -> &Self::Target {
57 &self.exprs_indices
58 }
59}
60
61impl From<Vec<(Arc<dyn PhysicalExpr>, usize)>> for ProjectionTargets {
62 fn from(exprs_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>) -> Self {
63 Self { exprs_indices }
64 }
65}
66
67/// Stores the mapping between source expressions and target expressions for a
68/// projection.
69#[derive(Clone, Debug)]
70pub struct ProjectionMapping {
71 /// Mapping between source expressions and target expressions.
72 /// Vector indices correspond to the indices after projection.
73 map: IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>,
74}
75
76impl ProjectionMapping {
77 /// Constructs the mapping between a projection's input and output
78 /// expressions.
79 ///
80 /// For example, given the input projection expressions (`a + b`, `c + d`)
81 /// and an output schema with two columns `"c + d"` and `"a + b"`, the
82 /// projection mapping would be:
83 ///
84 /// ```text
85 /// [0]: (c + d, [(col("c + d"), 0)])
86 /// [1]: (a + b, [(col("a + b"), 1)])
87 /// ```
88 ///
89 /// where `col("c + d")` means the column named `"c + d"`.
90 pub fn try_new(
91 expr: impl IntoIterator<Item = (Arc<dyn PhysicalExpr>, String)>,
92 input_schema: &SchemaRef,
93 ) -> Result<Self> {
94 // Construct a map from the input expressions to the output expression of the projection:
95 let mut map = IndexMap::<_, ProjectionTargets>::new();
96 for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
97 let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
98 let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
99 Some(col) => {
100 // Sometimes, an expression and its name in the input_schema
101 // doesn't match. This can cause problems, so we make sure
102 // that the expression name matches with the name in `input_schema`.
103 // Conceptually, `source_expr` and `expression` should be the same.
104 let idx = col.index();
105 let matching_field = input_schema.field(idx);
106 let matching_name = matching_field.name();
107 if col.name() != matching_name {
108 return internal_err!(
109 "Input field name {} does not match with the projection expression {}",
110 matching_name,
111 col.name()
112 );
113 }
114 let matching_column = Column::new(matching_name, idx);
115 Ok(Transformed::yes(Arc::new(matching_column)))
116 }
117 None => Ok(Transformed::no(e)),
118 })
119 .data()?;
120 map.entry(source_expr)
121 .or_default()
122 .push((target_expr, expr_idx));
123 }
124 Ok(Self { map })
125 }
126
127 /// Constructs a subset mapping using the provided indices.
128 ///
129 /// This is used when the output is a subset of the input without any
130 /// other transformations. The indices are for columns in the schema.
131 pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
132 let projection_exprs = indices.iter().map(|index| {
133 let field = schema.field(*index);
134 let column = Arc::new(Column::new(field.name(), *index));
135 (column as _, field.name().clone())
136 });
137 ProjectionMapping::try_new(projection_exprs, schema)
138 }
139}
140
141impl Deref for ProjectionMapping {
142 type Target = IndexMap<Arc<dyn PhysicalExpr>, ProjectionTargets>;
143
144 fn deref(&self) -> &Self::Target {
145 &self.map
146 }
147}
148
149impl FromIterator<(Arc<dyn PhysicalExpr>, ProjectionTargets)> for ProjectionMapping {
150 fn from_iter<T: IntoIterator<Item = (Arc<dyn PhysicalExpr>, ProjectionTargets)>>(
151 iter: T,
152 ) -> Self {
153 Self {
154 map: IndexMap::from_iter(iter),
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::equivalence::tests::output_schema;
163 use crate::equivalence::{convert_to_orderings, EquivalenceProperties};
164 use crate::expressions::{col, BinaryExpr};
165 use crate::utils::tests::TestScalarUDF;
166 use crate::{PhysicalExprRef, ScalarFunctionExpr};
167
168 use arrow::compute::SortOptions;
169 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
170 use datafusion_common::config::ConfigOptions;
171 use datafusion_expr::{Operator, ScalarUDF};
172
173 #[test]
174 fn project_orderings() -> Result<()> {
175 let schema = Arc::new(Schema::new(vec![
176 Field::new("a", DataType::Int32, true),
177 Field::new("b", DataType::Int32, true),
178 Field::new("c", DataType::Int32, true),
179 Field::new("d", DataType::Int32, true),
180 Field::new("e", DataType::Int32, true),
181 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
182 ]));
183 let col_a = &col("a", &schema)?;
184 let col_b = &col("b", &schema)?;
185 let col_c = &col("c", &schema)?;
186 let col_d = &col("d", &schema)?;
187 let col_e = &col("e", &schema)?;
188 let col_ts = &col("ts", &schema)?;
189 let a_plus_b = Arc::new(BinaryExpr::new(
190 Arc::clone(col_a),
191 Operator::Plus,
192 Arc::clone(col_b),
193 )) as Arc<dyn PhysicalExpr>;
194 let b_plus_d = Arc::new(BinaryExpr::new(
195 Arc::clone(col_b),
196 Operator::Plus,
197 Arc::clone(col_d),
198 )) as Arc<dyn PhysicalExpr>;
199 let b_plus_e = Arc::new(BinaryExpr::new(
200 Arc::clone(col_b),
201 Operator::Plus,
202 Arc::clone(col_e),
203 )) as Arc<dyn PhysicalExpr>;
204 let c_plus_d = Arc::new(BinaryExpr::new(
205 Arc::clone(col_c),
206 Operator::Plus,
207 Arc::clone(col_d),
208 )) as Arc<dyn PhysicalExpr>;
209
210 let option_asc = SortOptions {
211 descending: false,
212 nulls_first: false,
213 };
214 let option_desc = SortOptions {
215 descending: true,
216 nulls_first: true,
217 };
218
219 let test_cases = vec![
220 // ---------- TEST CASE 1 ------------
221 (
222 // orderings
223 vec![
224 // [b ASC]
225 vec![(col_b, option_asc)],
226 ],
227 // projection exprs
228 vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
229 // expected
230 vec![
231 // [b_new ASC]
232 vec![("b_new", option_asc)],
233 ],
234 ),
235 // ---------- TEST CASE 2 ------------
236 (
237 // orderings
238 vec![
239 // empty ordering
240 ],
241 // projection exprs
242 vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
243 // expected
244 vec![
245 // no ordering at the output
246 ],
247 ),
248 // ---------- TEST CASE 3 ------------
249 (
250 // orderings
251 vec![
252 // [ts ASC]
253 vec![(col_ts, option_asc)],
254 ],
255 // projection exprs
256 vec![
257 (col_b, "b_new".to_string()),
258 (col_a, "a_new".to_string()),
259 (col_ts, "ts_new".to_string()),
260 ],
261 // expected
262 vec![
263 // [ts_new ASC]
264 vec![("ts_new", option_asc)],
265 ],
266 ),
267 // ---------- TEST CASE 4 ------------
268 (
269 // orderings
270 vec![
271 // [a ASC, ts ASC]
272 vec![(col_a, option_asc), (col_ts, option_asc)],
273 // [b ASC, ts ASC]
274 vec![(col_b, option_asc), (col_ts, option_asc)],
275 ],
276 // projection exprs
277 vec![
278 (col_b, "b_new".to_string()),
279 (col_a, "a_new".to_string()),
280 (col_ts, "ts_new".to_string()),
281 ],
282 // expected
283 vec![
284 // [a_new ASC, ts_new ASC]
285 vec![("a_new", option_asc), ("ts_new", option_asc)],
286 // [b_new ASC, ts_new ASC]
287 vec![("b_new", option_asc), ("ts_new", option_asc)],
288 ],
289 ),
290 // ---------- TEST CASE 5 ------------
291 (
292 // orderings
293 vec![
294 // [a + b ASC]
295 vec![(&a_plus_b, option_asc)],
296 ],
297 // projection exprs
298 vec![
299 (col_b, "b_new".to_string()),
300 (col_a, "a_new".to_string()),
301 (&a_plus_b, "a+b".to_string()),
302 ],
303 // expected
304 vec![
305 // [a + b ASC]
306 vec![("a+b", option_asc)],
307 ],
308 ),
309 // ---------- TEST CASE 6 ------------
310 (
311 // orderings
312 vec![
313 // [a + b ASC, c ASC]
314 vec![(&a_plus_b, option_asc), (col_c, option_asc)],
315 ],
316 // projection exprs
317 vec![
318 (col_b, "b_new".to_string()),
319 (col_a, "a_new".to_string()),
320 (col_c, "c_new".to_string()),
321 (&a_plus_b, "a+b".to_string()),
322 ],
323 // expected
324 vec![
325 // [a + b ASC, c_new ASC]
326 vec![("a+b", option_asc), ("c_new", option_asc)],
327 ],
328 ),
329 // ------- TEST CASE 7 ----------
330 (
331 vec![
332 // [a ASC, b ASC, c ASC]
333 vec![(col_a, option_asc), (col_b, option_asc)],
334 // [a ASC, d ASC]
335 vec![(col_a, option_asc), (col_d, option_asc)],
336 ],
337 // b as b_new, a as a_new, d as d_new b+d
338 vec![
339 (col_b, "b_new".to_string()),
340 (col_a, "a_new".to_string()),
341 (col_d, "d_new".to_string()),
342 (&b_plus_d, "b+d".to_string()),
343 ],
344 // expected
345 vec![
346 // [a_new ASC, b_new ASC]
347 vec![("a_new", option_asc), ("b_new", option_asc)],
348 // [a_new ASC, d_new ASC]
349 vec![("a_new", option_asc), ("d_new", option_asc)],
350 // [a_new ASC, b+d ASC]
351 vec![("a_new", option_asc), ("b+d", option_asc)],
352 ],
353 ),
354 // ------- TEST CASE 8 ----------
355 (
356 // orderings
357 vec![
358 // [b+d ASC]
359 vec![(&b_plus_d, option_asc)],
360 ],
361 // proj exprs
362 vec![
363 (col_b, "b_new".to_string()),
364 (col_a, "a_new".to_string()),
365 (col_d, "d_new".to_string()),
366 (&b_plus_d, "b+d".to_string()),
367 ],
368 // expected
369 vec![
370 // [b+d ASC]
371 vec![("b+d", option_asc)],
372 ],
373 ),
374 // ------- TEST CASE 9 ----------
375 (
376 // orderings
377 vec![
378 // [a ASC, d ASC, b ASC]
379 vec![
380 (col_a, option_asc),
381 (col_d, option_asc),
382 (col_b, option_asc),
383 ],
384 // [c ASC]
385 vec![(col_c, option_asc)],
386 ],
387 // proj exprs
388 vec![
389 (col_b, "b_new".to_string()),
390 (col_a, "a_new".to_string()),
391 (col_d, "d_new".to_string()),
392 (col_c, "c_new".to_string()),
393 ],
394 // expected
395 vec![
396 // [a_new ASC, d_new ASC, b_new ASC]
397 vec![
398 ("a_new", option_asc),
399 ("d_new", option_asc),
400 ("b_new", option_asc),
401 ],
402 // [c_new ASC],
403 vec![("c_new", option_asc)],
404 ],
405 ),
406 // ------- TEST CASE 10 ----------
407 (
408 vec![
409 // [a ASC, b ASC, c ASC]
410 vec![
411 (col_a, option_asc),
412 (col_b, option_asc),
413 (col_c, option_asc),
414 ],
415 // [a ASC, d ASC]
416 vec![(col_a, option_asc), (col_d, option_asc)],
417 ],
418 // proj exprs
419 vec![
420 (col_b, "b_new".to_string()),
421 (col_a, "a_new".to_string()),
422 (col_c, "c_new".to_string()),
423 (&c_plus_d, "c+d".to_string()),
424 ],
425 // expected
426 vec![
427 // [a_new ASC, b_new ASC, c_new ASC]
428 vec![
429 ("a_new", option_asc),
430 ("b_new", option_asc),
431 ("c_new", option_asc),
432 ],
433 // [a_new ASC, b_new ASC, c+d ASC]
434 vec![
435 ("a_new", option_asc),
436 ("b_new", option_asc),
437 ("c+d", option_asc),
438 ],
439 ],
440 ),
441 // ------- TEST CASE 11 ----------
442 (
443 // orderings
444 vec![
445 // [a ASC, b ASC]
446 vec![(col_a, option_asc), (col_b, option_asc)],
447 // [a ASC, d ASC]
448 vec![(col_a, option_asc), (col_d, option_asc)],
449 ],
450 // proj exprs
451 vec![
452 (col_b, "b_new".to_string()),
453 (col_a, "a_new".to_string()),
454 (&b_plus_d, "b+d".to_string()),
455 ],
456 // expected
457 vec![
458 // [a_new ASC, b_new ASC]
459 vec![("a_new", option_asc), ("b_new", option_asc)],
460 // [a_new ASC, b + d ASC]
461 vec![("a_new", option_asc), ("b+d", option_asc)],
462 ],
463 ),
464 // ------- TEST CASE 12 ----------
465 (
466 // orderings
467 vec![
468 // [a ASC, b ASC, c ASC]
469 vec![
470 (col_a, option_asc),
471 (col_b, option_asc),
472 (col_c, option_asc),
473 ],
474 ],
475 // proj exprs
476 vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
477 // expected
478 vec![
479 // [a_new ASC]
480 vec![("a_new", option_asc)],
481 ],
482 ),
483 // ------- TEST CASE 13 ----------
484 (
485 // orderings
486 vec![
487 // [a ASC, b ASC, c ASC]
488 vec![
489 (col_a, option_asc),
490 (col_b, option_asc),
491 (col_c, option_asc),
492 ],
493 // [a ASC, a + b ASC, c ASC]
494 vec![
495 (col_a, option_asc),
496 (&a_plus_b, option_asc),
497 (col_c, option_asc),
498 ],
499 ],
500 // proj exprs
501 vec![
502 (col_c, "c_new".to_string()),
503 (col_b, "b_new".to_string()),
504 (col_a, "a_new".to_string()),
505 (&a_plus_b, "a+b".to_string()),
506 ],
507 // expected
508 vec![
509 // [a_new ASC, b_new ASC, c_new ASC]
510 vec![
511 ("a_new", option_asc),
512 ("b_new", option_asc),
513 ("c_new", option_asc),
514 ],
515 // [a_new ASC, a+b ASC, c_new ASC]
516 vec![
517 ("a_new", option_asc),
518 ("a+b", option_asc),
519 ("c_new", option_asc),
520 ],
521 ],
522 ),
523 // ------- TEST CASE 14 ----------
524 (
525 // orderings
526 vec![
527 // [a ASC, b ASC]
528 vec![(col_a, option_asc), (col_b, option_asc)],
529 // [c ASC, b ASC]
530 vec![(col_c, option_asc), (col_b, option_asc)],
531 // [d ASC, e ASC]
532 vec![(col_d, option_asc), (col_e, option_asc)],
533 ],
534 // proj exprs
535 vec![
536 (col_c, "c_new".to_string()),
537 (col_d, "d_new".to_string()),
538 (col_a, "a_new".to_string()),
539 (&b_plus_e, "b+e".to_string()),
540 ],
541 // expected
542 vec![
543 // [a_new ASC, d_new ASC, b+e ASC]
544 vec![
545 ("a_new", option_asc),
546 ("d_new", option_asc),
547 ("b+e", option_asc),
548 ],
549 // [d_new ASC, a_new ASC, b+e ASC]
550 vec![
551 ("d_new", option_asc),
552 ("a_new", option_asc),
553 ("b+e", option_asc),
554 ],
555 // [c_new ASC, d_new ASC, b+e ASC]
556 vec![
557 ("c_new", option_asc),
558 ("d_new", option_asc),
559 ("b+e", option_asc),
560 ],
561 // [d_new ASC, c_new ASC, b+e ASC]
562 vec![
563 ("d_new", option_asc),
564 ("c_new", option_asc),
565 ("b+e", option_asc),
566 ],
567 ],
568 ),
569 // ------- TEST CASE 15 ----------
570 (
571 // orderings
572 vec![
573 // [a ASC, c ASC, b ASC]
574 vec![
575 (col_a, option_asc),
576 (col_c, option_asc),
577 (col_b, option_asc),
578 ],
579 ],
580 // proj exprs
581 vec![
582 (col_c, "c_new".to_string()),
583 (col_a, "a_new".to_string()),
584 (&a_plus_b, "a+b".to_string()),
585 ],
586 // expected
587 vec![
588 // [a_new ASC, d_new ASC, b+e ASC]
589 vec![
590 ("a_new", option_asc),
591 ("c_new", option_asc),
592 ("a+b", option_asc),
593 ],
594 ],
595 ),
596 // ------- TEST CASE 16 ----------
597 (
598 // orderings
599 vec![
600 // [a ASC, b ASC]
601 vec![(col_a, option_asc), (col_b, option_asc)],
602 // [c ASC, b DESC]
603 vec![(col_c, option_asc), (col_b, option_desc)],
604 // [e ASC]
605 vec![(col_e, option_asc)],
606 ],
607 // proj exprs
608 vec![
609 (col_c, "c_new".to_string()),
610 (col_a, "a_new".to_string()),
611 (col_b, "b_new".to_string()),
612 (&b_plus_e, "b+e".to_string()),
613 ],
614 // expected
615 vec![
616 // [a_new ASC, b_new ASC]
617 vec![("a_new", option_asc), ("b_new", option_asc)],
618 // [a_new ASC, b_new ASC]
619 vec![("a_new", option_asc), ("b+e", option_asc)],
620 // [c_new ASC, b_new DESC]
621 vec![("c_new", option_asc), ("b_new", option_desc)],
622 ],
623 ),
624 ];
625
626 for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
627 {
628 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
629
630 let orderings = convert_to_orderings(&orderings);
631 eq_properties.add_orderings(orderings);
632
633 let proj_exprs = proj_exprs
634 .into_iter()
635 .map(|(expr, name)| (Arc::clone(expr), name));
636 let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
637 let output_schema = output_schema(&projection_mapping, &schema)?;
638
639 let expected = expected
640 .into_iter()
641 .map(|ordering| {
642 ordering
643 .into_iter()
644 .map(|(name, options)| {
645 (col(name, &output_schema).unwrap(), options)
646 })
647 .collect::<Vec<_>>()
648 })
649 .collect::<Vec<_>>();
650 let expected = convert_to_orderings(&expected);
651
652 let projected_eq = eq_properties.project(&projection_mapping, output_schema);
653 let orderings = projected_eq.oeq_class();
654
655 let err_msg = format!(
656 "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
657 );
658
659 assert_eq!(orderings.len(), expected.len(), "{err_msg}");
660 for expected_ordering in &expected {
661 assert!(orderings.contains(expected_ordering), "{}", err_msg)
662 }
663 }
664
665 Ok(())
666 }
667
668 #[test]
669 fn project_orderings2() -> Result<()> {
670 let schema = Arc::new(Schema::new(vec![
671 Field::new("a", DataType::Int32, true),
672 Field::new("b", DataType::Int32, true),
673 Field::new("c", DataType::Int32, true),
674 Field::new("d", DataType::Int32, true),
675 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
676 ]));
677 let col_a = &col("a", &schema)?;
678 let col_b = &col("b", &schema)?;
679 let col_c = &col("c", &schema)?;
680 let col_ts = &col("ts", &schema)?;
681 let a_plus_b = Arc::new(BinaryExpr::new(
682 Arc::clone(col_a),
683 Operator::Plus,
684 Arc::clone(col_b),
685 )) as Arc<dyn PhysicalExpr>;
686
687 let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
688
689 let round_c = Arc::new(ScalarFunctionExpr::try_new(
690 test_fun,
691 vec![Arc::clone(col_c)],
692 &schema,
693 Arc::new(ConfigOptions::default()),
694 )?) as PhysicalExprRef;
695
696 let option_asc = SortOptions {
697 descending: false,
698 nulls_first: false,
699 };
700
701 let proj_exprs = vec![
702 (col_b, "b_new".to_string()),
703 (col_a, "a_new".to_string()),
704 (col_c, "c_new".to_string()),
705 (&round_c, "round_c_res".to_string()),
706 ];
707 let proj_exprs = proj_exprs
708 .into_iter()
709 .map(|(expr, name)| (Arc::clone(expr), name));
710 let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
711 let output_schema = output_schema(&projection_mapping, &schema)?;
712
713 let col_a_new = &col("a_new", &output_schema)?;
714 let col_b_new = &col("b_new", &output_schema)?;
715 let col_c_new = &col("c_new", &output_schema)?;
716 let col_round_c_res = &col("round_c_res", &output_schema)?;
717 let a_new_plus_b_new = Arc::new(BinaryExpr::new(
718 Arc::clone(col_a_new),
719 Operator::Plus,
720 Arc::clone(col_b_new),
721 )) as Arc<dyn PhysicalExpr>;
722
723 let test_cases = vec![
724 // ---------- TEST CASE 1 ------------
725 (
726 // orderings
727 vec![
728 // [a ASC]
729 vec![(col_a, option_asc)],
730 ],
731 // expected
732 vec![
733 // [b_new ASC]
734 vec![(col_a_new, option_asc)],
735 ],
736 ),
737 // ---------- TEST CASE 2 ------------
738 (
739 // orderings
740 vec![
741 // [a+b ASC]
742 vec![(&a_plus_b, option_asc)],
743 ],
744 // expected
745 vec![
746 // [b_new ASC]
747 vec![(&a_new_plus_b_new, option_asc)],
748 ],
749 ),
750 // ---------- TEST CASE 3 ------------
751 (
752 // orderings
753 vec![
754 // [a ASC, ts ASC]
755 vec![(col_a, option_asc), (col_ts, option_asc)],
756 ],
757 // expected
758 vec![
759 // [a_new ASC, date_bin_res ASC]
760 vec![(col_a_new, option_asc)],
761 ],
762 ),
763 // ---------- TEST CASE 4 ------------
764 (
765 // orderings
766 vec![
767 // [a ASC, ts ASC, b ASC]
768 vec![
769 (col_a, option_asc),
770 (col_ts, option_asc),
771 (col_b, option_asc),
772 ],
773 ],
774 // expected
775 vec![
776 // [a_new ASC, date_bin_res ASC]
777 vec![(col_a_new, option_asc)],
778 ],
779 ),
780 // ---------- TEST CASE 5 ------------
781 (
782 // orderings
783 vec![
784 // [a ASC, c ASC]
785 vec![(col_a, option_asc), (col_c, option_asc)],
786 ],
787 // expected
788 vec![
789 // [a_new ASC, round_c_res ASC, c_new ASC]
790 vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
791 // [a_new ASC, c_new ASC]
792 vec![(col_a_new, option_asc), (col_c_new, option_asc)],
793 ],
794 ),
795 // ---------- TEST CASE 6 ------------
796 (
797 // orderings
798 vec![
799 // [c ASC, b ASC]
800 vec![(col_c, option_asc), (col_b, option_asc)],
801 ],
802 // expected
803 vec![
804 // [round_c_res ASC]
805 vec![(col_round_c_res, option_asc)],
806 // [c_new ASC, b_new ASC]
807 vec![(col_c_new, option_asc), (col_b_new, option_asc)],
808 ],
809 ),
810 // ---------- TEST CASE 7 ------------
811 (
812 // orderings
813 vec![
814 // [a+b ASC, c ASC]
815 vec![(&a_plus_b, option_asc), (col_c, option_asc)],
816 ],
817 // expected
818 vec![
819 // [a+b ASC, round(c) ASC, c_new ASC]
820 vec![
821 (&a_new_plus_b_new, option_asc),
822 (col_round_c_res, option_asc),
823 ],
824 // [a+b ASC, c_new ASC]
825 vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
826 ],
827 ),
828 ];
829
830 for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
831 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
832
833 let orderings = convert_to_orderings(orderings);
834 eq_properties.add_orderings(orderings);
835
836 let expected = convert_to_orderings(expected);
837
838 let projected_eq =
839 eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
840 let orderings = projected_eq.oeq_class();
841
842 let err_msg = format!(
843 "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
844 );
845
846 assert_eq!(orderings.len(), expected.len(), "{err_msg}");
847 for expected_ordering in &expected {
848 assert!(orderings.contains(expected_ordering), "{}", err_msg)
849 }
850 }
851 Ok(())
852 }
853
854 #[test]
855 fn project_orderings3() -> Result<()> {
856 let schema = Arc::new(Schema::new(vec![
857 Field::new("a", DataType::Int32, true),
858 Field::new("b", DataType::Int32, true),
859 Field::new("c", DataType::Int32, true),
860 Field::new("d", DataType::Int32, true),
861 Field::new("e", DataType::Int32, true),
862 Field::new("f", DataType::Int32, true),
863 ]));
864 let col_a = &col("a", &schema)?;
865 let col_b = &col("b", &schema)?;
866 let col_c = &col("c", &schema)?;
867 let col_d = &col("d", &schema)?;
868 let col_e = &col("e", &schema)?;
869 let col_f = &col("f", &schema)?;
870 let a_plus_b = Arc::new(BinaryExpr::new(
871 Arc::clone(col_a),
872 Operator::Plus,
873 Arc::clone(col_b),
874 )) as Arc<dyn PhysicalExpr>;
875
876 let option_asc = SortOptions {
877 descending: false,
878 nulls_first: false,
879 };
880
881 let proj_exprs = vec![
882 (col_c, "c_new".to_string()),
883 (col_d, "d_new".to_string()),
884 (&a_plus_b, "a+b".to_string()),
885 ];
886 let proj_exprs = proj_exprs
887 .into_iter()
888 .map(|(expr, name)| (Arc::clone(expr), name));
889 let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?;
890 let output_schema = output_schema(&projection_mapping, &schema)?;
891
892 let col_a_plus_b_new = &col("a+b", &output_schema)?;
893 let col_c_new = &col("c_new", &output_schema)?;
894 let col_d_new = &col("d_new", &output_schema)?;
895
896 let test_cases = vec![
897 // ---------- TEST CASE 1 ------------
898 (
899 // orderings
900 vec![
901 // [d ASC, b ASC]
902 vec![(col_d, option_asc), (col_b, option_asc)],
903 // [c ASC, a ASC]
904 vec![(col_c, option_asc), (col_a, option_asc)],
905 ],
906 // equal conditions
907 vec![],
908 // expected
909 vec![
910 // [d_new ASC, c_new ASC, a+b ASC]
911 vec![
912 (col_d_new, option_asc),
913 (col_c_new, option_asc),
914 (col_a_plus_b_new, option_asc),
915 ],
916 // [c_new ASC, d_new ASC, a+b ASC]
917 vec![
918 (col_c_new, option_asc),
919 (col_d_new, option_asc),
920 (col_a_plus_b_new, option_asc),
921 ],
922 ],
923 ),
924 // ---------- TEST CASE 2 ------------
925 (
926 // orderings
927 vec![
928 // [d ASC, b ASC]
929 vec![(col_d, option_asc), (col_b, option_asc)],
930 // [c ASC, e ASC], Please note that a=e
931 vec![(col_c, option_asc), (col_e, option_asc)],
932 ],
933 // equal conditions
934 vec![(col_e, col_a)],
935 // expected
936 vec![
937 // [d_new ASC, c_new ASC, a+b ASC]
938 vec![
939 (col_d_new, option_asc),
940 (col_c_new, option_asc),
941 (col_a_plus_b_new, option_asc),
942 ],
943 // [c_new ASC, d_new ASC, a+b ASC]
944 vec![
945 (col_c_new, option_asc),
946 (col_d_new, option_asc),
947 (col_a_plus_b_new, option_asc),
948 ],
949 ],
950 ),
951 // ---------- TEST CASE 3 ------------
952 (
953 // orderings
954 vec![
955 // [d ASC, b ASC]
956 vec![(col_d, option_asc), (col_b, option_asc)],
957 // [c ASC, e ASC], Please note that a=f
958 vec![(col_c, option_asc), (col_e, option_asc)],
959 ],
960 // equal conditions
961 vec![(col_a, col_f)],
962 // expected
963 vec![
964 // [d_new ASC]
965 vec![(col_d_new, option_asc)],
966 // [c_new ASC]
967 vec![(col_c_new, option_asc)],
968 ],
969 ),
970 ];
971 for (orderings, equal_columns, expected) in test_cases {
972 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
973 for (lhs, rhs) in equal_columns {
974 eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?;
975 }
976
977 let orderings = convert_to_orderings(&orderings);
978 eq_properties.add_orderings(orderings);
979
980 let expected = convert_to_orderings(&expected);
981
982 let projected_eq =
983 eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
984 let orderings = projected_eq.oeq_class();
985
986 let err_msg = format!(
987 "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}"
988 );
989
990 assert_eq!(orderings.len(), expected.len(), "{err_msg}");
991 for expected_ordering in &expected {
992 assert!(orderings.contains(expected_ordering), "{}", err_msg)
993 }
994 }
995
996 Ok(())
997 }
998}