1use std::fmt::Display;
19use std::ops::Deref;
20use std::sync::Arc;
21use std::vec::IntoIter;
22
23use super::ProjectionMapping;
24use crate::expressions::Literal;
25use crate::physical_expr::add_offset_to_expr;
26use crate::projection::ProjectionTargets;
27use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement};
28
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{HashMap, JoinType, Result, ScalarValue};
31use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
32
33use indexmap::{IndexMap, IndexSet};
34
35#[derive(Clone, Debug, Default, Eq, PartialEq)]
42pub enum AcrossPartitions {
43 #[default]
44 Heterogeneous,
45 Uniform(Option<ScalarValue>),
46}
47
48impl Display for AcrossPartitions {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"),
52 AcrossPartitions::Uniform(value) => {
53 if let Some(val) = value {
54 write!(f, "(uniform: {val})")
55 } else {
56 write!(f, "(uniform: unknown)")
57 }
58 }
59 }
60 }
61}
62
63#[derive(Clone, Debug)]
88pub struct ConstExpr {
89 pub expr: Arc<dyn PhysicalExpr>,
91 pub across_partitions: AcrossPartitions,
93}
94impl ConstExpr {
100 pub fn new(expr: Arc<dyn PhysicalExpr>, across_partitions: AcrossPartitions) -> Self {
108 let mut result = ConstExpr::from(expr);
109 if result.across_partitions == AcrossPartitions::Heterogeneous {
112 result.across_partitions = across_partitions;
113 }
114 result
115 }
116
117 pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
119 struct DisplayableList<'a>(&'a [ConstExpr]);
120 impl Display for DisplayableList<'_> {
121 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122 let mut first = true;
123 for const_expr in self.0 {
124 if first {
125 first = false;
126 } else {
127 write!(f, ",")?;
128 }
129 write!(f, "{const_expr}")?;
130 }
131 Ok(())
132 }
133 }
134 DisplayableList(input)
135 }
136}
137
138impl PartialEq for ConstExpr {
139 fn eq(&self, other: &Self) -> bool {
140 self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
141 }
142}
143
144impl Display for ConstExpr {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 write!(f, "{}", self.expr)?;
147 write!(f, "{}", self.across_partitions)
148 }
149}
150
151impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
152 fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
153 let across = if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
157 AcrossPartitions::Uniform(Some(lit.value().clone()))
158 } else {
159 AcrossPartitions::Heterogeneous
160 };
161 Self {
162 expr,
163 across_partitions: across,
164 }
165 }
166}
167
168#[derive(Clone, Debug, Default, Eq, PartialEq)]
176pub struct EquivalenceClass {
177 pub(crate) exprs: IndexSet<Arc<dyn PhysicalExpr>>,
180 pub(crate) constant: Option<AcrossPartitions>,
183}
184
185impl EquivalenceClass {
186 pub fn new(exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>) -> Self {
188 let mut class = Self::default();
189 for expr in exprs {
190 class.push(expr);
191 }
192 class
193 }
194
195 pub fn canonical_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
198 self.exprs.iter().next()
199 }
200
201 pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
204 if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
205 let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone()));
206 if let Some(across) = self.constant.as_mut() {
207 if *across == AcrossPartitions::Heterogeneous {
209 *across = expr_across;
210 }
211 } else {
212 self.constant = Some(expr_across);
213 }
214 }
215 self.exprs.insert(expr);
216 }
217
218 pub fn extend(&mut self, other: Self) {
220 self.exprs.extend(other.exprs);
221 match (&self.constant, &other.constant) {
222 (Some(across), Some(_)) => {
223 if across == &AcrossPartitions::Heterogeneous {
225 self.constant = other.constant;
226 }
227 }
228 (None, Some(_)) => self.constant = other.constant,
229 (_, None) => {}
230 }
231 }
232
233 pub fn contains_any(&self, other: &Self) -> bool {
236 self.exprs.intersection(&other.exprs).next().is_some()
237 }
238
239 pub fn is_trivial(&self) -> bool {
243 self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none())
244 }
245
246 pub fn try_with_offset(&self, offset: isize) -> Result<Self> {
249 let mut cls = Self::default();
250 for expr_result in self
251 .exprs
252 .iter()
253 .cloned()
254 .map(|e| add_offset_to_expr(e, offset))
255 {
256 cls.push(expr_result?);
257 }
258 Ok(cls)
259 }
260}
261
262impl Deref for EquivalenceClass {
263 type Target = IndexSet<Arc<dyn PhysicalExpr>>;
264
265 fn deref(&self) -> &Self::Target {
266 &self.exprs
267 }
268}
269
270impl IntoIterator for EquivalenceClass {
271 type Item = Arc<dyn PhysicalExpr>;
272 type IntoIter = <IndexSet<Self::Item> as IntoIterator>::IntoIter;
273
274 fn into_iter(self) -> Self::IntoIter {
275 self.exprs.into_iter()
276 }
277}
278
279impl Display for EquivalenceClass {
280 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281 write!(f, "{{")?;
282 write!(f, "members: {}", format_physical_expr_list(&self.exprs))?;
283 if let Some(across) = &self.constant {
284 write!(f, ", constant: {across}")?;
285 }
286 write!(f, "}}")
287 }
288}
289
290impl From<EquivalenceClass> for Vec<Arc<dyn PhysicalExpr>> {
291 fn from(cls: EquivalenceClass) -> Self {
292 cls.exprs.into_iter().collect()
293 }
294}
295
296type AugmentedMapping<'a> = IndexMap<
297 &'a Arc<dyn PhysicalExpr>,
298 (&'a ProjectionTargets, Option<&'a EquivalenceClass>),
299>;
300
301#[derive(Clone, Debug, Default)]
304pub struct EquivalenceGroup {
305 map: HashMap<Arc<dyn PhysicalExpr>, usize>,
307 classes: Vec<EquivalenceClass>,
309}
310
311impl EquivalenceGroup {
312 pub fn new(classes: impl IntoIterator<Item = EquivalenceClass>) -> Self {
314 classes.into_iter().collect::<Vec<_>>().into()
315 }
316
317 pub fn add_constant(&mut self, const_expr: ConstExpr) {
319 if let Some(idx) = self.map.get(&const_expr.expr) {
322 let cls = &mut self.classes[*idx];
323 if let Some(across) = cls.constant.as_mut() {
324 if *across == AcrossPartitions::Heterogeneous {
326 *across = const_expr.across_partitions;
327 }
328 } else {
329 cls.constant = Some(const_expr.across_partitions);
330 }
331 return;
332 }
333 if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions {
336 for (idx, cls) in self.classes.iter_mut().enumerate() {
337 if cls
338 .constant
339 .as_ref()
340 .is_some_and(|across| const_expr.across_partitions.eq(across))
341 {
342 self.map.insert(Arc::clone(&const_expr.expr), idx);
343 cls.push(const_expr.expr);
344 return;
345 }
346 }
347 }
348 let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr));
350 if new_class.constant.is_none() {
351 new_class.constant = Some(const_expr.across_partitions);
352 }
353 Self::update_lookup_table(&mut self.map, &new_class, self.classes.len());
354 self.classes.push(new_class);
355 }
356
357 pub fn clear_per_partition_constants(&mut self) -> bool {
361 let (mut idx, mut change) = (0, false);
362 while idx < self.classes.len() {
363 let cls = &mut self.classes[idx];
364 if let Some(AcrossPartitions::Heterogeneous) = cls.constant {
365 change = true;
366 if cls.len() == 1 {
367 self.remove_class_at_idx(idx);
369 continue;
370 } else {
371 cls.constant = None;
372 }
373 }
374 idx += 1;
375 }
376 change
377 }
378
379 pub fn add_equal_conditions(
384 &mut self,
385 left: Arc<dyn PhysicalExpr>,
386 right: Arc<dyn PhysicalExpr>,
387 ) -> bool {
388 let first_class = self.map.get(&left).copied();
389 let second_class = self.map.get(&right).copied();
390 match (first_class, second_class) {
391 (Some(mut first_idx), Some(mut second_idx)) => {
392 match first_idx.cmp(&second_idx) {
395 std::cmp::Ordering::Equal => return false,
397 std::cmp::Ordering::Greater => {
399 std::mem::swap(&mut first_idx, &mut second_idx);
400 }
401 _ => {}
402 }
403 let other_class = self.remove_class_at_idx(second_idx);
407 Self::update_lookup_table(&mut self.map, &other_class, first_idx);
409 self.classes[first_idx].extend(other_class);
410 }
411 (Some(group_idx), None) => {
412 self.map.insert(Arc::clone(&right), group_idx);
414 self.classes[group_idx].push(right);
415 }
416 (None, Some(group_idx)) => {
417 self.map.insert(Arc::clone(&left), group_idx);
419 self.classes[group_idx].push(left);
420 }
421 (None, None) => {
422 let class = EquivalenceClass::new([left, right]);
425 Self::update_lookup_table(&mut self.map, &class, self.classes.len());
426 self.classes.push(class);
427 return true;
428 }
429 }
430 false
431 }
432
433 fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass {
435 let cls = self.classes.swap_remove(idx);
437 for expr in cls.iter() {
439 self.map.remove(expr);
440 }
441 if idx < self.classes.len() {
443 Self::update_lookup_table(&mut self.map, &self.classes[idx], idx);
444 }
445 cls
446 }
447
448 fn update_lookup_table(
451 map: &mut HashMap<Arc<dyn PhysicalExpr>, usize>,
452 cls: &EquivalenceClass,
453 idx: usize,
454 ) {
455 for expr in cls.iter() {
456 map.insert(Arc::clone(expr), idx);
457 }
458 }
459
460 fn remove_redundant_entries(&mut self) -> bool {
463 let mut change = false;
465 for idx in (0..self.classes.len()).rev() {
466 if self.classes[idx].is_trivial() {
467 self.remove_class_at_idx(idx);
468 change = true;
469 }
470 }
471 self.bridge_classes() || change
473 }
474
475 fn bridge_classes(&mut self) -> bool {
481 let (mut idx, mut change) = (0, false);
482 'scan: while idx < self.classes.len() {
483 for other_idx in (idx + 1..self.classes.len()).rev() {
484 if self.classes[idx].contains_any(&self.classes[other_idx]) {
485 let extension = self.remove_class_at_idx(other_idx);
486 Self::update_lookup_table(&mut self.map, &extension, idx);
487 self.classes[idx].extend(extension);
488 change = true;
489 continue 'scan;
490 }
491 }
492 idx += 1;
493 }
494 change
495 }
496
497 pub fn extend(&mut self, other: Self) -> bool {
501 for (idx, cls) in other.classes.iter().enumerate() {
502 Self::update_lookup_table(&mut self.map, cls, idx);
504 }
505 self.classes.extend(other.classes);
506 self.bridge_classes()
507 }
508
509 pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
513 expr.transform(|expr| {
514 let cls = self.get_equivalence_class(&expr);
515 let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else {
516 return Ok(Transformed::no(expr));
517 };
518 Ok(Transformed::yes(Arc::clone(canonical)))
519 })
520 .data()
521 .unwrap()
522 }
524
525 pub fn normalize_sort_expr(
531 &self,
532 mut sort_expr: PhysicalSortExpr,
533 ) -> PhysicalSortExpr {
534 sort_expr.expr = self.normalize_expr(sort_expr.expr);
535 sort_expr
536 }
537
538 pub fn normalize_sort_exprs<'a>(
547 &'a self,
548 sort_exprs: impl IntoIterator<Item = PhysicalSortExpr> + 'a,
549 ) -> impl Iterator<Item = PhysicalSortExpr> + 'a {
550 sort_exprs
551 .into_iter()
552 .map(|sort_expr| self.normalize_sort_expr(sort_expr))
553 .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none())
554 }
555
556 pub fn normalize_sort_requirement(
562 &self,
563 mut sort_requirement: PhysicalSortRequirement,
564 ) -> PhysicalSortRequirement {
565 sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
566 sort_requirement
567 }
568
569 pub fn normalize_sort_requirements<'a>(
578 &'a self,
579 sort_reqs: impl IntoIterator<Item = PhysicalSortRequirement> + 'a,
580 ) -> impl Iterator<Item = PhysicalSortRequirement> + 'a {
581 sort_reqs
582 .into_iter()
583 .map(|req| self.normalize_sort_requirement(req))
584 .filter(|req| self.is_expr_constant(&req.expr).is_none())
585 }
586
587 fn project_expr_indirect(
590 aug_mapping: &AugmentedMapping,
591 expr: &Arc<dyn PhysicalExpr>,
592 ) -> Option<Arc<dyn PhysicalExpr>> {
593 if expr.as_any().downcast_ref::<Literal>().is_some() {
595 return Some(Arc::clone(expr));
596 }
597
598 for (targets, eq_class) in aug_mapping.values() {
601 if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) {
606 let (target, _) = targets.first();
607 return Some(Arc::clone(target));
608 }
609 }
610 let children = expr.children();
612 if children.is_empty() {
613 return None;
615 }
616 children
617 .into_iter()
618 .map(|child| {
619 if let Some((targets, _)) = aug_mapping.get(child) {
622 let (target, _) = targets.first();
624 Some(Arc::clone(target))
625 } else {
626 Self::project_expr_indirect(aug_mapping, child)
627 }
628 })
629 .collect::<Option<Vec<_>>>()
630 .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
631 }
632
633 fn augment_projection_mapping<'a>(
634 &'a self,
635 mapping: &'a ProjectionMapping,
636 ) -> AugmentedMapping<'a> {
637 mapping
638 .iter()
639 .map(|(k, v)| {
640 let eq_class = self.get_equivalence_class(k);
641 (k, (v, eq_class))
642 })
643 .collect()
644 }
645
646 pub fn project_expr(
649 &self,
650 mapping: &ProjectionMapping,
651 expr: &Arc<dyn PhysicalExpr>,
652 ) -> Option<Arc<dyn PhysicalExpr>> {
653 if let Some(targets) = mapping.get(expr) {
654 let (target, _) = targets.first();
656 Some(Arc::clone(target))
657 } else {
658 let aug_mapping = self.augment_projection_mapping(mapping);
659 Self::project_expr_indirect(&aug_mapping, expr)
660 }
661 }
662
663 pub fn project_expressions<'a>(
668 &'a self,
669 mapping: &'a ProjectionMapping,
670 expressions: impl IntoIterator<Item = &'a Arc<dyn PhysicalExpr>> + 'a,
671 ) -> impl Iterator<Item = Option<Arc<dyn PhysicalExpr>>> + 'a {
672 let mut aug_mapping = None;
673 expressions.into_iter().map(move |expr| {
674 if let Some(targets) = mapping.get(expr) {
675 let (target, _) = targets.first();
677 Some(Arc::clone(target))
678 } else {
679 let aug_mapping = aug_mapping
680 .get_or_insert_with(|| self.augment_projection_mapping(mapping));
681 Self::project_expr_indirect(aug_mapping, expr)
682 }
683 })
684 }
685
686 pub fn project(&self, mapping: &ProjectionMapping) -> Self {
688 let projected_classes = self.iter().map(|cls| {
689 let new_exprs = self.project_expressions(mapping, cls.iter());
690 EquivalenceClass::new(new_exprs.flatten())
691 });
692
693 let mut new_constants = vec![];
696 let mut new_classes = IndexMap::<_, EquivalenceClass>::new();
697 for (source, targets) in mapping.iter() {
698 let normalized_expr = self.normalize_expr(Arc::clone(source));
704 let cls = new_classes.entry(normalized_expr).or_default();
705 for (target, _) in targets.iter() {
706 cls.push(Arc::clone(target));
707 }
708 if let Some(across) = self.is_expr_constant(source) {
710 for (target, _) in targets.iter() {
711 let const_expr = ConstExpr::new(Arc::clone(target), across.clone());
712 new_constants.push(const_expr);
713 }
714 }
715 }
716
717 let classes = projected_classes
719 .chain(new_classes.into_values())
720 .filter(|cls| !cls.is_trivial());
721 let mut result = Self::new(classes);
722 for constant in new_constants {
724 result.add_constant(constant);
725 }
726 result
727 }
728
729 pub fn is_expr_constant(
734 &self,
735 expr: &Arc<dyn PhysicalExpr>,
736 ) -> Option<AcrossPartitions> {
737 if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
738 return Some(AcrossPartitions::Uniform(Some(lit.value().clone())));
739 }
740 if let Some(cls) = self.get_equivalence_class(expr) {
741 if cls.constant.is_some() {
742 return cls.constant.clone();
743 }
744 }
745 let children = expr.children();
749 if children.is_empty() {
750 return None;
751 }
752 for child in children {
753 self.is_expr_constant(child)?;
754 }
755 Some(AcrossPartitions::Heterogeneous)
756 }
757
758 pub fn get_equivalence_class(
761 &self,
762 expr: &Arc<dyn PhysicalExpr>,
763 ) -> Option<&EquivalenceClass> {
764 self.map.get(expr).map(|idx| &self.classes[*idx])
765 }
766
767 pub fn join(
769 &self,
770 right_equivalences: &Self,
771 join_type: &JoinType,
772 left_size: usize,
773 on: &[(PhysicalExprRef, PhysicalExprRef)],
774 ) -> Result<Self> {
775 let group = match join_type {
776 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
777 let mut result = Self::new(
778 self.iter().cloned().chain(
779 right_equivalences
780 .iter()
781 .map(|cls| cls.try_with_offset(left_size as _))
782 .collect::<Result<Vec<_>>>()?,
783 ),
784 );
785 if join_type == &JoinType::Inner {
788 for (lhs, rhs) in on.iter() {
789 let new_lhs = Arc::clone(lhs);
790 let new_rhs =
792 add_offset_to_expr(Arc::clone(rhs), left_size as _)?;
793 result.add_equal_conditions(new_lhs, new_rhs);
794 }
795 }
796 result
797 }
798 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
799 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
800 right_equivalences.clone()
801 }
802 };
803 Ok(group)
804 }
805
806 pub fn exprs_equal(
811 &self,
812 left: &Arc<dyn PhysicalExpr>,
813 right: &Arc<dyn PhysicalExpr>,
814 ) -> bool {
815 if left.eq(right) {
817 return true;
818 }
819
820 if let Some(left_class) = self.get_equivalence_class(left) {
823 if left_class.contains(right) {
824 return true;
825 }
826 }
827 if let Some(right_class) = self.get_equivalence_class(right) {
828 if right_class.contains(left) {
829 return true;
830 }
831 }
832
833 let left_children = left.children();
835 let right_children = right.children();
836
837 if left_children.is_empty() || right_children.is_empty() {
840 return false;
841 }
842
843 if left.as_any().type_id() != right.as_any().type_id() {
845 return false;
846 }
847
848 if left_children.len() != right_children.len() {
850 return false;
851 }
852
853 left_children
855 .into_iter()
856 .zip(right_children)
857 .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
858 }
859}
860
861impl Deref for EquivalenceGroup {
862 type Target = [EquivalenceClass];
863
864 fn deref(&self) -> &Self::Target {
865 &self.classes
866 }
867}
868
869impl IntoIterator for EquivalenceGroup {
870 type Item = EquivalenceClass;
871 type IntoIter = IntoIter<Self::Item>;
872
873 fn into_iter(self) -> Self::IntoIter {
874 self.classes.into_iter()
875 }
876}
877
878impl Display for EquivalenceGroup {
879 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
880 write!(f, "[")?;
881 let mut iter = self.iter();
882 if let Some(cls) = iter.next() {
883 write!(f, "{cls}")?;
884 }
885 for cls in iter {
886 write!(f, ", {cls}")?;
887 }
888 write!(f, "]")
889 }
890}
891
892impl From<Vec<EquivalenceClass>> for EquivalenceGroup {
893 fn from(classes: Vec<EquivalenceClass>) -> Self {
894 let mut result = Self {
895 map: classes
896 .iter()
897 .enumerate()
898 .flat_map(|(idx, cls)| {
899 cls.iter().map(move |expr| (Arc::clone(expr), idx))
900 })
901 .collect(),
902 classes,
903 };
904 result.remove_redundant_entries();
905 result
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912 use crate::equivalence::tests::create_test_params;
913 use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal};
914 use arrow::datatypes::{DataType, Field, Schema};
915
916 use datafusion_common::{Result, ScalarValue};
917 use datafusion_expr::Operator;
918
919 #[test]
920 fn test_bridge_groups() -> Result<()> {
921 let test_cases = vec![
923 (
925 vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
926 vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
928 ),
929 (
931 vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
932 vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
934 ),
935 ];
936 for (entries, expected) in test_cases {
937 let entries = entries
938 .into_iter()
939 .map(|entry| {
940 entry.into_iter().map(|idx| {
941 let c = Column::new(format!("col_{idx}").as_str(), idx);
942 Arc::new(c) as _
943 })
944 })
945 .map(EquivalenceClass::new)
946 .collect::<Vec<_>>();
947 let expected = expected
948 .into_iter()
949 .map(|entry| {
950 entry.into_iter().map(|idx| {
951 let c = Column::new(format!("col_{idx}").as_str(), idx);
952 Arc::new(c) as _
953 })
954 })
955 .map(EquivalenceClass::new)
956 .collect::<Vec<_>>();
957 let eq_groups: EquivalenceGroup = entries.clone().into();
958 let eq_groups = eq_groups.classes;
959 let err_msg = format!(
960 "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
961 );
962 assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
963 for idx in 0..eq_groups.len() {
964 assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
965 }
966 }
967 Ok(())
968 }
969
970 #[test]
971 fn test_remove_redundant_entries_eq_group() -> Result<()> {
972 let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _;
973 let entries = [
974 EquivalenceClass::new([c(1), c(1), lit(20)]),
975 EquivalenceClass::new([lit(30), lit(30)]),
976 EquivalenceClass::new([c(2), c(3), c(4)]),
977 ];
978 let expected = [
981 EquivalenceClass::new([c(1), lit(20)]),
982 EquivalenceClass::new([lit(30)]),
983 EquivalenceClass::new([c(2), c(3), c(4)]),
984 ];
985 let eq_groups = EquivalenceGroup::new(entries);
986 assert_eq!(eq_groups.classes, expected);
987 Ok(())
988 }
989
990 #[test]
991 fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
992 let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
993 let col_b = Arc::new(Column::new("b", 1)) as _;
994 let col_c = Arc::new(Column::new("c", 2)) as _;
995 let (_, eq_properties) = create_test_params()?;
997 let expressions = vec![
1000 (Arc::clone(&col_a), Arc::clone(&col_a)),
1004 (col_c, col_a),
1005 (Arc::clone(&col_b), Arc::clone(&col_b)),
1007 ];
1008 let eq_group = eq_properties.eq_group();
1009 for (expr, expected_eq) in expressions {
1010 assert!(expected_eq.eq(&eq_group.normalize_expr(expr)));
1011 }
1012
1013 Ok(())
1014 }
1015
1016 #[test]
1017 fn test_contains_any() {
1018 let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _;
1019 let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _;
1020 let col_a_expr = Arc::new(Column::new("a", 0)) as _;
1021 let col_b_expr = Arc::new(Column::new("b", 1)) as _;
1022 let col_c_expr = Arc::new(Column::new("c", 2)) as _;
1023
1024 let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]);
1025 let cls2 = EquivalenceClass::new([lit_true, col_b_expr]);
1026 let cls3 = EquivalenceClass::new([col_c_expr, lit_false]);
1027
1028 assert!(cls1.contains_any(&cls2));
1030 assert!(!cls1.contains_any(&cls3));
1032 assert!(!cls2.contains_any(&cls3));
1033 }
1034
1035 #[test]
1036 fn test_exprs_equal() -> Result<()> {
1037 struct TestCase {
1038 left: Arc<dyn PhysicalExpr>,
1039 right: Arc<dyn PhysicalExpr>,
1040 expected: bool,
1041 description: &'static str,
1042 }
1043
1044 let col_a = Arc::new(Column::new("a", 0)) as _;
1046 let col_b = Arc::new(Column::new("b", 1)) as _;
1047 let col_x = Arc::new(Column::new("x", 2)) as _;
1048 let col_y = Arc::new(Column::new("y", 3)) as _;
1049
1050 let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _;
1052 let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _;
1053
1054 let eq_group = EquivalenceGroup::new([
1056 EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]),
1057 EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]),
1058 ]);
1059
1060 let test_cases = vec![
1061 TestCase {
1063 left: Arc::clone(&col_a),
1064 right: Arc::clone(&col_a),
1065 expected: true,
1066 description: "Same column should be equal",
1067 },
1068 TestCase {
1070 left: Arc::clone(&col_a),
1071 right: Arc::clone(&col_x),
1072 expected: true,
1073 description: "Columns in same equivalence class should be equal",
1074 },
1075 TestCase {
1076 left: Arc::clone(&col_b),
1077 right: Arc::clone(&col_y),
1078 expected: true,
1079 description: "Columns in same equivalence class should be equal",
1080 },
1081 TestCase {
1082 left: Arc::clone(&col_a),
1083 right: Arc::clone(&col_b),
1084 expected: false,
1085 description:
1086 "Columns in different equivalence classes should not be equal",
1087 },
1088 TestCase {
1090 left: Arc::clone(&lit_1),
1091 right: Arc::clone(&lit_1),
1092 expected: true,
1093 description: "Same literal should be equal",
1094 },
1095 TestCase {
1096 left: Arc::clone(&lit_1),
1097 right: Arc::clone(&lit_2),
1098 expected: false,
1099 description: "Different literals should not be equal",
1100 },
1101 TestCase {
1103 left: Arc::new(BinaryExpr::new(
1104 Arc::clone(&col_a),
1105 Operator::Plus,
1106 Arc::clone(&col_b),
1107 )) as _,
1108 right: Arc::new(BinaryExpr::new(
1109 Arc::clone(&col_x),
1110 Operator::Plus,
1111 Arc::clone(&col_y),
1112 )) as _,
1113 expected: true,
1114 description:
1115 "Binary expressions with equivalent operands should be equal",
1116 },
1117 TestCase {
1118 left: Arc::new(BinaryExpr::new(
1119 Arc::clone(&col_a),
1120 Operator::Plus,
1121 Arc::clone(&col_b),
1122 )) as _,
1123 right: Arc::new(BinaryExpr::new(
1124 Arc::clone(&col_x),
1125 Operator::Plus,
1126 Arc::clone(&col_a),
1127 )) as _,
1128 expected: false,
1129 description:
1130 "Binary expressions with non-equivalent operands should not be equal",
1131 },
1132 TestCase {
1133 left: Arc::new(BinaryExpr::new(
1134 Arc::clone(&col_a),
1135 Operator::Plus,
1136 Arc::clone(&lit_1),
1137 )) as _,
1138 right: Arc::new(BinaryExpr::new(
1139 Arc::clone(&col_x),
1140 Operator::Plus,
1141 Arc::clone(&lit_1),
1142 )) as _,
1143 expected: true,
1144 description: "Binary expressions with equivalent column and same literal should be equal",
1145 },
1146 TestCase {
1147 left: Arc::new(BinaryExpr::new(
1148 Arc::new(BinaryExpr::new(
1149 Arc::clone(&col_a),
1150 Operator::Plus,
1151 Arc::clone(&col_b),
1152 )),
1153 Operator::Multiply,
1154 Arc::clone(&lit_1),
1155 )) as _,
1156 right: Arc::new(BinaryExpr::new(
1157 Arc::new(BinaryExpr::new(
1158 Arc::clone(&col_x),
1159 Operator::Plus,
1160 Arc::clone(&col_y),
1161 )),
1162 Operator::Multiply,
1163 Arc::clone(&lit_1),
1164 )) as _,
1165 expected: true,
1166 description: "Nested binary expressions with equivalent operands should be equal",
1167 },
1168 ];
1169
1170 for TestCase {
1171 left,
1172 right,
1173 expected,
1174 description,
1175 } in test_cases
1176 {
1177 let actual = eq_group.exprs_equal(&left, &right);
1178 assert_eq!(
1179 actual, expected,
1180 "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1181 );
1182 }
1183
1184 Ok(())
1185 }
1186
1187 #[test]
1188 fn test_project_classes() -> Result<()> {
1189 let schema = Arc::new(Schema::new(vec![
1194 Field::new("a", DataType::Int32, false),
1195 Field::new("b", DataType::Int32, false),
1196 Field::new("c", DataType::Int32, false),
1197 ]));
1198 let mut group = EquivalenceGroup::default();
1199 group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?);
1200
1201 let projected_schema = Arc::new(Schema::new(vec![
1202 Field::new("a+c", DataType::Int32, false),
1203 Field::new("b+c", DataType::Int32, false),
1204 ]));
1205
1206 let mapping = [
1207 (
1208 binary(
1209 col("a", &schema)?,
1210 Operator::Plus,
1211 col("c", &schema)?,
1212 &schema,
1213 )?,
1214 vec![(col("a+c", &projected_schema)?, 0)].into(),
1215 ),
1216 (
1217 binary(
1218 col("b", &schema)?,
1219 Operator::Plus,
1220 col("c", &schema)?,
1221 &schema,
1222 )?,
1223 vec![(col("b+c", &projected_schema)?, 1)].into(),
1224 ),
1225 ]
1226 .into_iter()
1227 .collect::<ProjectionMapping>();
1228
1229 let projected = group.project(&mapping);
1230
1231 assert!(!projected.is_empty());
1232 let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1233 let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1234
1235 assert!(first_normalized.eq(&second_normalized));
1236
1237 Ok(())
1238 }
1239}