1use std::fmt::Display;
19use std::ops::Deref;
20use std::sync::Arc;
21use std::vec::IntoIter;
22
23use super::projection::ProjectionTargets;
24use super::ProjectionMapping;
25use crate::expressions::Literal;
26use crate::physical_expr::add_offset_to_expr;
27use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement};
28
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{HashMap, JoinType, Result, ScalarValue};
31use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
32
33use indexmap::{IndexMap, IndexSet};
34
35#[derive(Clone, Debug, Default, Eq, PartialEq)]
42pub enum AcrossPartitions {
43 #[default]
44 Heterogeneous,
45 Uniform(Option<ScalarValue>),
46}
47
48impl Display for AcrossPartitions {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"),
52 AcrossPartitions::Uniform(value) => {
53 if let Some(val) = value {
54 write!(f, "(uniform: {val})")
55 } else {
56 write!(f, "(uniform: unknown)")
57 }
58 }
59 }
60 }
61}
62
63#[derive(Clone, Debug)]
88pub struct ConstExpr {
89 pub expr: Arc<dyn PhysicalExpr>,
91 pub across_partitions: AcrossPartitions,
93}
94impl ConstExpr {
100 pub fn new(expr: Arc<dyn PhysicalExpr>, across_partitions: AcrossPartitions) -> Self {
108 let mut result = ConstExpr::from(expr);
109 if result.across_partitions == AcrossPartitions::Heterogeneous {
112 result.across_partitions = across_partitions;
113 }
114 result
115 }
116
117 pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
119 struct DisplayableList<'a>(&'a [ConstExpr]);
120 impl Display for DisplayableList<'_> {
121 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122 let mut first = true;
123 for const_expr in self.0 {
124 if first {
125 first = false;
126 } else {
127 write!(f, ",")?;
128 }
129 write!(f, "{const_expr}")?;
130 }
131 Ok(())
132 }
133 }
134 DisplayableList(input)
135 }
136}
137
138impl PartialEq for ConstExpr {
139 fn eq(&self, other: &Self) -> bool {
140 self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
141 }
142}
143
144impl Display for ConstExpr {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 write!(f, "{}", self.expr)?;
147 write!(f, "{}", self.across_partitions)
148 }
149}
150
151impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
152 fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
153 let across = if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
157 AcrossPartitions::Uniform(Some(lit.value().clone()))
158 } else {
159 AcrossPartitions::Heterogeneous
160 };
161 Self {
162 expr,
163 across_partitions: across,
164 }
165 }
166}
167
168#[derive(Clone, Debug, Default, Eq, PartialEq)]
176pub struct EquivalenceClass {
177 pub(crate) exprs: IndexSet<Arc<dyn PhysicalExpr>>,
180 pub(crate) constant: Option<AcrossPartitions>,
183}
184
185impl EquivalenceClass {
186 pub fn new(exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>) -> Self {
188 let mut class = Self::default();
189 for expr in exprs {
190 class.push(expr);
191 }
192 class
193 }
194
195 pub fn canonical_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
198 self.exprs.iter().next()
199 }
200
201 pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
204 if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
205 let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone()));
206 if let Some(across) = self.constant.as_mut() {
207 if *across == AcrossPartitions::Heterogeneous {
209 *across = expr_across;
210 }
211 } else {
212 self.constant = Some(expr_across);
213 }
214 }
215 self.exprs.insert(expr);
216 }
217
218 pub fn extend(&mut self, other: Self) {
220 self.exprs.extend(other.exprs);
221 match (&self.constant, &other.constant) {
222 (Some(across), Some(_)) => {
223 if across == &AcrossPartitions::Heterogeneous {
225 self.constant = other.constant;
226 }
227 }
228 (None, Some(_)) => self.constant = other.constant,
229 (_, None) => {}
230 }
231 }
232
233 pub fn contains_any(&self, other: &Self) -> bool {
236 self.exprs.intersection(&other.exprs).next().is_some()
237 }
238
239 pub fn is_trivial(&self) -> bool {
243 self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none())
244 }
245
246 pub fn try_with_offset(&self, offset: isize) -> Result<Self> {
249 let mut cls = Self::default();
250 for expr_result in self
251 .exprs
252 .iter()
253 .cloned()
254 .map(|e| add_offset_to_expr(e, offset))
255 {
256 cls.push(expr_result?);
257 }
258 Ok(cls)
259 }
260}
261
262impl Deref for EquivalenceClass {
263 type Target = IndexSet<Arc<dyn PhysicalExpr>>;
264
265 fn deref(&self) -> &Self::Target {
266 &self.exprs
267 }
268}
269
270impl IntoIterator for EquivalenceClass {
271 type Item = Arc<dyn PhysicalExpr>;
272 type IntoIter = <IndexSet<Self::Item> as IntoIterator>::IntoIter;
273
274 fn into_iter(self) -> Self::IntoIter {
275 self.exprs.into_iter()
276 }
277}
278
279impl Display for EquivalenceClass {
280 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281 write!(f, "{{")?;
282 write!(f, "members: {}", format_physical_expr_list(&self.exprs))?;
283 if let Some(across) = &self.constant {
284 write!(f, ", constant: {across}")?;
285 }
286 write!(f, "}}")
287 }
288}
289
290impl From<EquivalenceClass> for Vec<Arc<dyn PhysicalExpr>> {
291 fn from(cls: EquivalenceClass) -> Self {
292 cls.exprs.into_iter().collect()
293 }
294}
295
296type AugmentedMapping<'a> = IndexMap<
297 &'a Arc<dyn PhysicalExpr>,
298 (&'a ProjectionTargets, Option<&'a EquivalenceClass>),
299>;
300
301#[derive(Clone, Debug, Default)]
304pub struct EquivalenceGroup {
305 map: HashMap<Arc<dyn PhysicalExpr>, usize>,
307 classes: Vec<EquivalenceClass>,
309}
310
311impl EquivalenceGroup {
312 pub fn new(classes: impl IntoIterator<Item = EquivalenceClass>) -> Self {
314 classes.into_iter().collect::<Vec<_>>().into()
315 }
316
317 pub fn add_constant(&mut self, const_expr: ConstExpr) {
319 if let Some(idx) = self.map.get(&const_expr.expr) {
322 let cls = &mut self.classes[*idx];
323 if let Some(across) = cls.constant.as_mut() {
324 if *across == AcrossPartitions::Heterogeneous {
326 *across = const_expr.across_partitions;
327 }
328 } else {
329 cls.constant = Some(const_expr.across_partitions);
330 }
331 return;
332 }
333 if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions {
336 for (idx, cls) in self.classes.iter_mut().enumerate() {
337 if cls
338 .constant
339 .as_ref()
340 .is_some_and(|across| const_expr.across_partitions.eq(across))
341 {
342 self.map.insert(Arc::clone(&const_expr.expr), idx);
343 cls.push(const_expr.expr);
344 return;
345 }
346 }
347 }
348 let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr));
350 if new_class.constant.is_none() {
351 new_class.constant = Some(const_expr.across_partitions);
352 }
353 Self::update_lookup_table(&mut self.map, &new_class, self.classes.len());
354 self.classes.push(new_class);
355 }
356
357 pub fn clear_per_partition_constants(&mut self) -> bool {
361 let (mut idx, mut change) = (0, false);
362 while idx < self.classes.len() {
363 let cls = &mut self.classes[idx];
364 if let Some(AcrossPartitions::Heterogeneous) = cls.constant {
365 change = true;
366 if cls.len() == 1 {
367 self.remove_class_at_idx(idx);
369 continue;
370 } else {
371 cls.constant = None;
372 }
373 }
374 idx += 1;
375 }
376 change
377 }
378
379 pub fn add_equal_conditions(
384 &mut self,
385 left: Arc<dyn PhysicalExpr>,
386 right: Arc<dyn PhysicalExpr>,
387 ) -> bool {
388 let first_class = self.map.get(&left).copied();
389 let second_class = self.map.get(&right).copied();
390 match (first_class, second_class) {
391 (Some(mut first_idx), Some(mut second_idx)) => {
392 match first_idx.cmp(&second_idx) {
395 std::cmp::Ordering::Equal => return false,
397 std::cmp::Ordering::Greater => {
399 std::mem::swap(&mut first_idx, &mut second_idx);
400 }
401 _ => {}
402 }
403 let other_class = self.remove_class_at_idx(second_idx);
407 Self::update_lookup_table(&mut self.map, &other_class, first_idx);
409 self.classes[first_idx].extend(other_class);
410 }
411 (Some(group_idx), None) => {
412 self.map.insert(Arc::clone(&right), group_idx);
414 self.classes[group_idx].push(right);
415 }
416 (None, Some(group_idx)) => {
417 self.map.insert(Arc::clone(&left), group_idx);
419 self.classes[group_idx].push(left);
420 }
421 (None, None) => {
422 let class = EquivalenceClass::new([left, right]);
425 Self::update_lookup_table(&mut self.map, &class, self.classes.len());
426 self.classes.push(class);
427 return true;
428 }
429 }
430 false
431 }
432
433 fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass {
435 let cls = self.classes.swap_remove(idx);
437 for expr in cls.iter() {
439 self.map.remove(expr);
440 }
441 if idx < self.classes.len() {
443 Self::update_lookup_table(&mut self.map, &self.classes[idx], idx);
444 }
445 cls
446 }
447
448 fn update_lookup_table(
451 map: &mut HashMap<Arc<dyn PhysicalExpr>, usize>,
452 cls: &EquivalenceClass,
453 idx: usize,
454 ) {
455 for expr in cls.iter() {
456 map.insert(Arc::clone(expr), idx);
457 }
458 }
459
460 fn remove_redundant_entries(&mut self) -> bool {
463 let mut change = false;
465 for idx in (0..self.classes.len()).rev() {
466 if self.classes[idx].is_trivial() {
467 self.remove_class_at_idx(idx);
468 change = true;
469 }
470 }
471 self.bridge_classes() || change
473 }
474
475 fn bridge_classes(&mut self) -> bool {
481 let (mut idx, mut change) = (0, false);
482 'scan: while idx < self.classes.len() {
483 for other_idx in (idx + 1..self.classes.len()).rev() {
484 if self.classes[idx].contains_any(&self.classes[other_idx]) {
485 let extension = self.remove_class_at_idx(other_idx);
486 Self::update_lookup_table(&mut self.map, &extension, idx);
487 self.classes[idx].extend(extension);
488 change = true;
489 continue 'scan;
490 }
491 }
492 idx += 1;
493 }
494 change
495 }
496
497 pub fn extend(&mut self, other: Self) -> bool {
501 for (idx, cls) in other.classes.iter().enumerate() {
502 Self::update_lookup_table(&mut self.map, cls, idx);
504 }
505 self.classes.extend(other.classes);
506 self.bridge_classes()
507 }
508
509 pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
513 expr.transform(|expr| {
514 let cls = self.get_equivalence_class(&expr);
515 let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else {
516 return Ok(Transformed::no(expr));
517 };
518 Ok(Transformed::yes(Arc::clone(canonical)))
519 })
520 .data()
521 .unwrap()
522 }
524
525 pub fn normalize_sort_expr(
531 &self,
532 mut sort_expr: PhysicalSortExpr,
533 ) -> PhysicalSortExpr {
534 sort_expr.expr = self.normalize_expr(sort_expr.expr);
535 sort_expr
536 }
537
538 pub fn normalize_sort_exprs<'a>(
547 &'a self,
548 sort_exprs: impl IntoIterator<Item = PhysicalSortExpr> + 'a,
549 ) -> impl Iterator<Item = PhysicalSortExpr> + 'a {
550 sort_exprs
551 .into_iter()
552 .map(|sort_expr| self.normalize_sort_expr(sort_expr))
553 .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none())
554 }
555
556 pub fn normalize_sort_requirement(
562 &self,
563 mut sort_requirement: PhysicalSortRequirement,
564 ) -> PhysicalSortRequirement {
565 sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
566 sort_requirement
567 }
568
569 pub fn normalize_sort_requirements<'a>(
578 &'a self,
579 sort_reqs: impl IntoIterator<Item = PhysicalSortRequirement> + 'a,
580 ) -> impl Iterator<Item = PhysicalSortRequirement> + 'a {
581 sort_reqs
582 .into_iter()
583 .map(|req| self.normalize_sort_requirement(req))
584 .filter(|req| self.is_expr_constant(&req.expr).is_none())
585 }
586
587 fn project_expr_indirect(
590 aug_mapping: &AugmentedMapping,
591 expr: &Arc<dyn PhysicalExpr>,
592 ) -> Option<Arc<dyn PhysicalExpr>> {
593 for (targets, eq_class) in aug_mapping.values() {
596 if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) {
601 let (target, _) = targets.first();
602 return Some(Arc::clone(target));
603 }
604 }
605 let children = expr.children();
607 if children.is_empty() {
608 return None;
610 }
611 children
612 .into_iter()
613 .map(|child| {
614 if let Some((targets, _)) = aug_mapping.get(child) {
617 let (target, _) = targets.first();
619 Some(Arc::clone(target))
620 } else {
621 Self::project_expr_indirect(aug_mapping, child)
622 }
623 })
624 .collect::<Option<Vec<_>>>()
625 .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
626 }
627
628 fn augment_projection_mapping<'a>(
629 &'a self,
630 mapping: &'a ProjectionMapping,
631 ) -> AugmentedMapping<'a> {
632 mapping
633 .iter()
634 .map(|(k, v)| {
635 let eq_class = self.get_equivalence_class(k);
636 (k, (v, eq_class))
637 })
638 .collect()
639 }
640
641 pub fn project_expr(
644 &self,
645 mapping: &ProjectionMapping,
646 expr: &Arc<dyn PhysicalExpr>,
647 ) -> Option<Arc<dyn PhysicalExpr>> {
648 if let Some(targets) = mapping.get(expr) {
649 let (target, _) = targets.first();
651 Some(Arc::clone(target))
652 } else {
653 let aug_mapping = self.augment_projection_mapping(mapping);
654 Self::project_expr_indirect(&aug_mapping, expr)
655 }
656 }
657
658 pub fn project_expressions<'a>(
663 &'a self,
664 mapping: &'a ProjectionMapping,
665 expressions: impl IntoIterator<Item = &'a Arc<dyn PhysicalExpr>> + 'a,
666 ) -> impl Iterator<Item = Option<Arc<dyn PhysicalExpr>>> + 'a {
667 let mut aug_mapping = None;
668 expressions.into_iter().map(move |expr| {
669 if let Some(targets) = mapping.get(expr) {
670 let (target, _) = targets.first();
672 Some(Arc::clone(target))
673 } else {
674 let aug_mapping = aug_mapping
675 .get_or_insert_with(|| self.augment_projection_mapping(mapping));
676 Self::project_expr_indirect(aug_mapping, expr)
677 }
678 })
679 }
680
681 pub fn project(&self, mapping: &ProjectionMapping) -> Self {
683 let projected_classes = self.iter().map(|cls| {
684 let new_exprs = self.project_expressions(mapping, cls.iter());
685 EquivalenceClass::new(new_exprs.flatten())
686 });
687
688 let mut new_constants = vec![];
691 let mut new_classes = IndexMap::<_, EquivalenceClass>::new();
692 for (source, targets) in mapping.iter() {
693 let normalized_expr = self.normalize_expr(Arc::clone(source));
699 let cls = new_classes.entry(normalized_expr).or_default();
700 for (target, _) in targets.iter() {
701 cls.push(Arc::clone(target));
702 }
703 if let Some(across) = self.is_expr_constant(source) {
705 for (target, _) in targets.iter() {
706 let const_expr = ConstExpr::new(Arc::clone(target), across.clone());
707 new_constants.push(const_expr);
708 }
709 }
710 }
711
712 let classes = projected_classes
714 .chain(new_classes.into_values())
715 .filter(|cls| !cls.is_trivial());
716 let mut result = Self::new(classes);
717 for constant in new_constants {
719 result.add_constant(constant);
720 }
721 result
722 }
723
724 pub fn is_expr_constant(
729 &self,
730 expr: &Arc<dyn PhysicalExpr>,
731 ) -> Option<AcrossPartitions> {
732 if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
733 return Some(AcrossPartitions::Uniform(Some(lit.value().clone())));
734 }
735 if let Some(cls) = self.get_equivalence_class(expr) {
736 if cls.constant.is_some() {
737 return cls.constant.clone();
738 }
739 }
740 let children = expr.children();
744 if children.is_empty() {
745 return None;
746 }
747 for child in children {
748 self.is_expr_constant(child)?;
749 }
750 Some(AcrossPartitions::Heterogeneous)
751 }
752
753 pub fn get_equivalence_class(
756 &self,
757 expr: &Arc<dyn PhysicalExpr>,
758 ) -> Option<&EquivalenceClass> {
759 self.map.get(expr).map(|idx| &self.classes[*idx])
760 }
761
762 pub fn join(
764 &self,
765 right_equivalences: &Self,
766 join_type: &JoinType,
767 left_size: usize,
768 on: &[(PhysicalExprRef, PhysicalExprRef)],
769 ) -> Result<Self> {
770 let group = match join_type {
771 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
772 let mut result = Self::new(
773 self.iter().cloned().chain(
774 right_equivalences
775 .iter()
776 .map(|cls| cls.try_with_offset(left_size as _))
777 .collect::<Result<Vec<_>>>()?,
778 ),
779 );
780 if join_type == &JoinType::Inner {
783 for (lhs, rhs) in on.iter() {
784 let new_lhs = Arc::clone(lhs);
785 let new_rhs =
787 add_offset_to_expr(Arc::clone(rhs), left_size as _)?;
788 result.add_equal_conditions(new_lhs, new_rhs);
789 }
790 }
791 result
792 }
793 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
794 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
795 right_equivalences.clone()
796 }
797 };
798 Ok(group)
799 }
800
801 pub fn exprs_equal(
806 &self,
807 left: &Arc<dyn PhysicalExpr>,
808 right: &Arc<dyn PhysicalExpr>,
809 ) -> bool {
810 if left.eq(right) {
812 return true;
813 }
814
815 if let Some(left_class) = self.get_equivalence_class(left) {
818 if left_class.contains(right) {
819 return true;
820 }
821 }
822 if let Some(right_class) = self.get_equivalence_class(right) {
823 if right_class.contains(left) {
824 return true;
825 }
826 }
827
828 let left_children = left.children();
830 let right_children = right.children();
831
832 if left_children.is_empty() || right_children.is_empty() {
835 return false;
836 }
837
838 if left.as_any().type_id() != right.as_any().type_id() {
840 return false;
841 }
842
843 if left_children.len() != right_children.len() {
845 return false;
846 }
847
848 left_children
850 .into_iter()
851 .zip(right_children)
852 .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
853 }
854}
855
856impl Deref for EquivalenceGroup {
857 type Target = [EquivalenceClass];
858
859 fn deref(&self) -> &Self::Target {
860 &self.classes
861 }
862}
863
864impl IntoIterator for EquivalenceGroup {
865 type Item = EquivalenceClass;
866 type IntoIter = IntoIter<Self::Item>;
867
868 fn into_iter(self) -> Self::IntoIter {
869 self.classes.into_iter()
870 }
871}
872
873impl Display for EquivalenceGroup {
874 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
875 write!(f, "[")?;
876 let mut iter = self.iter();
877 if let Some(cls) = iter.next() {
878 write!(f, "{cls}")?;
879 }
880 for cls in iter {
881 write!(f, ", {cls}")?;
882 }
883 write!(f, "]")
884 }
885}
886
887impl From<Vec<EquivalenceClass>> for EquivalenceGroup {
888 fn from(classes: Vec<EquivalenceClass>) -> Self {
889 let mut result = Self {
890 map: classes
891 .iter()
892 .enumerate()
893 .flat_map(|(idx, cls)| {
894 cls.iter().map(move |expr| (Arc::clone(expr), idx))
895 })
896 .collect(),
897 classes,
898 };
899 result.remove_redundant_entries();
900 result
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907 use crate::equivalence::tests::create_test_params;
908 use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal};
909 use arrow::datatypes::{DataType, Field, Schema};
910
911 use datafusion_common::{Result, ScalarValue};
912 use datafusion_expr::Operator;
913
914 #[test]
915 fn test_bridge_groups() -> Result<()> {
916 let test_cases = vec![
918 (
920 vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
921 vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
923 ),
924 (
926 vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
927 vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
929 ),
930 ];
931 for (entries, expected) in test_cases {
932 let entries = entries
933 .into_iter()
934 .map(|entry| {
935 entry.into_iter().map(|idx| {
936 let c = Column::new(format!("col_{idx}").as_str(), idx);
937 Arc::new(c) as _
938 })
939 })
940 .map(EquivalenceClass::new)
941 .collect::<Vec<_>>();
942 let expected = expected
943 .into_iter()
944 .map(|entry| {
945 entry.into_iter().map(|idx| {
946 let c = Column::new(format!("col_{idx}").as_str(), idx);
947 Arc::new(c) as _
948 })
949 })
950 .map(EquivalenceClass::new)
951 .collect::<Vec<_>>();
952 let eq_groups: EquivalenceGroup = entries.clone().into();
953 let eq_groups = eq_groups.classes;
954 let err_msg = format!(
955 "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
956 );
957 assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
958 for idx in 0..eq_groups.len() {
959 assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
960 }
961 }
962 Ok(())
963 }
964
965 #[test]
966 fn test_remove_redundant_entries_eq_group() -> Result<()> {
967 let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _;
968 let entries = [
969 EquivalenceClass::new([c(1), c(1), lit(20)]),
970 EquivalenceClass::new([lit(30), lit(30)]),
971 EquivalenceClass::new([c(2), c(3), c(4)]),
972 ];
973 let expected = [
976 EquivalenceClass::new([c(1), lit(20)]),
977 EquivalenceClass::new([lit(30)]),
978 EquivalenceClass::new([c(2), c(3), c(4)]),
979 ];
980 let eq_groups = EquivalenceGroup::new(entries);
981 assert_eq!(eq_groups.classes, expected);
982 Ok(())
983 }
984
985 #[test]
986 fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
987 let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
988 let col_b = Arc::new(Column::new("b", 1)) as _;
989 let col_c = Arc::new(Column::new("c", 2)) as _;
990 let (_, eq_properties) = create_test_params()?;
992 let expressions = vec![
995 (Arc::clone(&col_a), Arc::clone(&col_a)),
999 (col_c, col_a),
1000 (Arc::clone(&col_b), Arc::clone(&col_b)),
1002 ];
1003 let eq_group = eq_properties.eq_group();
1004 for (expr, expected_eq) in expressions {
1005 assert!(expected_eq.eq(&eq_group.normalize_expr(expr)));
1006 }
1007
1008 Ok(())
1009 }
1010
1011 #[test]
1012 fn test_contains_any() {
1013 let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _;
1014 let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _;
1015 let col_a_expr = Arc::new(Column::new("a", 0)) as _;
1016 let col_b_expr = Arc::new(Column::new("b", 1)) as _;
1017 let col_c_expr = Arc::new(Column::new("c", 2)) as _;
1018
1019 let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]);
1020 let cls2 = EquivalenceClass::new([lit_true, col_b_expr]);
1021 let cls3 = EquivalenceClass::new([col_c_expr, lit_false]);
1022
1023 assert!(cls1.contains_any(&cls2));
1025 assert!(!cls1.contains_any(&cls3));
1027 assert!(!cls2.contains_any(&cls3));
1028 }
1029
1030 #[test]
1031 fn test_exprs_equal() -> Result<()> {
1032 struct TestCase {
1033 left: Arc<dyn PhysicalExpr>,
1034 right: Arc<dyn PhysicalExpr>,
1035 expected: bool,
1036 description: &'static str,
1037 }
1038
1039 let col_a = Arc::new(Column::new("a", 0)) as _;
1041 let col_b = Arc::new(Column::new("b", 1)) as _;
1042 let col_x = Arc::new(Column::new("x", 2)) as _;
1043 let col_y = Arc::new(Column::new("y", 3)) as _;
1044
1045 let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _;
1047 let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _;
1048
1049 let eq_group = EquivalenceGroup::new([
1051 EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]),
1052 EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]),
1053 ]);
1054
1055 let test_cases = vec![
1056 TestCase {
1058 left: Arc::clone(&col_a),
1059 right: Arc::clone(&col_a),
1060 expected: true,
1061 description: "Same column should be equal",
1062 },
1063 TestCase {
1065 left: Arc::clone(&col_a),
1066 right: Arc::clone(&col_x),
1067 expected: true,
1068 description: "Columns in same equivalence class should be equal",
1069 },
1070 TestCase {
1071 left: Arc::clone(&col_b),
1072 right: Arc::clone(&col_y),
1073 expected: true,
1074 description: "Columns in same equivalence class should be equal",
1075 },
1076 TestCase {
1077 left: Arc::clone(&col_a),
1078 right: Arc::clone(&col_b),
1079 expected: false,
1080 description:
1081 "Columns in different equivalence classes should not be equal",
1082 },
1083 TestCase {
1085 left: Arc::clone(&lit_1),
1086 right: Arc::clone(&lit_1),
1087 expected: true,
1088 description: "Same literal should be equal",
1089 },
1090 TestCase {
1091 left: Arc::clone(&lit_1),
1092 right: Arc::clone(&lit_2),
1093 expected: false,
1094 description: "Different literals should not be equal",
1095 },
1096 TestCase {
1098 left: Arc::new(BinaryExpr::new(
1099 Arc::clone(&col_a),
1100 Operator::Plus,
1101 Arc::clone(&col_b),
1102 )) as _,
1103 right: Arc::new(BinaryExpr::new(
1104 Arc::clone(&col_x),
1105 Operator::Plus,
1106 Arc::clone(&col_y),
1107 )) as _,
1108 expected: true,
1109 description:
1110 "Binary expressions with equivalent operands should be equal",
1111 },
1112 TestCase {
1113 left: Arc::new(BinaryExpr::new(
1114 Arc::clone(&col_a),
1115 Operator::Plus,
1116 Arc::clone(&col_b),
1117 )) as _,
1118 right: Arc::new(BinaryExpr::new(
1119 Arc::clone(&col_x),
1120 Operator::Plus,
1121 Arc::clone(&col_a),
1122 )) as _,
1123 expected: false,
1124 description:
1125 "Binary expressions with non-equivalent operands should not be equal",
1126 },
1127 TestCase {
1128 left: Arc::new(BinaryExpr::new(
1129 Arc::clone(&col_a),
1130 Operator::Plus,
1131 Arc::clone(&lit_1),
1132 )) as _,
1133 right: Arc::new(BinaryExpr::new(
1134 Arc::clone(&col_x),
1135 Operator::Plus,
1136 Arc::clone(&lit_1),
1137 )) as _,
1138 expected: true,
1139 description: "Binary expressions with equivalent column and same literal should be equal",
1140 },
1141 TestCase {
1142 left: Arc::new(BinaryExpr::new(
1143 Arc::new(BinaryExpr::new(
1144 Arc::clone(&col_a),
1145 Operator::Plus,
1146 Arc::clone(&col_b),
1147 )),
1148 Operator::Multiply,
1149 Arc::clone(&lit_1),
1150 )) as _,
1151 right: Arc::new(BinaryExpr::new(
1152 Arc::new(BinaryExpr::new(
1153 Arc::clone(&col_x),
1154 Operator::Plus,
1155 Arc::clone(&col_y),
1156 )),
1157 Operator::Multiply,
1158 Arc::clone(&lit_1),
1159 )) as _,
1160 expected: true,
1161 description: "Nested binary expressions with equivalent operands should be equal",
1162 },
1163 ];
1164
1165 for TestCase {
1166 left,
1167 right,
1168 expected,
1169 description,
1170 } in test_cases
1171 {
1172 let actual = eq_group.exprs_equal(&left, &right);
1173 assert_eq!(
1174 actual, expected,
1175 "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1176 );
1177 }
1178
1179 Ok(())
1180 }
1181
1182 #[test]
1183 fn test_project_classes() -> Result<()> {
1184 let schema = Arc::new(Schema::new(vec![
1189 Field::new("a", DataType::Int32, false),
1190 Field::new("b", DataType::Int32, false),
1191 Field::new("c", DataType::Int32, false),
1192 ]));
1193 let mut group = EquivalenceGroup::default();
1194 group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?);
1195
1196 let projected_schema = Arc::new(Schema::new(vec![
1197 Field::new("a+c", DataType::Int32, false),
1198 Field::new("b+c", DataType::Int32, false),
1199 ]));
1200
1201 let mapping = [
1202 (
1203 binary(
1204 col("a", &schema)?,
1205 Operator::Plus,
1206 col("c", &schema)?,
1207 &schema,
1208 )?,
1209 vec![(col("a+c", &projected_schema)?, 0)].into(),
1210 ),
1211 (
1212 binary(
1213 col("b", &schema)?,
1214 Operator::Plus,
1215 col("c", &schema)?,
1216 &schema,
1217 )?,
1218 vec![(col("b+c", &projected_schema)?, 1)].into(),
1219 ),
1220 ]
1221 .into_iter()
1222 .collect::<ProjectionMapping>();
1223
1224 let projected = group.project(&mapping);
1225
1226 assert!(!projected.is_empty());
1227 let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1228 let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1229
1230 assert!(first_normalized.eq(&second_normalized));
1231
1232 Ok(())
1233 }
1234}