datafusion_physical_expr/equivalence/
class.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{add_offset_to_expr, ProjectionMapping};
19use crate::{
20    expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef,
21    PhysicalSortExpr, PhysicalSortRequirement,
22};
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_common::{JoinType, ScalarValue};
25use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
26use std::fmt::Display;
27use std::sync::Arc;
28use std::vec::IntoIter;
29
30use indexmap::{IndexMap, IndexSet};
31
32/// A structure representing a expression known to be constant in a physical execution plan.
33///
34/// The `ConstExpr` struct encapsulates an expression that is constant during the execution
35/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would
36/// be known constant
37///
38/// # Fields
39///
40/// - `expr`: Constant expression for a node in the physical plan.
41///
42/// - `across_partitions`: A boolean flag indicating whether the constant
43///   expression is the same across partitions. If set to `true`, the constant
44///   expression has same value for all partitions. If set to `false`, the
45///   constant expression may have different values for different partitions.
46///
47/// # Example
48///
49/// ```rust
50/// # use datafusion_physical_expr::ConstExpr;
51/// # use datafusion_physical_expr::expressions::lit;
52/// let col = lit(5);
53/// // Create a constant expression from a physical expression ref
54/// let const_expr = ConstExpr::from(&col);
55/// // create a constant expression from a physical expression
56/// let const_expr = ConstExpr::from(col);
57/// ```
58// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum:
59//
60// ```
61// enum PartitionValues {
62//     Uniform(Option<ScalarValue>),           // Same value across all partitions
63//     Heterogeneous(Vec<Option<ScalarValue>>) // Different values per partition
64// }
65// ```
66//
67// This would provide more flexible representation of partition values.
68// Note: This is a breaking change for the equivalence API and should be
69// addressed in a separate issue/PR.
70#[derive(Debug, Clone)]
71pub struct ConstExpr {
72    /// The  expression that is known to be constant (e.g. a `Column`)
73    expr: Arc<dyn PhysicalExpr>,
74    /// Does the constant have the same value across all partitions? See
75    /// struct docs for more details
76    across_partitions: AcrossPartitions,
77}
78
79#[derive(PartialEq, Clone, Debug)]
80/// Represents whether a constant expression's value is uniform or varies across partitions.
81///
82/// The `AcrossPartitions` enum is used to describe the nature of a constant expression
83/// in a physical execution plan:
84///
85/// - `Heterogeneous`: The constant expression may have different values for different partitions.
86/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value across all partitions,
87///   or is `None` if the value is not specified.
88pub enum AcrossPartitions {
89    Heterogeneous,
90    Uniform(Option<ScalarValue>),
91}
92
93impl Default for AcrossPartitions {
94    fn default() -> Self {
95        Self::Heterogeneous
96    }
97}
98
99impl PartialEq for ConstExpr {
100    fn eq(&self, other: &Self) -> bool {
101        self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
102    }
103}
104
105impl ConstExpr {
106    /// Create a new constant expression from a physical expression.
107    ///
108    /// Note you can also use `ConstExpr::from` to create a constant expression
109    /// from a reference as well
110    pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
111        Self {
112            expr,
113            // By default, assume constant expressions are not same across partitions.
114            across_partitions: Default::default(),
115        }
116    }
117
118    /// Set the `across_partitions` flag
119    ///
120    /// See struct docs for more details
121    pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
122        self.across_partitions = across_partitions;
123        self
124    }
125
126    /// Is the  expression the same across all partitions?
127    ///
128    /// See struct docs for more details
129    pub fn across_partitions(&self) -> AcrossPartitions {
130        self.across_partitions.clone()
131    }
132
133    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
134        &self.expr
135    }
136
137    pub fn owned_expr(self) -> Arc<dyn PhysicalExpr> {
138        self.expr
139    }
140
141    pub fn map<F>(&self, f: F) -> Option<Self>
142    where
143        F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
144    {
145        let maybe_expr = f(&self.expr);
146        maybe_expr.map(|expr| Self {
147            expr,
148            across_partitions: self.across_partitions.clone(),
149        })
150    }
151
152    /// Returns true if this constant expression is equal to the given expression
153    pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
154        self.expr.as_ref() == other.as_ref()
155    }
156
157    /// Returns a [`Display`]able list of `ConstExpr`.
158    pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
159        struct DisplayableList<'a>(&'a [ConstExpr]);
160        impl Display for DisplayableList<'_> {
161            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162                let mut first = true;
163                for const_expr in self.0 {
164                    if first {
165                        first = false;
166                    } else {
167                        write!(f, ",")?;
168                    }
169                    write!(f, "{const_expr}")?;
170                }
171                Ok(())
172            }
173        }
174        DisplayableList(input)
175    }
176}
177
178impl Display for ConstExpr {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        write!(f, "{}", self.expr)?;
181        match &self.across_partitions {
182            AcrossPartitions::Heterogeneous => {
183                write!(f, "(heterogeneous)")?;
184            }
185            AcrossPartitions::Uniform(value) => {
186                if let Some(val) = value {
187                    write!(f, "(uniform: {val})")?;
188                } else {
189                    write!(f, "(uniform: unknown)")?;
190                }
191            }
192        }
193        Ok(())
194    }
195}
196
197impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
198    fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
199        Self::new(expr)
200    }
201}
202
203impl From<&Arc<dyn PhysicalExpr>> for ConstExpr {
204    fn from(expr: &Arc<dyn PhysicalExpr>) -> Self {
205        Self::new(Arc::clone(expr))
206    }
207}
208
209/// Checks whether `expr` is among in the `const_exprs`.
210pub fn const_exprs_contains(
211    const_exprs: &[ConstExpr],
212    expr: &Arc<dyn PhysicalExpr>,
213) -> bool {
214    const_exprs
215        .iter()
216        .any(|const_expr| const_expr.expr.eq(expr))
217}
218
219/// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known
220/// to have the same value for all tuples in a relation. These are generated by
221/// equality predicates (e.g. `a = b`), typically equi-join conditions and
222/// equality conditions in filters.
223///
224/// Two `EquivalenceClass`es are equal if they contains the same expressions in
225/// without any ordering.
226#[derive(Debug, Clone)]
227pub struct EquivalenceClass {
228    /// The expressions in this equivalence class. The order doesn't
229    /// matter for equivalence purposes
230    ///
231    exprs: IndexSet<Arc<dyn PhysicalExpr>>,
232}
233
234impl PartialEq for EquivalenceClass {
235    /// Returns true if other is equal in the sense
236    /// of bags (multi-sets), disregarding their orderings.
237    fn eq(&self, other: &Self) -> bool {
238        self.exprs.eq(&other.exprs)
239    }
240}
241
242impl EquivalenceClass {
243    /// Create a new empty equivalence class
244    pub fn new_empty() -> Self {
245        Self {
246            exprs: IndexSet::new(),
247        }
248    }
249
250    // Create a new equivalence class from a pre-existing `Vec`
251    pub fn new(exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
252        Self {
253            exprs: exprs.into_iter().collect(),
254        }
255    }
256
257    /// Return the inner vector of expressions
258    pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
259        self.exprs.into_iter().collect()
260    }
261
262    /// Return the "canonical" expression for this class (the first element)
263    /// if any
264    fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
265        self.exprs.iter().next().cloned()
266    }
267
268    /// Insert the expression into this class, meaning it is known to be equal to
269    /// all other expressions in this class
270    pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
271        self.exprs.insert(expr);
272    }
273
274    /// Inserts all the expressions from other into this class
275    pub fn extend(&mut self, other: Self) {
276        for expr in other.exprs {
277            // use push so entries are deduplicated
278            self.push(expr);
279        }
280    }
281
282    /// Returns true if this equivalence class contains t expression
283    pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
284        self.exprs.contains(expr)
285    }
286
287    /// Returns true if this equivalence class has any entries in common with `other`
288    pub fn contains_any(&self, other: &Self) -> bool {
289        self.exprs.iter().any(|e| other.contains(e))
290    }
291
292    /// return the number of items in this class
293    pub fn len(&self) -> usize {
294        self.exprs.len()
295    }
296
297    /// return true if this class is empty
298    pub fn is_empty(&self) -> bool {
299        self.exprs.is_empty()
300    }
301
302    /// Iterate over all elements in this class, in some arbitrary order
303    pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> {
304        self.exprs.iter()
305    }
306
307    /// Return a new equivalence class that have the specified offset added to
308    /// each expression (used when schemas are appended such as in joins)
309    pub fn with_offset(&self, offset: usize) -> Self {
310        let new_exprs = self
311            .exprs
312            .iter()
313            .cloned()
314            .map(|e| add_offset_to_expr(e, offset))
315            .collect();
316        Self::new(new_exprs)
317    }
318}
319
320impl Display for EquivalenceClass {
321    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
322        write!(f, "[{}]", format_physical_expr_list(&self.exprs))
323    }
324}
325
326/// A collection of distinct `EquivalenceClass`es
327#[derive(Debug, Clone)]
328pub struct EquivalenceGroup {
329    classes: Vec<EquivalenceClass>,
330}
331
332impl EquivalenceGroup {
333    /// Creates an empty equivalence group.
334    pub fn empty() -> Self {
335        Self { classes: vec![] }
336    }
337
338    /// Creates an equivalence group from the given equivalence classes.
339    pub fn new(classes: Vec<EquivalenceClass>) -> Self {
340        let mut result = Self { classes };
341        result.remove_redundant_entries();
342        result
343    }
344
345    /// Returns how many equivalence classes there are in this group.
346    pub fn len(&self) -> usize {
347        self.classes.len()
348    }
349
350    /// Checks whether this equivalence group is empty.
351    pub fn is_empty(&self) -> bool {
352        self.len() == 0
353    }
354
355    /// Returns an iterator over the equivalence classes in this group.
356    pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> {
357        self.classes.iter()
358    }
359
360    /// Adds the equality `left` = `right` to this equivalence group.
361    /// New equality conditions often arise after steps like `Filter(a = b)`,
362    /// `Alias(a, a as b)` etc.
363    pub fn add_equal_conditions(
364        &mut self,
365        left: &Arc<dyn PhysicalExpr>,
366        right: &Arc<dyn PhysicalExpr>,
367    ) {
368        let mut first_class = None;
369        let mut second_class = None;
370        for (idx, cls) in self.classes.iter().enumerate() {
371            if cls.contains(left) {
372                first_class = Some(idx);
373            }
374            if cls.contains(right) {
375                second_class = Some(idx);
376            }
377        }
378        match (first_class, second_class) {
379            (Some(mut first_idx), Some(mut second_idx)) => {
380                // If the given left and right sides belong to different classes,
381                // we should unify/bridge these classes.
382                if first_idx != second_idx {
383                    // By convention, make sure `second_idx` is larger than `first_idx`.
384                    if first_idx > second_idx {
385                        (first_idx, second_idx) = (second_idx, first_idx);
386                    }
387                    // Remove the class at `second_idx` and merge its values with
388                    // the class at `first_idx`. The convention above makes sure
389                    // that `first_idx` is still valid after removing `second_idx`.
390                    let other_class = self.classes.swap_remove(second_idx);
391                    self.classes[first_idx].extend(other_class);
392                }
393            }
394            (Some(group_idx), None) => {
395                // Right side is new, extend left side's class:
396                self.classes[group_idx].push(Arc::clone(right));
397            }
398            (None, Some(group_idx)) => {
399                // Left side is new, extend right side's class:
400                self.classes[group_idx].push(Arc::clone(left));
401            }
402            (None, None) => {
403                // None of the expressions is among existing classes.
404                // Create a new equivalence class and extend the group.
405                self.classes.push(EquivalenceClass::new(vec![
406                    Arc::clone(left),
407                    Arc::clone(right),
408                ]));
409            }
410        }
411    }
412
413    /// Removes redundant entries from this group.
414    fn remove_redundant_entries(&mut self) {
415        // Remove duplicate entries from each equivalence class:
416        self.classes.retain_mut(|cls| {
417            // Keep groups that have at least two entries as singleton class is
418            // meaningless (i.e. it contains no non-trivial information):
419            cls.len() > 1
420        });
421        // Unify/bridge groups that have common expressions:
422        self.bridge_classes()
423    }
424
425    /// This utility function unifies/bridges classes that have common expressions.
426    /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
427    /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
428    /// equal and belong to one class. This utility converts merges such classes.
429    fn bridge_classes(&mut self) {
430        let mut idx = 0;
431        while idx < self.classes.len() {
432            let mut next_idx = idx + 1;
433            let start_size = self.classes[idx].len();
434            while next_idx < self.classes.len() {
435                if self.classes[idx].contains_any(&self.classes[next_idx]) {
436                    let extension = self.classes.swap_remove(next_idx);
437                    self.classes[idx].extend(extension);
438                } else {
439                    next_idx += 1;
440                }
441            }
442            if self.classes[idx].len() > start_size {
443                continue;
444            }
445            idx += 1;
446        }
447    }
448
449    /// Extends this equivalence group with the `other` equivalence group.
450    pub fn extend(&mut self, other: Self) {
451        self.classes.extend(other.classes);
452        self.remove_redundant_entries();
453    }
454
455    /// Normalizes the given physical expression according to this group.
456    /// The expression is replaced with the first expression in the equivalence
457    /// class it matches with (if any).
458    pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
459        expr.transform(|expr| {
460            for cls in self.iter() {
461                if cls.contains(&expr) {
462                    // The unwrap below is safe because the guard above ensures
463                    // that the class is not empty.
464                    return Ok(Transformed::yes(cls.canonical_expr().unwrap()));
465                }
466            }
467            Ok(Transformed::no(expr))
468        })
469        .data()
470        .unwrap()
471        // The unwrap above is safe because the closure always returns `Ok`.
472    }
473
474    /// Normalizes the given sort expression according to this group.
475    /// The underlying physical expression is replaced with the first expression
476    /// in the equivalence class it matches with (if any). If the underlying
477    /// expression does not belong to any equivalence class in this group, returns
478    /// the sort expression as is.
479    pub fn normalize_sort_expr(
480        &self,
481        mut sort_expr: PhysicalSortExpr,
482    ) -> PhysicalSortExpr {
483        sort_expr.expr = self.normalize_expr(sort_expr.expr);
484        sort_expr
485    }
486
487    /// Normalizes the given sort requirement according to this group.
488    /// The underlying physical expression is replaced with the first expression
489    /// in the equivalence class it matches with (if any). If the underlying
490    /// expression does not belong to any equivalence class in this group, returns
491    /// the given sort requirement as is.
492    pub fn normalize_sort_requirement(
493        &self,
494        mut sort_requirement: PhysicalSortRequirement,
495    ) -> PhysicalSortRequirement {
496        sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
497        sort_requirement
498    }
499
500    /// This function applies the `normalize_expr` function for all expressions
501    /// in `exprs` and returns the corresponding normalized physical expressions.
502    pub fn normalize_exprs(
503        &self,
504        exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
505    ) -> Vec<Arc<dyn PhysicalExpr>> {
506        exprs
507            .into_iter()
508            .map(|expr| self.normalize_expr(expr))
509            .collect()
510    }
511
512    /// This function applies the `normalize_sort_expr` function for all sort
513    /// expressions in `sort_exprs` and returns the corresponding normalized
514    /// sort expressions.
515    pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering {
516        // Convert sort expressions to sort requirements:
517        let sort_reqs = LexRequirement::from(sort_exprs.clone());
518        // Normalize the requirements:
519        let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
520        // Convert sort requirements back to sort expressions:
521        LexOrdering::from(normalized_sort_reqs)
522    }
523
524    /// This function applies the `normalize_sort_requirement` function for all
525    /// requirements in `sort_reqs` and returns the corresponding normalized
526    /// sort requirements.
527    pub fn normalize_sort_requirements(
528        &self,
529        sort_reqs: &LexRequirement,
530    ) -> LexRequirement {
531        LexRequirement::new(
532            sort_reqs
533                .iter()
534                .map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
535                .collect(),
536        )
537        .collapse()
538    }
539
540    /// Projects `expr` according to the given projection mapping.
541    /// If the resulting expression is invalid after projection, returns `None`.
542    pub fn project_expr(
543        &self,
544        mapping: &ProjectionMapping,
545        expr: &Arc<dyn PhysicalExpr>,
546    ) -> Option<Arc<dyn PhysicalExpr>> {
547        // First, we try to project expressions with an exact match. If we are
548        // unable to do this, we consult equivalence classes.
549        if let Some(target) = mapping.target_expr(expr) {
550            // If we match the source, we can project directly:
551            return Some(target);
552        } else {
553            // If the given expression is not inside the mapping, try to project
554            // expressions considering the equivalence classes.
555            for (source, target) in mapping.iter() {
556                // If we match an equivalent expression to `source`, then we can
557                // project. For example, if we have the mapping `(a as a1, a + c)`
558                // and the equivalence class `(a, b)`, expression `b` projects to `a1`.
559                if self
560                    .get_equivalence_class(source)
561                    .is_some_and(|group| group.contains(expr))
562                {
563                    return Some(Arc::clone(target));
564                }
565            }
566        }
567        // Project a non-leaf expression by projecting its children.
568        let children = expr.children();
569        if children.is_empty() {
570            // Leaf expression should be inside mapping.
571            return None;
572        }
573        children
574            .into_iter()
575            .map(|child| self.project_expr(mapping, child))
576            .collect::<Option<Vec<_>>>()
577            .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
578    }
579
580    /// Projects this equivalence group according to the given projection mapping.
581    pub fn project(&self, mapping: &ProjectionMapping) -> Self {
582        let projected_classes = self.iter().filter_map(|cls| {
583            let new_class = cls
584                .iter()
585                .filter_map(|expr| self.project_expr(mapping, expr))
586                .collect::<Vec<_>>();
587            (new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
588        });
589
590        // The key is the source expression, and the value is the equivalence
591        // class that contains the corresponding target expression.
592        let mut new_classes: IndexMap<_, _> = IndexMap::new();
593        for (source, target) in mapping.iter() {
594            // We need to find equivalent projected expressions. For example,
595            // consider a table with columns `[a, b, c]` with `a` == `b`, and
596            // projection `[a + c, b + c]`. To conclude that `a + c == b + c`,
597            // we first normalize all source expressions in the mapping, then
598            // merge all equivalent expressions into the classes.
599            let normalized_expr = self.normalize_expr(Arc::clone(source));
600            new_classes
601                .entry(normalized_expr)
602                .or_insert_with(EquivalenceClass::new_empty)
603                .push(Arc::clone(target));
604        }
605        // Only add equivalence classes with at least two members as singleton
606        // equivalence classes are meaningless.
607        let new_classes = new_classes
608            .into_iter()
609            .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls));
610
611        let classes = projected_classes.chain(new_classes).collect();
612        Self::new(classes)
613    }
614
615    /// Returns the equivalence class containing `expr`. If no equivalence class
616    /// contains `expr`, returns `None`.
617    fn get_equivalence_class(
618        &self,
619        expr: &Arc<dyn PhysicalExpr>,
620    ) -> Option<&EquivalenceClass> {
621        self.iter().find(|cls| cls.contains(expr))
622    }
623
624    /// Combine equivalence groups of the given join children.
625    pub fn join(
626        &self,
627        right_equivalences: &Self,
628        join_type: &JoinType,
629        left_size: usize,
630        on: &[(PhysicalExprRef, PhysicalExprRef)],
631    ) -> Self {
632        match join_type {
633            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
634                let mut result = Self::new(
635                    self.iter()
636                        .cloned()
637                        .chain(
638                            right_equivalences
639                                .iter()
640                                .map(|cls| cls.with_offset(left_size)),
641                        )
642                        .collect(),
643                );
644                // In we have an inner join, expressions in the "on" condition
645                // are equal in the resulting table.
646                if join_type == &JoinType::Inner {
647                    for (lhs, rhs) in on.iter() {
648                        let new_lhs = Arc::clone(lhs);
649                        // Rewrite rhs to point to the right side of the join:
650                        let new_rhs = Arc::clone(rhs)
651                            .transform(|expr| {
652                                if let Some(column) =
653                                    expr.as_any().downcast_ref::<Column>()
654                                {
655                                    let new_column = Arc::new(Column::new(
656                                        column.name(),
657                                        column.index() + left_size,
658                                    ))
659                                        as _;
660                                    return Ok(Transformed::yes(new_column));
661                                }
662
663                                Ok(Transformed::no(expr))
664                            })
665                            .data()
666                            .unwrap();
667                        result.add_equal_conditions(&new_lhs, &new_rhs);
668                    }
669                }
670                result
671            }
672            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
673            JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
674        }
675    }
676
677    /// Checks if two expressions are equal either directly or through equivalence classes.
678    /// For complex expressions (e.g. a + b), checks that the expression trees are structurally
679    /// identical and their leaf nodes are equivalent either directly or through equivalence classes.
680    pub fn exprs_equal(
681        &self,
682        left: &Arc<dyn PhysicalExpr>,
683        right: &Arc<dyn PhysicalExpr>,
684    ) -> bool {
685        // Direct equality check
686        if left.eq(right) {
687            return true;
688        }
689
690        // Check if expressions are equivalent through equivalence classes
691        // We need to check both directions since expressions might be in different classes
692        if let Some(left_class) = self.get_equivalence_class(left) {
693            if left_class.contains(right) {
694                return true;
695            }
696        }
697        if let Some(right_class) = self.get_equivalence_class(right) {
698            if right_class.contains(left) {
699                return true;
700            }
701        }
702
703        // For non-leaf nodes, check structural equality
704        let left_children = left.children();
705        let right_children = right.children();
706
707        // If either expression is a leaf node and we haven't found equality yet,
708        // they must be different
709        if left_children.is_empty() || right_children.is_empty() {
710            return false;
711        }
712
713        // Type equality check through reflection
714        if left.as_any().type_id() != right.as_any().type_id() {
715            return false;
716        }
717
718        // Check if the number of children is the same
719        if left_children.len() != right_children.len() {
720            return false;
721        }
722
723        // Check if all children are equal
724        left_children
725            .into_iter()
726            .zip(right_children)
727            .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
728    }
729
730    /// Return the inner classes of this equivalence group.
731    pub fn into_inner(self) -> Vec<EquivalenceClass> {
732        self.classes
733    }
734}
735
736impl IntoIterator for EquivalenceGroup {
737    type Item = EquivalenceClass;
738    type IntoIter = IntoIter<EquivalenceClass>;
739
740    fn into_iter(self) -> Self::IntoIter {
741        self.classes.into_iter()
742    }
743}
744
745impl Display for EquivalenceGroup {
746    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
747        write!(f, "[")?;
748        let mut iter = self.iter();
749        if let Some(cls) = iter.next() {
750            write!(f, "{cls}")?;
751        }
752        for cls in iter {
753            write!(f, ", {cls}")?;
754        }
755        write!(f, "]")
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use super::*;
762    use crate::equivalence::tests::create_test_params;
763    use crate::expressions::{binary, col, lit, BinaryExpr, Literal};
764    use arrow::datatypes::{DataType, Field, Schema};
765
766    use datafusion_common::{Result, ScalarValue};
767    use datafusion_expr::Operator;
768
769    #[test]
770    fn test_bridge_groups() -> Result<()> {
771        // First entry in the tuple is argument, second entry is the bridged result
772        let test_cases = vec![
773            // ------- TEST CASE 1 -----------//
774            (
775                vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
776                // Expected is compared with set equality. Order of the specific results may change.
777                vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
778            ),
779            // ------- TEST CASE 2 -----------//
780            (
781                vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
782                // Expected
783                vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
784            ),
785        ];
786        for (entries, expected) in test_cases {
787            let entries = entries
788                .into_iter()
789                .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
790                .map(EquivalenceClass::new)
791                .collect::<Vec<_>>();
792            let expected = expected
793                .into_iter()
794                .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
795                .map(EquivalenceClass::new)
796                .collect::<Vec<_>>();
797            let mut eq_groups = EquivalenceGroup::new(entries.clone());
798            eq_groups.bridge_classes();
799            let eq_groups = eq_groups.classes;
800            let err_msg = format!(
801                "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
802            );
803            assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
804            for idx in 0..eq_groups.len() {
805                assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
806            }
807        }
808        Ok(())
809    }
810
811    #[test]
812    fn test_remove_redundant_entries_eq_group() -> Result<()> {
813        let entries = [
814            EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]),
815            // This group is meaningless should be removed
816            EquivalenceClass::new(vec![lit(3), lit(3)]),
817            EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
818        ];
819        // Given equivalences classes are not in succinct form.
820        // Expected form is the most plain representation that is functionally same.
821        let expected = [
822            EquivalenceClass::new(vec![lit(1), lit(2)]),
823            EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
824        ];
825        let mut eq_groups = EquivalenceGroup::new(entries.to_vec());
826        eq_groups.remove_redundant_entries();
827
828        let eq_groups = eq_groups.classes;
829        assert_eq!(eq_groups.len(), expected.len());
830        assert_eq!(eq_groups.len(), 2);
831
832        assert_eq!(eq_groups[0], expected[0]);
833        assert_eq!(eq_groups[1], expected[1]);
834        Ok(())
835    }
836
837    #[test]
838    fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
839        let col_a = &Column::new("a", 0);
840        let col_b = &Column::new("b", 1);
841        let col_c = &Column::new("c", 2);
842        // Assume that column a and c are aliases.
843        let (_test_schema, eq_properties) = create_test_params()?;
844
845        let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>;
846        let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>;
847        let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>;
848        // Test cases for equivalence normalization,
849        // First entry in the tuple is argument, second entry is expected result after normalization.
850        let expressions = vec![
851            // Normalized version of the column a and c should go to a
852            // (by convention all the expressions inside equivalence class are mapped to the first entry
853            // in this case a is the first entry in the equivalence class.)
854            (&col_a_expr, &col_a_expr),
855            (&col_c_expr, &col_a_expr),
856            // Cannot normalize column b
857            (&col_b_expr, &col_b_expr),
858        ];
859        let eq_group = eq_properties.eq_group();
860        for (expr, expected_eq) in expressions {
861            assert!(
862                expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))),
863                "error in test: expr: {expr:?}"
864            );
865        }
866
867        Ok(())
868    }
869
870    #[test]
871    fn test_contains_any() {
872        let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
873            as Arc<dyn PhysicalExpr>;
874        let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
875            as Arc<dyn PhysicalExpr>;
876        let lit2 =
877            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
878        let lit1 =
879            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
880        let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
881
882        let cls1 =
883            EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]);
884        let cls2 =
885            EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]);
886        let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]);
887
888        // lit_true is common
889        assert!(cls1.contains_any(&cls2));
890        // there is no common entry
891        assert!(!cls1.contains_any(&cls3));
892        assert!(!cls2.contains_any(&cls3));
893    }
894
895    #[test]
896    fn test_exprs_equal() -> Result<()> {
897        struct TestCase {
898            left: Arc<dyn PhysicalExpr>,
899            right: Arc<dyn PhysicalExpr>,
900            expected: bool,
901            description: &'static str,
902        }
903
904        // Create test columns
905        let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
906        let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
907        let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
908        let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;
909
910        // Create test literals
911        let lit_1 =
912            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
913        let lit_2 =
914            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
915
916        // Create equivalence group with classes (a = x) and (b = y)
917        let eq_group = EquivalenceGroup::new(vec![
918            EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
919            EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
920        ]);
921
922        let test_cases = vec![
923            // Basic equality tests
924            TestCase {
925                left: Arc::clone(&col_a),
926                right: Arc::clone(&col_a),
927                expected: true,
928                description: "Same column should be equal",
929            },
930            // Equivalence class tests
931            TestCase {
932                left: Arc::clone(&col_a),
933                right: Arc::clone(&col_x),
934                expected: true,
935                description: "Columns in same equivalence class should be equal",
936            },
937            TestCase {
938                left: Arc::clone(&col_b),
939                right: Arc::clone(&col_y),
940                expected: true,
941                description: "Columns in same equivalence class should be equal",
942            },
943            TestCase {
944                left: Arc::clone(&col_a),
945                right: Arc::clone(&col_b),
946                expected: false,
947                description:
948                    "Columns in different equivalence classes should not be equal",
949            },
950            // Literal tests
951            TestCase {
952                left: Arc::clone(&lit_1),
953                right: Arc::clone(&lit_1),
954                expected: true,
955                description: "Same literal should be equal",
956            },
957            TestCase {
958                left: Arc::clone(&lit_1),
959                right: Arc::clone(&lit_2),
960                expected: false,
961                description: "Different literals should not be equal",
962            },
963            // Complex expression tests
964            TestCase {
965                left: Arc::new(BinaryExpr::new(
966                    Arc::clone(&col_a),
967                    Operator::Plus,
968                    Arc::clone(&col_b),
969                )) as Arc<dyn PhysicalExpr>,
970                right: Arc::new(BinaryExpr::new(
971                    Arc::clone(&col_x),
972                    Operator::Plus,
973                    Arc::clone(&col_y),
974                )) as Arc<dyn PhysicalExpr>,
975                expected: true,
976                description:
977                    "Binary expressions with equivalent operands should be equal",
978            },
979            TestCase {
980                left: Arc::new(BinaryExpr::new(
981                    Arc::clone(&col_a),
982                    Operator::Plus,
983                    Arc::clone(&col_b),
984                )) as Arc<dyn PhysicalExpr>,
985                right: Arc::new(BinaryExpr::new(
986                    Arc::clone(&col_x),
987                    Operator::Plus,
988                    Arc::clone(&col_a),
989                )) as Arc<dyn PhysicalExpr>,
990                expected: false,
991                description:
992                    "Binary expressions with non-equivalent operands should not be equal",
993            },
994            TestCase {
995                left: Arc::new(BinaryExpr::new(
996                    Arc::clone(&col_a),
997                    Operator::Plus,
998                    Arc::clone(&lit_1),
999                )) as Arc<dyn PhysicalExpr>,
1000                right: Arc::new(BinaryExpr::new(
1001                    Arc::clone(&col_x),
1002                    Operator::Plus,
1003                    Arc::clone(&lit_1),
1004                )) as Arc<dyn PhysicalExpr>,
1005                expected: true,
1006                description: "Binary expressions with equivalent column and same literal should be equal",
1007            },
1008            TestCase {
1009                left: Arc::new(BinaryExpr::new(
1010                    Arc::new(BinaryExpr::new(
1011                        Arc::clone(&col_a),
1012                        Operator::Plus,
1013                        Arc::clone(&col_b),
1014                    )),
1015                    Operator::Multiply,
1016                    Arc::clone(&lit_1),
1017                )) as Arc<dyn PhysicalExpr>,
1018                right: Arc::new(BinaryExpr::new(
1019                    Arc::new(BinaryExpr::new(
1020                        Arc::clone(&col_x),
1021                        Operator::Plus,
1022                        Arc::clone(&col_y),
1023                    )),
1024                    Operator::Multiply,
1025                    Arc::clone(&lit_1),
1026                )) as Arc<dyn PhysicalExpr>,
1027                expected: true,
1028                description: "Nested binary expressions with equivalent operands should be equal",
1029            },
1030        ];
1031
1032        for TestCase {
1033            left,
1034            right,
1035            expected,
1036            description,
1037        } in test_cases
1038        {
1039            let actual = eq_group.exprs_equal(&left, &right);
1040            assert_eq!(
1041                actual, expected,
1042                "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1043            );
1044        }
1045
1046        Ok(())
1047    }
1048
1049    #[test]
1050    fn test_project_classes() -> Result<()> {
1051        // - columns: [a, b, c].
1052        // - "a" and "b" in the same equivalence class.
1053        // - then after a+c, b+c projection col(0) and col(1) must be
1054        // in the same class too.
1055        let schema = Arc::new(Schema::new(vec![
1056            Field::new("a", DataType::Int32, false),
1057            Field::new("b", DataType::Int32, false),
1058            Field::new("c", DataType::Int32, false),
1059        ]));
1060        let mut group = EquivalenceGroup::empty();
1061        group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?);
1062
1063        let projected_schema = Arc::new(Schema::new(vec![
1064            Field::new("a+c", DataType::Int32, false),
1065            Field::new("b+c", DataType::Int32, false),
1066        ]));
1067
1068        let mapping = ProjectionMapping {
1069            map: vec![
1070                (
1071                    binary(
1072                        col("a", &schema)?,
1073                        Operator::Plus,
1074                        col("c", &schema)?,
1075                        &schema,
1076                    )?,
1077                    col("a+c", &projected_schema)?,
1078                ),
1079                (
1080                    binary(
1081                        col("b", &schema)?,
1082                        Operator::Plus,
1083                        col("c", &schema)?,
1084                        &schema,
1085                    )?,
1086                    col("b+c", &projected_schema)?,
1087                ),
1088            ],
1089        };
1090
1091        let projected = group.project(&mapping);
1092
1093        assert!(!projected.is_empty());
1094        let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1095        let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1096
1097        assert!(first_normalized.eq(&second_normalized));
1098
1099        Ok(())
1100    }
1101}