1use std::any::Any;
19use std::fmt::Display;
20use std::ops::Deref;
21use std::sync::Arc;
22use std::vec::IntoIter;
23
24use super::ProjectionMapping;
25use crate::expressions::Literal;
26use crate::physical_expr::add_offset_to_expr;
27use crate::projection::ProjectionTargets;
28use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement};
29
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::{JoinType, Result, ScalarValue};
32use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
33
34use indexmap::{IndexMap, IndexSet};
35
36#[derive(Clone, Debug, Default, Eq, PartialEq)]
43pub enum AcrossPartitions {
44 #[default]
45 Heterogeneous,
46 Uniform(Option<ScalarValue>),
47}
48
49impl Display for AcrossPartitions {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"),
53 AcrossPartitions::Uniform(value) => {
54 if let Some(val) = value {
55 write!(f, "(uniform: {val})")
56 } else {
57 write!(f, "(uniform: unknown)")
58 }
59 }
60 }
61 }
62}
63
64#[derive(Clone, Debug)]
89pub struct ConstExpr {
90 pub expr: Arc<dyn PhysicalExpr>,
92 pub across_partitions: AcrossPartitions,
94}
95impl ConstExpr {
101 pub fn new(expr: Arc<dyn PhysicalExpr>, across_partitions: AcrossPartitions) -> Self {
109 let mut result = ConstExpr::from(expr);
110 if result.across_partitions == AcrossPartitions::Heterogeneous {
113 result.across_partitions = across_partitions;
114 }
115 result
116 }
117
118 pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
120 struct DisplayableList<'a>(&'a [ConstExpr]);
121 impl Display for DisplayableList<'_> {
122 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
123 let mut first = true;
124 for const_expr in self.0 {
125 if first {
126 first = false;
127 } else {
128 write!(f, ",")?;
129 }
130 write!(f, "{const_expr}")?;
131 }
132 Ok(())
133 }
134 }
135 DisplayableList(input)
136 }
137}
138
139impl PartialEq for ConstExpr {
140 fn eq(&self, other: &Self) -> bool {
141 self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
142 }
143}
144
145impl Display for ConstExpr {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 write!(f, "{}", self.expr)?;
148 write!(f, "{}", self.across_partitions)
149 }
150}
151
152impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
153 fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
154 let across = if let Some(lit) = expr.downcast_ref::<Literal>() {
158 AcrossPartitions::Uniform(Some(lit.value().clone()))
159 } else {
160 AcrossPartitions::Heterogeneous
161 };
162 Self {
163 expr,
164 across_partitions: across,
165 }
166 }
167}
168
169#[derive(Clone, Debug, Default, Eq, PartialEq)]
177pub struct EquivalenceClass {
178 pub(crate) exprs: IndexSet<Arc<dyn PhysicalExpr>>,
181 pub(crate) constant: Option<AcrossPartitions>,
184}
185
186impl EquivalenceClass {
187 pub fn new(exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>) -> Self {
189 let mut class = Self::default();
190 for expr in exprs {
191 class.push(expr);
192 }
193 class
194 }
195
196 pub fn canonical_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
199 self.exprs.iter().next()
200 }
201
202 pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
205 if let Some(lit) = expr.downcast_ref::<Literal>() {
206 let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone()));
207 if let Some(across) = self.constant.as_mut() {
208 if *across == AcrossPartitions::Heterogeneous {
210 *across = expr_across;
211 }
212 } else {
213 self.constant = Some(expr_across);
214 }
215 }
216 self.exprs.insert(expr);
217 }
218
219 pub fn extend(&mut self, other: Self) {
221 self.exprs.extend(other.exprs);
222 match (&self.constant, &other.constant) {
223 (Some(across), Some(_)) => {
224 if across == &AcrossPartitions::Heterogeneous {
226 self.constant = other.constant;
227 }
228 }
229 (None, Some(_)) => self.constant = other.constant,
230 (_, None) => {}
231 }
232 }
233
234 pub fn contains_any(&self, other: &Self) -> bool {
237 self.exprs.intersection(&other.exprs).next().is_some()
238 }
239
240 pub fn is_trivial(&self) -> bool {
244 self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none())
245 }
246
247 pub fn try_with_offset(&self, offset: isize) -> Result<Self> {
250 let mut cls = Self::default();
251 for expr_result in self
252 .exprs
253 .iter()
254 .cloned()
255 .map(|e| add_offset_to_expr(e, offset))
256 {
257 cls.push(expr_result?);
258 }
259 Ok(cls)
260 }
261}
262
263impl Deref for EquivalenceClass {
264 type Target = IndexSet<Arc<dyn PhysicalExpr>>;
265
266 fn deref(&self) -> &Self::Target {
267 &self.exprs
268 }
269}
270
271impl IntoIterator for EquivalenceClass {
272 type Item = Arc<dyn PhysicalExpr>;
273 type IntoIter = <IndexSet<Self::Item> as IntoIterator>::IntoIter;
274
275 fn into_iter(self) -> Self::IntoIter {
276 self.exprs.into_iter()
277 }
278}
279
280impl Display for EquivalenceClass {
281 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
282 write!(f, "{{")?;
283 write!(f, "members: {}", format_physical_expr_list(&self.exprs))?;
284 if let Some(across) = &self.constant {
285 write!(f, ", constant: {across}")?;
286 }
287 write!(f, "}}")
288 }
289}
290
291impl From<EquivalenceClass> for Vec<Arc<dyn PhysicalExpr>> {
292 fn from(cls: EquivalenceClass) -> Self {
293 cls.exprs.into_iter().collect()
294 }
295}
296
297type AugmentedMapping<'a> = IndexMap<
298 &'a Arc<dyn PhysicalExpr>,
299 (&'a ProjectionTargets, Option<&'a EquivalenceClass>),
300>;
301
302#[derive(Clone, Debug, Default)]
305pub struct EquivalenceGroup {
306 map: IndexMap<Arc<dyn PhysicalExpr>, usize>,
308 classes: Vec<EquivalenceClass>,
310}
311
312impl EquivalenceGroup {
313 pub fn new(classes: impl IntoIterator<Item = EquivalenceClass>) -> Self {
315 classes.into_iter().collect::<Vec<_>>().into()
316 }
317
318 pub fn add_constant(&mut self, const_expr: ConstExpr) {
320 if let Some(idx) = self.map.get(&const_expr.expr) {
323 let cls = &mut self.classes[*idx];
324 if let Some(across) = cls.constant.as_mut() {
325 if *across == AcrossPartitions::Heterogeneous {
327 *across = const_expr.across_partitions;
328 }
329 } else {
330 cls.constant = Some(const_expr.across_partitions);
331 }
332 return;
333 }
334 if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions {
337 for (idx, cls) in self.classes.iter_mut().enumerate() {
338 if cls
339 .constant
340 .as_ref()
341 .is_some_and(|across| const_expr.across_partitions.eq(across))
342 {
343 self.map.insert(Arc::clone(&const_expr.expr), idx);
344 cls.push(const_expr.expr);
345 return;
346 }
347 }
348 }
349 let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr));
351 if new_class.constant.is_none() {
352 new_class.constant = Some(const_expr.across_partitions);
353 }
354 Self::update_lookup_table(&mut self.map, &new_class, self.classes.len());
355 self.classes.push(new_class);
356 }
357
358 pub fn clear_per_partition_constants(&mut self) -> bool {
362 let (mut idx, mut change) = (0, false);
363 while idx < self.classes.len() {
364 let cls = &mut self.classes[idx];
365 if let Some(AcrossPartitions::Heterogeneous) = cls.constant {
366 change = true;
367 if cls.len() == 1 {
368 self.remove_class_at_idx(idx);
370 continue;
371 } else {
372 cls.constant = None;
373 }
374 }
375 idx += 1;
376 }
377 change
378 }
379
380 pub fn add_equal_conditions(
385 &mut self,
386 left: Arc<dyn PhysicalExpr>,
387 right: Arc<dyn PhysicalExpr>,
388 ) -> bool {
389 let first_class = self.map.get(&left).copied();
390 let second_class = self.map.get(&right).copied();
391 match (first_class, second_class) {
392 (Some(mut first_idx), Some(mut second_idx)) => {
393 match first_idx.cmp(&second_idx) {
396 std::cmp::Ordering::Equal => return false,
398 std::cmp::Ordering::Greater => {
400 std::mem::swap(&mut first_idx, &mut second_idx);
401 }
402 _ => {}
403 }
404 let other_class = self.remove_class_at_idx(second_idx);
408 Self::update_lookup_table(&mut self.map, &other_class, first_idx);
410 self.classes[first_idx].extend(other_class);
411 }
412 (Some(group_idx), None) => {
413 self.map.insert(Arc::clone(&right), group_idx);
415 self.classes[group_idx].push(right);
416 }
417 (None, Some(group_idx)) => {
418 self.map.insert(Arc::clone(&left), group_idx);
420 self.classes[group_idx].push(left);
421 }
422 (None, None) => {
423 let class = EquivalenceClass::new([left, right]);
426 Self::update_lookup_table(&mut self.map, &class, self.classes.len());
427 self.classes.push(class);
428 return true;
429 }
430 }
431 false
432 }
433
434 fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass {
436 let cls = self.classes.swap_remove(idx);
438 for expr in cls.iter() {
440 self.map.swap_remove(expr);
441 }
442 if idx < self.classes.len() {
444 Self::update_lookup_table(&mut self.map, &self.classes[idx], idx);
445 }
446 cls
447 }
448
449 fn update_lookup_table(
452 map: &mut IndexMap<Arc<dyn PhysicalExpr>, usize>,
453 cls: &EquivalenceClass,
454 idx: usize,
455 ) {
456 for expr in cls.iter() {
457 map.insert(Arc::clone(expr), idx);
458 }
459 }
460
461 fn remove_redundant_entries(&mut self) -> bool {
464 let mut change = false;
466 for idx in (0..self.classes.len()).rev() {
467 if self.classes[idx].is_trivial() {
468 self.remove_class_at_idx(idx);
469 change = true;
470 }
471 }
472 self.bridge_classes() || change
474 }
475
476 fn bridge_classes(&mut self) -> bool {
482 let (mut idx, mut change) = (0, false);
483 'scan: while idx < self.classes.len() {
484 for other_idx in (idx + 1..self.classes.len()).rev() {
485 if self.classes[idx].contains_any(&self.classes[other_idx]) {
486 let extension = self.remove_class_at_idx(other_idx);
487 Self::update_lookup_table(&mut self.map, &extension, idx);
488 self.classes[idx].extend(extension);
489 change = true;
490 continue 'scan;
491 }
492 }
493 idx += 1;
494 }
495 change
496 }
497
498 pub fn extend(&mut self, other: Self) -> bool {
502 for (idx, cls) in other.classes.iter().enumerate() {
503 Self::update_lookup_table(&mut self.map, cls, idx);
505 }
506 self.classes.extend(other.classes);
507 self.bridge_classes()
508 }
509
510 pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
514 expr.transform(|expr| {
515 let cls = self.get_equivalence_class(&expr);
516 let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else {
517 return Ok(Transformed::no(expr));
518 };
519 Ok(Transformed::yes(Arc::clone(canonical)))
520 })
521 .data()
522 .unwrap()
523 }
525
526 pub fn normalize_sort_expr(
532 &self,
533 mut sort_expr: PhysicalSortExpr,
534 ) -> PhysicalSortExpr {
535 sort_expr.expr = self.normalize_expr(sort_expr.expr);
536 sort_expr
537 }
538
539 pub fn normalize_sort_exprs<'a>(
548 &'a self,
549 sort_exprs: impl IntoIterator<Item = PhysicalSortExpr> + 'a,
550 ) -> impl Iterator<Item = PhysicalSortExpr> + 'a {
551 sort_exprs
552 .into_iter()
553 .map(|sort_expr| self.normalize_sort_expr(sort_expr))
554 .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none())
555 }
556
557 pub fn normalize_sort_requirement(
563 &self,
564 mut sort_requirement: PhysicalSortRequirement,
565 ) -> PhysicalSortRequirement {
566 sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
567 sort_requirement
568 }
569
570 pub fn normalize_sort_requirements<'a>(
579 &'a self,
580 sort_reqs: impl IntoIterator<Item = PhysicalSortRequirement> + 'a,
581 ) -> impl Iterator<Item = PhysicalSortRequirement> + 'a {
582 sort_reqs
583 .into_iter()
584 .map(|req| self.normalize_sort_requirement(req))
585 .filter(|req| self.is_expr_constant(&req.expr).is_none())
586 }
587
588 fn project_expr_indirect(
591 aug_mapping: &AugmentedMapping,
592 expr: &Arc<dyn PhysicalExpr>,
593 ) -> Option<Arc<dyn PhysicalExpr>> {
594 if expr.downcast_ref::<Literal>().is_some() {
596 return Some(Arc::clone(expr));
597 }
598
599 for (targets, eq_class) in aug_mapping.values() {
602 if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) {
607 let (target, _) = targets.first();
608 return Some(Arc::clone(target));
609 }
610 }
611 let children = expr.children();
613 if children.is_empty() {
614 return None;
616 }
617 children
618 .into_iter()
619 .map(|child| {
620 if let Some((targets, _)) = aug_mapping.get(child) {
623 let (target, _) = targets.first();
625 Some(Arc::clone(target))
626 } else {
627 Self::project_expr_indirect(aug_mapping, child)
628 }
629 })
630 .collect::<Option<Vec<_>>>()
631 .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
632 }
633
634 fn augment_projection_mapping<'a>(
635 &'a self,
636 mapping: &'a ProjectionMapping,
637 ) -> AugmentedMapping<'a> {
638 mapping
639 .iter()
640 .map(|(k, v)| {
641 let eq_class = self.get_equivalence_class(k);
642 (k, (v, eq_class))
643 })
644 .collect()
645 }
646
647 pub fn project_expr(
650 &self,
651 mapping: &ProjectionMapping,
652 expr: &Arc<dyn PhysicalExpr>,
653 ) -> Option<Arc<dyn PhysicalExpr>> {
654 if let Some(targets) = mapping.get(expr) {
655 let (target, _) = targets.first();
657 Some(Arc::clone(target))
658 } else {
659 let aug_mapping = self.augment_projection_mapping(mapping);
660 Self::project_expr_indirect(&aug_mapping, expr)
661 }
662 }
663
664 pub fn project_expressions<'a>(
669 &'a self,
670 mapping: &'a ProjectionMapping,
671 expressions: impl IntoIterator<Item = &'a Arc<dyn PhysicalExpr>> + 'a,
672 ) -> impl Iterator<Item = Option<Arc<dyn PhysicalExpr>>> + 'a {
673 let mut aug_mapping = None;
674 expressions.into_iter().map(move |expr| {
675 if let Some(targets) = mapping.get(expr) {
676 let (target, _) = targets.first();
678 Some(Arc::clone(target))
679 } else {
680 let aug_mapping = aug_mapping
681 .get_or_insert_with(|| self.augment_projection_mapping(mapping));
682 Self::project_expr_indirect(aug_mapping, expr)
683 }
684 })
685 }
686
687 pub fn project(&self, mapping: &ProjectionMapping) -> Self {
689 let projected_classes = self.iter().map(|cls| {
690 let new_exprs = self.project_expressions(mapping, cls.iter());
691 EquivalenceClass::new(new_exprs.flatten())
692 });
693
694 let mut new_constants = vec![];
697 let mut new_classes = IndexMap::<_, EquivalenceClass>::new();
698 for (source, targets) in mapping.iter() {
699 let normalized_expr = self.normalize_expr(Arc::clone(source));
705 let cls = new_classes.entry(normalized_expr).or_default();
706 for (target, _) in targets.iter() {
707 cls.push(Arc::clone(target));
708 }
709 if let Some(across) = self.is_expr_constant(source) {
711 for (target, _) in targets.iter() {
712 let const_expr = ConstExpr::new(Arc::clone(target), across.clone());
713 new_constants.push(const_expr);
714 }
715 }
716 }
717
718 let classes = projected_classes
720 .chain(new_classes.into_values())
721 .filter(|cls| !cls.is_trivial());
722 let mut result = Self::new(classes);
723 for constant in new_constants {
725 result.add_constant(constant);
726 }
727 result
728 }
729
730 pub fn is_expr_constant(
735 &self,
736 expr: &Arc<dyn PhysicalExpr>,
737 ) -> Option<AcrossPartitions> {
738 if let Some(lit) = expr.downcast_ref::<Literal>() {
739 return Some(AcrossPartitions::Uniform(Some(lit.value().clone())));
740 }
741 if let Some(cls) = self.get_equivalence_class(expr)
742 && cls.constant.is_some()
743 {
744 return cls.constant.clone();
745 }
746 let children = expr.children();
750 if children.is_empty() {
751 return None;
752 }
753 for child in children {
754 self.is_expr_constant(child)?;
755 }
756 Some(AcrossPartitions::Heterogeneous)
757 }
758
759 pub fn get_equivalence_class(
762 &self,
763 expr: &Arc<dyn PhysicalExpr>,
764 ) -> Option<&EquivalenceClass> {
765 self.map.get(expr).map(|idx| &self.classes[*idx])
766 }
767
768 pub fn join(
770 &self,
771 right_equivalences: &Self,
772 join_type: &JoinType,
773 left_size: usize,
774 on: &[(PhysicalExprRef, PhysicalExprRef)],
775 ) -> Result<Self> {
776 let group = match join_type {
777 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
778 let mut result = Self::new(
779 self.iter().cloned().chain(
780 right_equivalences
781 .iter()
782 .map(|cls| cls.try_with_offset(left_size as _))
783 .collect::<Result<Vec<_>>>()?,
784 ),
785 );
786 if join_type == &JoinType::Inner {
789 for (lhs, rhs) in on.iter() {
790 let new_lhs = Arc::clone(lhs);
791 let new_rhs =
793 add_offset_to_expr(Arc::clone(rhs), left_size as _)?;
794 result.add_equal_conditions(new_lhs, new_rhs);
795 }
796 }
797 result
798 }
799 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
800 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
801 right_equivalences.clone()
802 }
803 };
804 Ok(group)
805 }
806
807 pub fn exprs_equal(
812 &self,
813 left: &Arc<dyn PhysicalExpr>,
814 right: &Arc<dyn PhysicalExpr>,
815 ) -> bool {
816 if left.eq(right) {
818 return true;
819 }
820
821 if let Some(left_class) = self.get_equivalence_class(left)
824 && left_class.contains(right)
825 {
826 return true;
827 }
828 if let Some(right_class) = self.get_equivalence_class(right)
829 && right_class.contains(left)
830 {
831 return true;
832 }
833
834 let left_children = left.children();
836 let right_children = right.children();
837
838 if left_children.is_empty() || right_children.is_empty() {
841 return false;
842 }
843
844 if (left as &dyn Any).type_id() != (right as &dyn Any).type_id() {
846 return false;
847 }
848
849 if left_children.len() != right_children.len() {
851 return false;
852 }
853
854 left_children
856 .into_iter()
857 .zip(right_children)
858 .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
859 }
860}
861
862impl Deref for EquivalenceGroup {
863 type Target = [EquivalenceClass];
864
865 fn deref(&self) -> &Self::Target {
866 &self.classes
867 }
868}
869
870impl IntoIterator for EquivalenceGroup {
871 type Item = EquivalenceClass;
872 type IntoIter = IntoIter<Self::Item>;
873
874 fn into_iter(self) -> Self::IntoIter {
875 self.classes.into_iter()
876 }
877}
878
879impl Display for EquivalenceGroup {
880 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
881 write!(f, "[")?;
882 let mut iter = self.iter();
883 if let Some(cls) = iter.next() {
884 write!(f, "{cls}")?;
885 }
886 for cls in iter {
887 write!(f, ", {cls}")?;
888 }
889 write!(f, "]")
890 }
891}
892
893impl From<Vec<EquivalenceClass>> for EquivalenceGroup {
894 fn from(classes: Vec<EquivalenceClass>) -> Self {
895 let mut result = Self {
896 map: classes
897 .iter()
898 .enumerate()
899 .flat_map(|(idx, cls)| {
900 cls.iter().map(move |expr| (Arc::clone(expr), idx))
901 })
902 .collect(),
903 classes,
904 };
905 result.remove_redundant_entries();
906 result
907 }
908}
909
910#[cfg(test)]
911mod tests {
912 use super::*;
913 use crate::equivalence::tests::create_test_params;
914 use crate::expressions::{BinaryExpr, Column, binary, col, lit};
915 use arrow::datatypes::{DataType, Field, Schema};
916
917 use datafusion_expr::Operator;
918
919 #[test]
920 fn test_bridge_groups() -> Result<()> {
921 let test_cases = vec![
923 (
925 vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
926 vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
928 ),
929 (
931 vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
932 vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
934 ),
935 ];
936 for (entries, expected) in test_cases {
937 let entries = entries
938 .into_iter()
939 .map(|entry| {
940 entry.into_iter().map(|idx| {
941 let c = Column::new(format!("col_{idx}").as_str(), idx);
942 Arc::new(c) as _
943 })
944 })
945 .map(EquivalenceClass::new)
946 .collect::<Vec<_>>();
947 let expected = expected
948 .into_iter()
949 .map(|entry| {
950 entry.into_iter().map(|idx| {
951 let c = Column::new(format!("col_{idx}").as_str(), idx);
952 Arc::new(c) as _
953 })
954 })
955 .map(EquivalenceClass::new)
956 .collect::<Vec<_>>();
957 let eq_groups: EquivalenceGroup = entries.clone().into();
958 let eq_groups = eq_groups.classes;
959 let err_msg = format!(
960 "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
961 );
962 assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
963 for idx in 0..eq_groups.len() {
964 assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
965 }
966 }
967 Ok(())
968 }
969
970 #[test]
971 fn test_remove_redundant_entries_eq_group() -> Result<()> {
972 let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _;
973 let entries = [
974 EquivalenceClass::new([c(1), c(1), lit(20)]),
975 EquivalenceClass::new([lit(30), lit(30)]),
976 EquivalenceClass::new([c(2), c(3), c(4)]),
977 ];
978 let expected = [
981 EquivalenceClass::new([c(1), lit(20)]),
982 EquivalenceClass::new([lit(30)]),
983 EquivalenceClass::new([c(2), c(3), c(4)]),
984 ];
985 let eq_groups = EquivalenceGroup::new(entries);
986 assert_eq!(eq_groups.classes, expected);
987 Ok(())
988 }
989
990 #[test]
991 fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
992 let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
993 let col_b = Arc::new(Column::new("b", 1)) as _;
994 let col_c = Arc::new(Column::new("c", 2)) as _;
995 let (_, eq_properties) = create_test_params()?;
997 let expressions = vec![
1000 (Arc::clone(&col_a), Arc::clone(&col_a)),
1004 (col_c, col_a),
1005 (Arc::clone(&col_b), Arc::clone(&col_b)),
1007 ];
1008 let eq_group = eq_properties.eq_group();
1009 for (expr, expected_eq) in expressions {
1010 assert!(expected_eq.eq(&eq_group.normalize_expr(expr)));
1011 }
1012
1013 Ok(())
1014 }
1015
1016 #[test]
1017 fn test_contains_any() {
1018 let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _;
1019 let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _;
1020 let col_a_expr = Arc::new(Column::new("a", 0)) as _;
1021 let col_b_expr = Arc::new(Column::new("b", 1)) as _;
1022 let col_c_expr = Arc::new(Column::new("c", 2)) as _;
1023
1024 let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]);
1025 let cls2 = EquivalenceClass::new([lit_true, col_b_expr]);
1026 let cls3 = EquivalenceClass::new([col_c_expr, lit_false]);
1027
1028 assert!(cls1.contains_any(&cls2));
1030 assert!(!cls1.contains_any(&cls3));
1032 assert!(!cls2.contains_any(&cls3));
1033 }
1034
1035 #[test]
1036 fn test_exprs_equal() -> Result<()> {
1037 struct TestCase {
1038 left: Arc<dyn PhysicalExpr>,
1039 right: Arc<dyn PhysicalExpr>,
1040 expected: bool,
1041 description: &'static str,
1042 }
1043
1044 let col_a = Arc::new(Column::new("a", 0)) as _;
1046 let col_b = Arc::new(Column::new("b", 1)) as _;
1047 let col_x = Arc::new(Column::new("x", 2)) as _;
1048 let col_y = Arc::new(Column::new("y", 3)) as _;
1049
1050 let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _;
1052 let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _;
1053
1054 let eq_group = EquivalenceGroup::new([
1056 EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]),
1057 EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]),
1058 ]);
1059
1060 let test_cases = vec![
1061 TestCase {
1063 left: Arc::clone(&col_a),
1064 right: Arc::clone(&col_a),
1065 expected: true,
1066 description: "Same column should be equal",
1067 },
1068 TestCase {
1070 left: Arc::clone(&col_a),
1071 right: Arc::clone(&col_x),
1072 expected: true,
1073 description: "Columns in same equivalence class should be equal",
1074 },
1075 TestCase {
1076 left: Arc::clone(&col_b),
1077 right: Arc::clone(&col_y),
1078 expected: true,
1079 description: "Columns in same equivalence class should be equal",
1080 },
1081 TestCase {
1082 left: Arc::clone(&col_a),
1083 right: Arc::clone(&col_b),
1084 expected: false,
1085 description: "Columns in different equivalence classes should not be equal",
1086 },
1087 TestCase {
1089 left: Arc::clone(&lit_1),
1090 right: Arc::clone(&lit_1),
1091 expected: true,
1092 description: "Same literal should be equal",
1093 },
1094 TestCase {
1095 left: Arc::clone(&lit_1),
1096 right: Arc::clone(&lit_2),
1097 expected: false,
1098 description: "Different literals should not be equal",
1099 },
1100 TestCase {
1102 left: Arc::new(BinaryExpr::new(
1103 Arc::clone(&col_a),
1104 Operator::Plus,
1105 Arc::clone(&col_b),
1106 )) as _,
1107 right: Arc::new(BinaryExpr::new(
1108 Arc::clone(&col_x),
1109 Operator::Plus,
1110 Arc::clone(&col_y),
1111 )) as _,
1112 expected: true,
1113 description: "Binary expressions with equivalent operands should be equal",
1114 },
1115 TestCase {
1116 left: Arc::new(BinaryExpr::new(
1117 Arc::clone(&col_a),
1118 Operator::Plus,
1119 Arc::clone(&col_b),
1120 )) as _,
1121 right: Arc::new(BinaryExpr::new(
1122 Arc::clone(&col_x),
1123 Operator::Plus,
1124 Arc::clone(&col_a),
1125 )) as _,
1126 expected: false,
1127 description: "Binary expressions with non-equivalent operands should not be equal",
1128 },
1129 TestCase {
1130 left: Arc::new(BinaryExpr::new(
1131 Arc::clone(&col_a),
1132 Operator::Plus,
1133 Arc::clone(&lit_1),
1134 )) as _,
1135 right: Arc::new(BinaryExpr::new(
1136 Arc::clone(&col_x),
1137 Operator::Plus,
1138 Arc::clone(&lit_1),
1139 )) as _,
1140 expected: true,
1141 description: "Binary expressions with equivalent column and same literal should be equal",
1142 },
1143 TestCase {
1144 left: Arc::new(BinaryExpr::new(
1145 Arc::new(BinaryExpr::new(
1146 Arc::clone(&col_a),
1147 Operator::Plus,
1148 Arc::clone(&col_b),
1149 )),
1150 Operator::Multiply,
1151 Arc::clone(&lit_1),
1152 )) as _,
1153 right: Arc::new(BinaryExpr::new(
1154 Arc::new(BinaryExpr::new(
1155 Arc::clone(&col_x),
1156 Operator::Plus,
1157 Arc::clone(&col_y),
1158 )),
1159 Operator::Multiply,
1160 Arc::clone(&lit_1),
1161 )) as _,
1162 expected: true,
1163 description: "Nested binary expressions with equivalent operands should be equal",
1164 },
1165 ];
1166
1167 for TestCase {
1168 left,
1169 right,
1170 expected,
1171 description,
1172 } in test_cases
1173 {
1174 let actual = eq_group.exprs_equal(&left, &right);
1175 assert_eq!(
1176 actual, expected,
1177 "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1178 );
1179 }
1180
1181 Ok(())
1182 }
1183
1184 #[test]
1185 fn test_project_classes() -> Result<()> {
1186 let schema = Arc::new(Schema::new(vec![
1191 Field::new("a", DataType::Int32, false),
1192 Field::new("b", DataType::Int32, false),
1193 Field::new("c", DataType::Int32, false),
1194 ]));
1195 let mut group = EquivalenceGroup::default();
1196 group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?);
1197
1198 let projected_schema = Arc::new(Schema::new(vec![
1199 Field::new("a+c", DataType::Int32, false),
1200 Field::new("b+c", DataType::Int32, false),
1201 ]));
1202
1203 let mapping = [
1204 (
1205 binary(
1206 col("a", &schema)?,
1207 Operator::Plus,
1208 col("c", &schema)?,
1209 &schema,
1210 )?,
1211 vec![(col("a+c", &projected_schema)?, 0)].into(),
1212 ),
1213 (
1214 binary(
1215 col("b", &schema)?,
1216 Operator::Plus,
1217 col("c", &schema)?,
1218 &schema,
1219 )?,
1220 vec![(col("b+c", &projected_schema)?, 1)].into(),
1221 ),
1222 ]
1223 .into_iter()
1224 .collect::<ProjectionMapping>();
1225
1226 let projected = group.project(&mapping);
1227
1228 assert!(!projected.is_empty());
1229 let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1230 let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1231
1232 assert!(first_normalized.eq(&second_normalized));
1233
1234 Ok(())
1235 }
1236}