datafusion_physical_expr/equivalence/
class.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Display;
19use std::ops::Deref;
20use std::sync::Arc;
21use std::vec::IntoIter;
22
23use super::ProjectionMapping;
24use crate::expressions::Literal;
25use crate::physical_expr::add_offset_to_expr;
26use crate::projection::ProjectionTargets;
27use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement};
28
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{HashMap, JoinType, Result, ScalarValue};
31use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
32
33use indexmap::{IndexMap, IndexSet};
34
35/// Represents whether a constant expression's value is uniform or varies across
36/// partitions. Has two variants:
37/// - `Heterogeneous`: The constant expression may have different values for
38///   different partitions.
39/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value
40///   across all partitions, or is `None` if the value is unknown.
41#[derive(Clone, Debug, Default, Eq, PartialEq)]
42pub enum AcrossPartitions {
43    #[default]
44    Heterogeneous,
45    Uniform(Option<ScalarValue>),
46}
47
48impl Display for AcrossPartitions {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"),
52            AcrossPartitions::Uniform(value) => {
53                if let Some(val) = value {
54                    write!(f, "(uniform: {val})")
55                } else {
56                    write!(f, "(uniform: unknown)")
57                }
58            }
59        }
60    }
61}
62
63/// A structure representing a expression known to be constant in a physical
64/// execution plan.
65///
66/// The `ConstExpr` struct encapsulates an expression that is constant during
67/// the execution of a query. For example if a filter like `A = 5` appears
68/// earlier in the plan, `A` would become a constant in subsequent operations.
69///
70/// # Fields
71///
72/// - `expr`: Constant expression for a node in the physical plan.
73/// - `across_partitions`: A boolean flag indicating whether the constant
74///   expression is the same across partitions. If set to `true`, the constant
75///   expression has same value for all partitions. If set to `false`, the
76///   constant expression may have different values for different partitions.
77///
78/// # Example
79///
80/// ```rust
81/// # use datafusion_physical_expr::ConstExpr;
82/// # use datafusion_physical_expr::expressions::lit;
83/// let col = lit(5);
84/// // Create a constant expression from a physical expression:
85/// let const_expr = ConstExpr::from(col);
86/// ```
87#[derive(Clone, Debug)]
88pub struct ConstExpr {
89    /// The expression that is known to be constant (e.g. a `Column`).
90    pub expr: Arc<dyn PhysicalExpr>,
91    /// Indicates whether the constant have the same value across all partitions.
92    pub across_partitions: AcrossPartitions,
93}
94// TODO: The `ConstExpr` definition above can be in an inconsistent state where
95//       `expr` is a literal but `across_partitions` is not `Uniform`. Consider
96//       a refactor to ensure that `ConstExpr` is always in a consistent state
97//       (either by changing type definition, or by API constraints).
98
99impl ConstExpr {
100    /// Create a new constant expression from a physical expression, specifying
101    /// whether the constant expression is the same across partitions.
102    ///
103    /// Note that you can also use `ConstExpr::from` to create a constant
104    /// expression from just a physical expression, with the *safe* assumption
105    /// of heterogenous values across partitions unless the expression is a
106    /// literal.
107    pub fn new(expr: Arc<dyn PhysicalExpr>, across_partitions: AcrossPartitions) -> Self {
108        let mut result = ConstExpr::from(expr);
109        // Override the across partitions specification if the expression is not
110        // a literal.
111        if result.across_partitions == AcrossPartitions::Heterogeneous {
112            result.across_partitions = across_partitions;
113        }
114        result
115    }
116
117    /// Returns a [`Display`]able list of `ConstExpr`.
118    pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
119        struct DisplayableList<'a>(&'a [ConstExpr]);
120        impl Display for DisplayableList<'_> {
121            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122                let mut first = true;
123                for const_expr in self.0 {
124                    if first {
125                        first = false;
126                    } else {
127                        write!(f, ",")?;
128                    }
129                    write!(f, "{const_expr}")?;
130                }
131                Ok(())
132            }
133        }
134        DisplayableList(input)
135    }
136}
137
138impl PartialEq for ConstExpr {
139    fn eq(&self, other: &Self) -> bool {
140        self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
141    }
142}
143
144impl Display for ConstExpr {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        write!(f, "{}", self.expr)?;
147        write!(f, "{}", self.across_partitions)
148    }
149}
150
151impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
152    fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
153        // By default, assume constant expressions are not same across partitions.
154        // However, if we have a literal, it will have a single value that is the
155        // same across all partitions.
156        let across = if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
157            AcrossPartitions::Uniform(Some(lit.value().clone()))
158        } else {
159            AcrossPartitions::Heterogeneous
160        };
161        Self {
162            expr,
163            across_partitions: across,
164        }
165    }
166}
167
168/// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known
169/// to have the same value for all tuples in a relation. These are generated by
170/// equality predicates (e.g. `a = b`), typically equi-join conditions and
171/// equality conditions in filters.
172///
173/// Two `EquivalenceClass`es are equal if they contains the same expressions in
174/// without any ordering.
175#[derive(Clone, Debug, Default, Eq, PartialEq)]
176pub struct EquivalenceClass {
177    /// The expressions in this equivalence class. The order doesn't matter for
178    /// equivalence purposes.
179    pub(crate) exprs: IndexSet<Arc<dyn PhysicalExpr>>,
180    /// Indicates whether the expressions in this equivalence class have a
181    /// constant value. A `Some` value indicates constant-ness.
182    pub(crate) constant: Option<AcrossPartitions>,
183}
184
185impl EquivalenceClass {
186    // Create a new equivalence class from a pre-existing collection.
187    pub fn new(exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>) -> Self {
188        let mut class = Self::default();
189        for expr in exprs {
190            class.push(expr);
191        }
192        class
193    }
194
195    /// Return the "canonical" expression for this class (the first element)
196    /// if non-empty.
197    pub fn canonical_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
198        self.exprs.iter().next()
199    }
200
201    /// Insert the expression into this class, meaning it is known to be equal to
202    /// all other expressions in this class.
203    pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
204        if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
205            let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone()));
206            if let Some(across) = self.constant.as_mut() {
207                // TODO: Return an error if constant values do not agree.
208                if *across == AcrossPartitions::Heterogeneous {
209                    *across = expr_across;
210                }
211            } else {
212                self.constant = Some(expr_across);
213            }
214        }
215        self.exprs.insert(expr);
216    }
217
218    /// Inserts all the expressions from other into this class.
219    pub fn extend(&mut self, other: Self) {
220        self.exprs.extend(other.exprs);
221        match (&self.constant, &other.constant) {
222            (Some(across), Some(_)) => {
223                // TODO: Return an error if constant values do not agree.
224                if across == &AcrossPartitions::Heterogeneous {
225                    self.constant = other.constant;
226                }
227            }
228            (None, Some(_)) => self.constant = other.constant,
229            (_, None) => {}
230        }
231    }
232
233    /// Returns whether this equivalence class has any entries in common with
234    /// `other`.
235    pub fn contains_any(&self, other: &Self) -> bool {
236        self.exprs.intersection(&other.exprs).next().is_some()
237    }
238
239    /// Returns whether this equivalence class is trivial, meaning that it is
240    /// either empty, or contains a single expression that is not a constant.
241    /// Such classes are not useful, and can be removed from equivalence groups.
242    pub fn is_trivial(&self) -> bool {
243        self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none())
244    }
245
246    /// Adds the given offset to all columns in the expressions inside this
247    /// class. This is used when schemas are appended, e.g. in joins.
248    pub fn try_with_offset(&self, offset: isize) -> Result<Self> {
249        let mut cls = Self::default();
250        for expr_result in self
251            .exprs
252            .iter()
253            .cloned()
254            .map(|e| add_offset_to_expr(e, offset))
255        {
256            cls.push(expr_result?);
257        }
258        Ok(cls)
259    }
260}
261
262impl Deref for EquivalenceClass {
263    type Target = IndexSet<Arc<dyn PhysicalExpr>>;
264
265    fn deref(&self) -> &Self::Target {
266        &self.exprs
267    }
268}
269
270impl IntoIterator for EquivalenceClass {
271    type Item = Arc<dyn PhysicalExpr>;
272    type IntoIter = <IndexSet<Self::Item> as IntoIterator>::IntoIter;
273
274    fn into_iter(self) -> Self::IntoIter {
275        self.exprs.into_iter()
276    }
277}
278
279impl Display for EquivalenceClass {
280    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281        write!(f, "{{")?;
282        write!(f, "members: {}", format_physical_expr_list(&self.exprs))?;
283        if let Some(across) = &self.constant {
284            write!(f, ", constant: {across}")?;
285        }
286        write!(f, "}}")
287    }
288}
289
290impl From<EquivalenceClass> for Vec<Arc<dyn PhysicalExpr>> {
291    fn from(cls: EquivalenceClass) -> Self {
292        cls.exprs.into_iter().collect()
293    }
294}
295
296type AugmentedMapping<'a> = IndexMap<
297    &'a Arc<dyn PhysicalExpr>,
298    (&'a ProjectionTargets, Option<&'a EquivalenceClass>),
299>;
300
301/// A collection of distinct `EquivalenceClass`es. This object supports fast
302/// lookups of expressions and their equivalence classes.
303#[derive(Clone, Debug, Default)]
304pub struct EquivalenceGroup {
305    /// A mapping from expressions to their equivalence class key.
306    map: HashMap<Arc<dyn PhysicalExpr>, usize>,
307    /// The equivalence classes in this group.
308    classes: Vec<EquivalenceClass>,
309}
310
311impl EquivalenceGroup {
312    /// Creates an equivalence group from the given equivalence classes.
313    pub fn new(classes: impl IntoIterator<Item = EquivalenceClass>) -> Self {
314        classes.into_iter().collect::<Vec<_>>().into()
315    }
316
317    /// Adds `expr` as a constant expression to this equivalence group.
318    pub fn add_constant(&mut self, const_expr: ConstExpr) {
319        // If the expression is already in an equivalence class, we should
320        // adjust the constant-ness of the class if necessary:
321        if let Some(idx) = self.map.get(&const_expr.expr) {
322            let cls = &mut self.classes[*idx];
323            if let Some(across) = cls.constant.as_mut() {
324                // TODO: Return an error if constant values do not agree.
325                if *across == AcrossPartitions::Heterogeneous {
326                    *across = const_expr.across_partitions;
327                }
328            } else {
329                cls.constant = Some(const_expr.across_partitions);
330            }
331            return;
332        }
333        // If the expression is not in any equivalence class, but has the same
334        // constant value with some class, add it to that class:
335        if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions {
336            for (idx, cls) in self.classes.iter_mut().enumerate() {
337                if cls
338                    .constant
339                    .as_ref()
340                    .is_some_and(|across| const_expr.across_partitions.eq(across))
341                {
342                    self.map.insert(Arc::clone(&const_expr.expr), idx);
343                    cls.push(const_expr.expr);
344                    return;
345                }
346            }
347        }
348        // Otherwise, create a new class with the expression as the only member:
349        let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr));
350        if new_class.constant.is_none() {
351            new_class.constant = Some(const_expr.across_partitions);
352        }
353        Self::update_lookup_table(&mut self.map, &new_class, self.classes.len());
354        self.classes.push(new_class);
355    }
356
357    /// Removes constant expressions that may change across partitions.
358    /// This method should be used when merging data from different partitions.
359    /// Returns whether any change was made to the equivalence group.
360    pub fn clear_per_partition_constants(&mut self) -> bool {
361        let (mut idx, mut change) = (0, false);
362        while idx < self.classes.len() {
363            let cls = &mut self.classes[idx];
364            if let Some(AcrossPartitions::Heterogeneous) = cls.constant {
365                change = true;
366                if cls.len() == 1 {
367                    // If this class becomes trivial, remove it entirely:
368                    self.remove_class_at_idx(idx);
369                    continue;
370                } else {
371                    cls.constant = None;
372                }
373            }
374            idx += 1;
375        }
376        change
377    }
378
379    /// Adds the equality `left` = `right` to this equivalence group. New
380    /// equality conditions often arise after steps like `Filter(a = b)`,
381    /// `Alias(a, a as b)` etc. Returns whether the given equality defines
382    /// a new equivalence class.
383    pub fn add_equal_conditions(
384        &mut self,
385        left: Arc<dyn PhysicalExpr>,
386        right: Arc<dyn PhysicalExpr>,
387    ) -> bool {
388        let first_class = self.map.get(&left).copied();
389        let second_class = self.map.get(&right).copied();
390        match (first_class, second_class) {
391            (Some(mut first_idx), Some(mut second_idx)) => {
392                // If the given left and right sides belong to different classes,
393                // we should unify/bridge these classes.
394                match first_idx.cmp(&second_idx) {
395                    // The equality is already known, return and signal this:
396                    std::cmp::Ordering::Equal => return false,
397                    // Swap indices to ensure `first_idx` is the lesser index.
398                    std::cmp::Ordering::Greater => {
399                        std::mem::swap(&mut first_idx, &mut second_idx);
400                    }
401                    _ => {}
402                }
403                // Remove the class at `second_idx` and merge its values with
404                // the class at `first_idx`. The convention above makes sure
405                // that `first_idx` is still valid after removing `second_idx`.
406                let other_class = self.remove_class_at_idx(second_idx);
407                // Update the lookup table for the second class:
408                Self::update_lookup_table(&mut self.map, &other_class, first_idx);
409                self.classes[first_idx].extend(other_class);
410            }
411            (Some(group_idx), None) => {
412                // Right side is new, extend left side's class:
413                self.map.insert(Arc::clone(&right), group_idx);
414                self.classes[group_idx].push(right);
415            }
416            (None, Some(group_idx)) => {
417                // Left side is new, extend right side's class:
418                self.map.insert(Arc::clone(&left), group_idx);
419                self.classes[group_idx].push(left);
420            }
421            (None, None) => {
422                // None of the expressions is among existing classes.
423                // Create a new equivalence class and extend the group.
424                let class = EquivalenceClass::new([left, right]);
425                Self::update_lookup_table(&mut self.map, &class, self.classes.len());
426                self.classes.push(class);
427                return true;
428            }
429        }
430        false
431    }
432
433    /// Removes the equivalence class at the given index from this group.
434    fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass {
435        // Remove the class at the given index:
436        let cls = self.classes.swap_remove(idx);
437        // Remove its entries from the lookup table:
438        for expr in cls.iter() {
439            self.map.remove(expr);
440        }
441        // Update the lookup table for the moved class:
442        if idx < self.classes.len() {
443            Self::update_lookup_table(&mut self.map, &self.classes[idx], idx);
444        }
445        cls
446    }
447
448    /// Updates the entry in lookup table for the given equivalence class with
449    /// the given index.
450    fn update_lookup_table(
451        map: &mut HashMap<Arc<dyn PhysicalExpr>, usize>,
452        cls: &EquivalenceClass,
453        idx: usize,
454    ) {
455        for expr in cls.iter() {
456            map.insert(Arc::clone(expr), idx);
457        }
458    }
459
460    /// Removes redundant entries from this group. Returns whether any change
461    /// was made to the equivalence group.
462    fn remove_redundant_entries(&mut self) -> bool {
463        // First, remove trivial equivalence classes:
464        let mut change = false;
465        for idx in (0..self.classes.len()).rev() {
466            if self.classes[idx].is_trivial() {
467                self.remove_class_at_idx(idx);
468                change = true;
469            }
470        }
471        // Then, unify/bridge groups that have common expressions:
472        self.bridge_classes() || change
473    }
474
475    /// This utility function unifies/bridges classes that have common expressions.
476    /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
477    /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
478    /// equal and belong to one class. This utility converts merges such classes.
479    /// Returns whether any change was made to the equivalence group.
480    fn bridge_classes(&mut self) -> bool {
481        let (mut idx, mut change) = (0, false);
482        'scan: while idx < self.classes.len() {
483            for other_idx in (idx + 1..self.classes.len()).rev() {
484                if self.classes[idx].contains_any(&self.classes[other_idx]) {
485                    let extension = self.remove_class_at_idx(other_idx);
486                    Self::update_lookup_table(&mut self.map, &extension, idx);
487                    self.classes[idx].extend(extension);
488                    change = true;
489                    continue 'scan;
490                }
491            }
492            idx += 1;
493        }
494        change
495    }
496
497    /// Extends this equivalence group with the `other` equivalence group.
498    /// Returns whether any equivalence classes were unified/bridged as a
499    /// result of the extension process.
500    pub fn extend(&mut self, other: Self) -> bool {
501        for (idx, cls) in other.classes.iter().enumerate() {
502            // Update the lookup table for the new class:
503            Self::update_lookup_table(&mut self.map, cls, idx);
504        }
505        self.classes.extend(other.classes);
506        self.bridge_classes()
507    }
508
509    /// Normalizes the given physical expression according to this group. The
510    /// expression is replaced with the first (canonical) expression in the
511    /// equivalence class it matches with (if any).
512    pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
513        expr.transform(|expr| {
514            let cls = self.get_equivalence_class(&expr);
515            let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else {
516                return Ok(Transformed::no(expr));
517            };
518            Ok(Transformed::yes(Arc::clone(canonical)))
519        })
520        .data()
521        .unwrap()
522        // The unwrap above is safe because the closure always returns `Ok`.
523    }
524
525    /// Normalizes the given sort expression according to this group. The
526    /// underlying physical expression is replaced with the first expression in
527    /// the equivalence class it matches with (if any). If the underlying
528    /// expression does not belong to any equivalence class in this group,
529    /// returns the sort expression as is.
530    pub fn normalize_sort_expr(
531        &self,
532        mut sort_expr: PhysicalSortExpr,
533    ) -> PhysicalSortExpr {
534        sort_expr.expr = self.normalize_expr(sort_expr.expr);
535        sort_expr
536    }
537
538    /// Normalizes the given sort expressions (i.e. `sort_exprs`) by:
539    /// - Replacing sections that belong to some equivalence class in the
540    ///   with the first entry in the matching equivalence class.
541    /// - Removing expressions that have a constant value.
542    ///
543    /// If columns `a` and `b` are known to be equal, `d` is known to be a
544    /// constant, and `sort_exprs` is `[b ASC, d DESC, c ASC, a ASC]`, this
545    /// function would return `[a ASC, c ASC, a ASC]`.
546    pub fn normalize_sort_exprs<'a>(
547        &'a self,
548        sort_exprs: impl IntoIterator<Item = PhysicalSortExpr> + 'a,
549    ) -> impl Iterator<Item = PhysicalSortExpr> + 'a {
550        sort_exprs
551            .into_iter()
552            .map(|sort_expr| self.normalize_sort_expr(sort_expr))
553            .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none())
554    }
555
556    /// Normalizes the given sort requirement according to this group. The
557    /// underlying physical expression is replaced with the first expression in
558    /// the equivalence class it matches with (if any). If the underlying
559    /// expression does not belong to any equivalence class in this group,
560    /// returns the given sort requirement as is.
561    pub fn normalize_sort_requirement(
562        &self,
563        mut sort_requirement: PhysicalSortRequirement,
564    ) -> PhysicalSortRequirement {
565        sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
566        sort_requirement
567    }
568
569    /// Normalizes the given sort requirements (i.e. `sort_reqs`) by:
570    /// - Replacing sections that belong to some equivalence class in the
571    ///   with the first entry in the matching equivalence class.
572    /// - Removing expressions that have a constant value.
573    ///
574    /// If columns `a` and `b` are known to be equal, `d` is known to be a
575    /// constant, and `sort_reqs` is `[b ASC, d DESC, c ASC, a ASC]`, this
576    /// function would return `[a ASC, c ASC, a ASC]`.
577    pub fn normalize_sort_requirements<'a>(
578        &'a self,
579        sort_reqs: impl IntoIterator<Item = PhysicalSortRequirement> + 'a,
580    ) -> impl Iterator<Item = PhysicalSortRequirement> + 'a {
581        sort_reqs
582            .into_iter()
583            .map(|req| self.normalize_sort_requirement(req))
584            .filter(|req| self.is_expr_constant(&req.expr).is_none())
585    }
586
587    /// Perform an indirect projection of `expr` by consulting the equivalence
588    /// classes.
589    fn project_expr_indirect(
590        aug_mapping: &AugmentedMapping,
591        expr: &Arc<dyn PhysicalExpr>,
592    ) -> Option<Arc<dyn PhysicalExpr>> {
593        // Literals don't need to be projected
594        if expr.as_any().downcast_ref::<Literal>().is_some() {
595            return Some(Arc::clone(expr));
596        }
597
598        // The given expression is not inside the mapping, so we try to project
599        // indirectly using equivalence classes.
600        for (targets, eq_class) in aug_mapping.values() {
601            // If we match an equivalent expression to a source expression in
602            // the mapping, then we can project. For example, if we have the
603            // mapping `(a as a1, a + c)` and the equivalence `a == b`,
604            // expression `b` projects to `a1`.
605            if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) {
606                let (target, _) = targets.first();
607                return Some(Arc::clone(target));
608            }
609        }
610        // Project a non-leaf expression by projecting its children.
611        let children = expr.children();
612        if children.is_empty() {
613            // A leaf expression should be inside the mapping.
614            return None;
615        }
616        children
617            .into_iter()
618            .map(|child| {
619                // First, we try to project children with an exact match. If
620                // we are unable to do this, we consult equivalence classes.
621                if let Some((targets, _)) = aug_mapping.get(child) {
622                    // If we match the source, we can project directly:
623                    let (target, _) = targets.first();
624                    Some(Arc::clone(target))
625                } else {
626                    Self::project_expr_indirect(aug_mapping, child)
627                }
628            })
629            .collect::<Option<Vec<_>>>()
630            .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
631    }
632
633    fn augment_projection_mapping<'a>(
634        &'a self,
635        mapping: &'a ProjectionMapping,
636    ) -> AugmentedMapping<'a> {
637        mapping
638            .iter()
639            .map(|(k, v)| {
640                let eq_class = self.get_equivalence_class(k);
641                (k, (v, eq_class))
642            })
643            .collect()
644    }
645
646    /// Projects `expr` according to the given projection mapping.
647    /// If the resulting expression is invalid after projection, returns `None`.
648    pub fn project_expr(
649        &self,
650        mapping: &ProjectionMapping,
651        expr: &Arc<dyn PhysicalExpr>,
652    ) -> Option<Arc<dyn PhysicalExpr>> {
653        if let Some(targets) = mapping.get(expr) {
654            // If we match the source, we can project directly:
655            let (target, _) = targets.first();
656            Some(Arc::clone(target))
657        } else {
658            let aug_mapping = self.augment_projection_mapping(mapping);
659            Self::project_expr_indirect(&aug_mapping, expr)
660        }
661    }
662
663    /// Projects `expressions` according to the given projection mapping.
664    /// This function is similar to [`Self::project_expr`], but projects multiple
665    /// expressions at once more efficiently than calling `project_expr` for each
666    /// expression.
667    pub fn project_expressions<'a>(
668        &'a self,
669        mapping: &'a ProjectionMapping,
670        expressions: impl IntoIterator<Item = &'a Arc<dyn PhysicalExpr>> + 'a,
671    ) -> impl Iterator<Item = Option<Arc<dyn PhysicalExpr>>> + 'a {
672        let mut aug_mapping = None;
673        expressions.into_iter().map(move |expr| {
674            if let Some(targets) = mapping.get(expr) {
675                // If we match the source, we can project directly:
676                let (target, _) = targets.first();
677                Some(Arc::clone(target))
678            } else {
679                let aug_mapping = aug_mapping
680                    .get_or_insert_with(|| self.augment_projection_mapping(mapping));
681                Self::project_expr_indirect(aug_mapping, expr)
682            }
683        })
684    }
685
686    /// Projects this equivalence group according to the given projection mapping.
687    pub fn project(&self, mapping: &ProjectionMapping) -> Self {
688        let projected_classes = self.iter().map(|cls| {
689            let new_exprs = self.project_expressions(mapping, cls.iter());
690            EquivalenceClass::new(new_exprs.flatten())
691        });
692
693        // The key is the source expression, and the value is the equivalence
694        // class that contains the corresponding target expression.
695        let mut new_constants = vec![];
696        let mut new_classes = IndexMap::<_, EquivalenceClass>::new();
697        for (source, targets) in mapping.iter() {
698            // We need to find equivalent projected expressions. For example,
699            // consider a table with columns `[a, b, c]` with `a` == `b`, and
700            // projection `[a + c, b + c]`. To conclude that `a + c == b + c`,
701            // we first normalize all source expressions in the mapping, then
702            // merge all equivalent expressions into the classes.
703            let normalized_expr = self.normalize_expr(Arc::clone(source));
704            let cls = new_classes.entry(normalized_expr).or_default();
705            for (target, _) in targets.iter() {
706                cls.push(Arc::clone(target));
707            }
708            // Save new constants arising from the projection:
709            if let Some(across) = self.is_expr_constant(source) {
710                for (target, _) in targets.iter() {
711                    let const_expr = ConstExpr::new(Arc::clone(target), across.clone());
712                    new_constants.push(const_expr);
713                }
714            }
715        }
716
717        // Union projected classes with new classes to make up the result:
718        let classes = projected_classes
719            .chain(new_classes.into_values())
720            .filter(|cls| !cls.is_trivial());
721        let mut result = Self::new(classes);
722        // Add new constants arising from the projection to the equivalence group:
723        for constant in new_constants {
724            result.add_constant(constant);
725        }
726        result
727    }
728
729    /// Returns a `Some` value if the expression is constant according to
730    /// equivalence group, and `None` otherwise. The `Some` variant contains
731    /// an `AcrossPartitions` value indicating whether the expression is
732    /// constant across partitions, and its actual value (if available).
733    pub fn is_expr_constant(
734        &self,
735        expr: &Arc<dyn PhysicalExpr>,
736    ) -> Option<AcrossPartitions> {
737        if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
738            return Some(AcrossPartitions::Uniform(Some(lit.value().clone())));
739        }
740        if let Some(cls) = self.get_equivalence_class(expr) {
741            if cls.constant.is_some() {
742                return cls.constant.clone();
743            }
744        }
745        // TODO: This function should be able to return values of non-literal
746        //       complex constants as well; e.g. it should return `8` for the
747        //       expression `3 + 5`, not an unknown `heterogenous` value.
748        let children = expr.children();
749        if children.is_empty() {
750            return None;
751        }
752        for child in children {
753            self.is_expr_constant(child)?;
754        }
755        Some(AcrossPartitions::Heterogeneous)
756    }
757
758    /// Returns the equivalence class containing `expr`. If no equivalence class
759    /// contains `expr`, returns `None`.
760    pub fn get_equivalence_class(
761        &self,
762        expr: &Arc<dyn PhysicalExpr>,
763    ) -> Option<&EquivalenceClass> {
764        self.map.get(expr).map(|idx| &self.classes[*idx])
765    }
766
767    /// Combine equivalence groups of the given join children.
768    pub fn join(
769        &self,
770        right_equivalences: &Self,
771        join_type: &JoinType,
772        left_size: usize,
773        on: &[(PhysicalExprRef, PhysicalExprRef)],
774    ) -> Result<Self> {
775        let group = match join_type {
776            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
777                let mut result = Self::new(
778                    self.iter().cloned().chain(
779                        right_equivalences
780                            .iter()
781                            .map(|cls| cls.try_with_offset(left_size as _))
782                            .collect::<Result<Vec<_>>>()?,
783                    ),
784                );
785                // In we have an inner join, expressions in the "on" condition
786                // are equal in the resulting table.
787                if join_type == &JoinType::Inner {
788                    for (lhs, rhs) in on.iter() {
789                        let new_lhs = Arc::clone(lhs);
790                        // Rewrite rhs to point to the right side of the join:
791                        let new_rhs =
792                            add_offset_to_expr(Arc::clone(rhs), left_size as _)?;
793                        result.add_equal_conditions(new_lhs, new_rhs);
794                    }
795                }
796                result
797            }
798            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
799            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
800                right_equivalences.clone()
801            }
802        };
803        Ok(group)
804    }
805
806    /// Checks if two expressions are equal directly or through equivalence
807    /// classes. For complex expressions (e.g. `a + b`), checks that the
808    /// expression trees are structurally identical and their leaf nodes are
809    /// equivalent either directly or through equivalence classes.
810    pub fn exprs_equal(
811        &self,
812        left: &Arc<dyn PhysicalExpr>,
813        right: &Arc<dyn PhysicalExpr>,
814    ) -> bool {
815        // Direct equality check
816        if left.eq(right) {
817            return true;
818        }
819
820        // Check if expressions are equivalent through equivalence classes
821        // We need to check both directions since expressions might be in different classes
822        if let Some(left_class) = self.get_equivalence_class(left) {
823            if left_class.contains(right) {
824                return true;
825            }
826        }
827        if let Some(right_class) = self.get_equivalence_class(right) {
828            if right_class.contains(left) {
829                return true;
830            }
831        }
832
833        // For non-leaf nodes, check structural equality
834        let left_children = left.children();
835        let right_children = right.children();
836
837        // If either expression is a leaf node and we haven't found equality yet,
838        // they must be different
839        if left_children.is_empty() || right_children.is_empty() {
840            return false;
841        }
842
843        // Type equality check through reflection
844        if left.as_any().type_id() != right.as_any().type_id() {
845            return false;
846        }
847
848        // Check if the number of children is the same
849        if left_children.len() != right_children.len() {
850            return false;
851        }
852
853        // Check if all children are equal
854        left_children
855            .into_iter()
856            .zip(right_children)
857            .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
858    }
859}
860
861impl Deref for EquivalenceGroup {
862    type Target = [EquivalenceClass];
863
864    fn deref(&self) -> &Self::Target {
865        &self.classes
866    }
867}
868
869impl IntoIterator for EquivalenceGroup {
870    type Item = EquivalenceClass;
871    type IntoIter = IntoIter<Self::Item>;
872
873    fn into_iter(self) -> Self::IntoIter {
874        self.classes.into_iter()
875    }
876}
877
878impl Display for EquivalenceGroup {
879    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
880        write!(f, "[")?;
881        let mut iter = self.iter();
882        if let Some(cls) = iter.next() {
883            write!(f, "{cls}")?;
884        }
885        for cls in iter {
886            write!(f, ", {cls}")?;
887        }
888        write!(f, "]")
889    }
890}
891
892impl From<Vec<EquivalenceClass>> for EquivalenceGroup {
893    fn from(classes: Vec<EquivalenceClass>) -> Self {
894        let mut result = Self {
895            map: classes
896                .iter()
897                .enumerate()
898                .flat_map(|(idx, cls)| {
899                    cls.iter().map(move |expr| (Arc::clone(expr), idx))
900                })
901                .collect(),
902            classes,
903        };
904        result.remove_redundant_entries();
905        result
906    }
907}
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912    use crate::equivalence::tests::create_test_params;
913    use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal};
914    use arrow::datatypes::{DataType, Field, Schema};
915
916    use datafusion_common::{Result, ScalarValue};
917    use datafusion_expr::Operator;
918
919    #[test]
920    fn test_bridge_groups() -> Result<()> {
921        // First entry in the tuple is argument, second entry is the bridged result
922        let test_cases = vec![
923            // ------- TEST CASE 1 -----------//
924            (
925                vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
926                // Expected is compared with set equality. Order of the specific results may change.
927                vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
928            ),
929            // ------- TEST CASE 2 -----------//
930            (
931                vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
932                // Expected
933                vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
934            ),
935        ];
936        for (entries, expected) in test_cases {
937            let entries = entries
938                .into_iter()
939                .map(|entry| {
940                    entry.into_iter().map(|idx| {
941                        let c = Column::new(format!("col_{idx}").as_str(), idx);
942                        Arc::new(c) as _
943                    })
944                })
945                .map(EquivalenceClass::new)
946                .collect::<Vec<_>>();
947            let expected = expected
948                .into_iter()
949                .map(|entry| {
950                    entry.into_iter().map(|idx| {
951                        let c = Column::new(format!("col_{idx}").as_str(), idx);
952                        Arc::new(c) as _
953                    })
954                })
955                .map(EquivalenceClass::new)
956                .collect::<Vec<_>>();
957            let eq_groups: EquivalenceGroup = entries.clone().into();
958            let eq_groups = eq_groups.classes;
959            let err_msg = format!(
960                "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
961            );
962            assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
963            for idx in 0..eq_groups.len() {
964                assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
965            }
966        }
967        Ok(())
968    }
969
970    #[test]
971    fn test_remove_redundant_entries_eq_group() -> Result<()> {
972        let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _;
973        let entries = [
974            EquivalenceClass::new([c(1), c(1), lit(20)]),
975            EquivalenceClass::new([lit(30), lit(30)]),
976            EquivalenceClass::new([c(2), c(3), c(4)]),
977        ];
978        // Given equivalences classes are not in succinct form.
979        // Expected form is the most plain representation that is functionally same.
980        let expected = [
981            EquivalenceClass::new([c(1), lit(20)]),
982            EquivalenceClass::new([lit(30)]),
983            EquivalenceClass::new([c(2), c(3), c(4)]),
984        ];
985        let eq_groups = EquivalenceGroup::new(entries);
986        assert_eq!(eq_groups.classes, expected);
987        Ok(())
988    }
989
990    #[test]
991    fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
992        let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
993        let col_b = Arc::new(Column::new("b", 1)) as _;
994        let col_c = Arc::new(Column::new("c", 2)) as _;
995        // Assume that column a and c are aliases.
996        let (_, eq_properties) = create_test_params()?;
997        // Test cases for equivalence normalization. First entry in the tuple is
998        // the argument, second entry is expected result after normalization.
999        let expressions = vec![
1000            // Normalized version of the column a and c should go to a
1001            // (by convention all the expressions inside equivalence class are mapped to the first entry
1002            // in this case a is the first entry in the equivalence class.)
1003            (Arc::clone(&col_a), Arc::clone(&col_a)),
1004            (col_c, col_a),
1005            // Cannot normalize column b
1006            (Arc::clone(&col_b), Arc::clone(&col_b)),
1007        ];
1008        let eq_group = eq_properties.eq_group();
1009        for (expr, expected_eq) in expressions {
1010            assert!(expected_eq.eq(&eq_group.normalize_expr(expr)));
1011        }
1012
1013        Ok(())
1014    }
1015
1016    #[test]
1017    fn test_contains_any() {
1018        let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _;
1019        let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _;
1020        let col_a_expr = Arc::new(Column::new("a", 0)) as _;
1021        let col_b_expr = Arc::new(Column::new("b", 1)) as _;
1022        let col_c_expr = Arc::new(Column::new("c", 2)) as _;
1023
1024        let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]);
1025        let cls2 = EquivalenceClass::new([lit_true, col_b_expr]);
1026        let cls3 = EquivalenceClass::new([col_c_expr, lit_false]);
1027
1028        // lit_true is common
1029        assert!(cls1.contains_any(&cls2));
1030        // there is no common entry
1031        assert!(!cls1.contains_any(&cls3));
1032        assert!(!cls2.contains_any(&cls3));
1033    }
1034
1035    #[test]
1036    fn test_exprs_equal() -> Result<()> {
1037        struct TestCase {
1038            left: Arc<dyn PhysicalExpr>,
1039            right: Arc<dyn PhysicalExpr>,
1040            expected: bool,
1041            description: &'static str,
1042        }
1043
1044        // Create test columns
1045        let col_a = Arc::new(Column::new("a", 0)) as _;
1046        let col_b = Arc::new(Column::new("b", 1)) as _;
1047        let col_x = Arc::new(Column::new("x", 2)) as _;
1048        let col_y = Arc::new(Column::new("y", 3)) as _;
1049
1050        // Create test literals
1051        let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _;
1052        let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _;
1053
1054        // Create equivalence group with classes (a = x) and (b = y)
1055        let eq_group = EquivalenceGroup::new([
1056            EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]),
1057            EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]),
1058        ]);
1059
1060        let test_cases = vec![
1061            // Basic equality tests
1062            TestCase {
1063                left: Arc::clone(&col_a),
1064                right: Arc::clone(&col_a),
1065                expected: true,
1066                description: "Same column should be equal",
1067            },
1068            // Equivalence class tests
1069            TestCase {
1070                left: Arc::clone(&col_a),
1071                right: Arc::clone(&col_x),
1072                expected: true,
1073                description: "Columns in same equivalence class should be equal",
1074            },
1075            TestCase {
1076                left: Arc::clone(&col_b),
1077                right: Arc::clone(&col_y),
1078                expected: true,
1079                description: "Columns in same equivalence class should be equal",
1080            },
1081            TestCase {
1082                left: Arc::clone(&col_a),
1083                right: Arc::clone(&col_b),
1084                expected: false,
1085                description:
1086                    "Columns in different equivalence classes should not be equal",
1087            },
1088            // Literal tests
1089            TestCase {
1090                left: Arc::clone(&lit_1),
1091                right: Arc::clone(&lit_1),
1092                expected: true,
1093                description: "Same literal should be equal",
1094            },
1095            TestCase {
1096                left: Arc::clone(&lit_1),
1097                right: Arc::clone(&lit_2),
1098                expected: false,
1099                description: "Different literals should not be equal",
1100            },
1101            // Complex expression tests
1102            TestCase {
1103                left: Arc::new(BinaryExpr::new(
1104                    Arc::clone(&col_a),
1105                    Operator::Plus,
1106                    Arc::clone(&col_b),
1107                )) as _,
1108                right: Arc::new(BinaryExpr::new(
1109                    Arc::clone(&col_x),
1110                    Operator::Plus,
1111                    Arc::clone(&col_y),
1112                )) as _,
1113                expected: true,
1114                description:
1115                    "Binary expressions with equivalent operands should be equal",
1116            },
1117            TestCase {
1118                left: Arc::new(BinaryExpr::new(
1119                    Arc::clone(&col_a),
1120                    Operator::Plus,
1121                    Arc::clone(&col_b),
1122                )) as _,
1123                right: Arc::new(BinaryExpr::new(
1124                    Arc::clone(&col_x),
1125                    Operator::Plus,
1126                    Arc::clone(&col_a),
1127                )) as _,
1128                expected: false,
1129                description:
1130                    "Binary expressions with non-equivalent operands should not be equal",
1131            },
1132            TestCase {
1133                left: Arc::new(BinaryExpr::new(
1134                    Arc::clone(&col_a),
1135                    Operator::Plus,
1136                    Arc::clone(&lit_1),
1137                )) as _,
1138                right: Arc::new(BinaryExpr::new(
1139                    Arc::clone(&col_x),
1140                    Operator::Plus,
1141                    Arc::clone(&lit_1),
1142                )) as _,
1143                expected: true,
1144                description: "Binary expressions with equivalent column and same literal should be equal",
1145            },
1146            TestCase {
1147                left: Arc::new(BinaryExpr::new(
1148                    Arc::new(BinaryExpr::new(
1149                        Arc::clone(&col_a),
1150                        Operator::Plus,
1151                        Arc::clone(&col_b),
1152                    )),
1153                    Operator::Multiply,
1154                    Arc::clone(&lit_1),
1155                )) as _,
1156                right: Arc::new(BinaryExpr::new(
1157                    Arc::new(BinaryExpr::new(
1158                        Arc::clone(&col_x),
1159                        Operator::Plus,
1160                        Arc::clone(&col_y),
1161                    )),
1162                    Operator::Multiply,
1163                    Arc::clone(&lit_1),
1164                )) as _,
1165                expected: true,
1166                description: "Nested binary expressions with equivalent operands should be equal",
1167            },
1168        ];
1169
1170        for TestCase {
1171            left,
1172            right,
1173            expected,
1174            description,
1175        } in test_cases
1176        {
1177            let actual = eq_group.exprs_equal(&left, &right);
1178            assert_eq!(
1179                actual, expected,
1180                "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1181            );
1182        }
1183
1184        Ok(())
1185    }
1186
1187    #[test]
1188    fn test_project_classes() -> Result<()> {
1189        // - columns: [a, b, c].
1190        // - "a" and "b" in the same equivalence class.
1191        // - then after a+c, b+c projection col(0) and col(1) must be
1192        // in the same class too.
1193        let schema = Arc::new(Schema::new(vec![
1194            Field::new("a", DataType::Int32, false),
1195            Field::new("b", DataType::Int32, false),
1196            Field::new("c", DataType::Int32, false),
1197        ]));
1198        let mut group = EquivalenceGroup::default();
1199        group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?);
1200
1201        let projected_schema = Arc::new(Schema::new(vec![
1202            Field::new("a+c", DataType::Int32, false),
1203            Field::new("b+c", DataType::Int32, false),
1204        ]));
1205
1206        let mapping = [
1207            (
1208                binary(
1209                    col("a", &schema)?,
1210                    Operator::Plus,
1211                    col("c", &schema)?,
1212                    &schema,
1213                )?,
1214                vec![(col("a+c", &projected_schema)?, 0)].into(),
1215            ),
1216            (
1217                binary(
1218                    col("b", &schema)?,
1219                    Operator::Plus,
1220                    col("c", &schema)?,
1221                    &schema,
1222                )?,
1223                vec![(col("b+c", &projected_schema)?, 1)].into(),
1224            ),
1225        ]
1226        .into_iter()
1227        .collect::<ProjectionMapping>();
1228
1229        let projected = group.project(&mapping);
1230
1231        assert!(!projected.is_empty());
1232        let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1233        let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1234
1235        assert!(first_normalized.eq(&second_normalized));
1236
1237        Ok(())
1238    }
1239}