datafusion_physical_expr/equivalence/
class.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Display;
19use std::ops::Deref;
20use std::sync::Arc;
21use std::vec::IntoIter;
22
23use super::projection::ProjectionTargets;
24use super::ProjectionMapping;
25use crate::expressions::Literal;
26use crate::physical_expr::add_offset_to_expr;
27use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement};
28
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30use datafusion_common::{HashMap, JoinType, Result, ScalarValue};
31use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
32
33use indexmap::{IndexMap, IndexSet};
34
35/// Represents whether a constant expression's value is uniform or varies across
36/// partitions. Has two variants:
37/// - `Heterogeneous`: The constant expression may have different values for
38///   different partitions.
39/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value
40///   across all partitions, or is `None` if the value is unknown.
41#[derive(Clone, Debug, Default, Eq, PartialEq)]
42pub enum AcrossPartitions {
43    #[default]
44    Heterogeneous,
45    Uniform(Option<ScalarValue>),
46}
47
48impl Display for AcrossPartitions {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"),
52            AcrossPartitions::Uniform(value) => {
53                if let Some(val) = value {
54                    write!(f, "(uniform: {val})")
55                } else {
56                    write!(f, "(uniform: unknown)")
57                }
58            }
59        }
60    }
61}
62
63/// A structure representing a expression known to be constant in a physical
64/// execution plan.
65///
66/// The `ConstExpr` struct encapsulates an expression that is constant during
67/// the execution of a query. For example if a filter like `A = 5` appears
68/// earlier in the plan, `A` would become a constant in subsequent operations.
69///
70/// # Fields
71///
72/// - `expr`: Constant expression for a node in the physical plan.
73/// - `across_partitions`: A boolean flag indicating whether the constant
74///   expression is the same across partitions. If set to `true`, the constant
75///   expression has same value for all partitions. If set to `false`, the
76///   constant expression may have different values for different partitions.
77///
78/// # Example
79///
80/// ```rust
81/// # use datafusion_physical_expr::ConstExpr;
82/// # use datafusion_physical_expr::expressions::lit;
83/// let col = lit(5);
84/// // Create a constant expression from a physical expression:
85/// let const_expr = ConstExpr::from(col);
86/// ```
87#[derive(Clone, Debug)]
88pub struct ConstExpr {
89    /// The expression that is known to be constant (e.g. a `Column`).
90    pub expr: Arc<dyn PhysicalExpr>,
91    /// Indicates whether the constant have the same value across all partitions.
92    pub across_partitions: AcrossPartitions,
93}
94// TODO: The `ConstExpr` definition above can be in an inconsistent state where
95//       `expr` is a literal but `across_partitions` is not `Uniform`. Consider
96//       a refactor to ensure that `ConstExpr` is always in a consistent state
97//       (either by changing type definition, or by API constraints).
98
99impl ConstExpr {
100    /// Create a new constant expression from a physical expression, specifying
101    /// whether the constant expression is the same across partitions.
102    ///
103    /// Note that you can also use `ConstExpr::from` to create a constant
104    /// expression from just a physical expression, with the *safe* assumption
105    /// of heterogenous values across partitions unless the expression is a
106    /// literal.
107    pub fn new(expr: Arc<dyn PhysicalExpr>, across_partitions: AcrossPartitions) -> Self {
108        let mut result = ConstExpr::from(expr);
109        // Override the across partitions specification if the expression is not
110        // a literal.
111        if result.across_partitions == AcrossPartitions::Heterogeneous {
112            result.across_partitions = across_partitions;
113        }
114        result
115    }
116
117    /// Returns a [`Display`]able list of `ConstExpr`.
118    pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
119        struct DisplayableList<'a>(&'a [ConstExpr]);
120        impl Display for DisplayableList<'_> {
121            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122                let mut first = true;
123                for const_expr in self.0 {
124                    if first {
125                        first = false;
126                    } else {
127                        write!(f, ",")?;
128                    }
129                    write!(f, "{const_expr}")?;
130                }
131                Ok(())
132            }
133        }
134        DisplayableList(input)
135    }
136}
137
138impl PartialEq for ConstExpr {
139    fn eq(&self, other: &Self) -> bool {
140        self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
141    }
142}
143
144impl Display for ConstExpr {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        write!(f, "{}", self.expr)?;
147        write!(f, "{}", self.across_partitions)
148    }
149}
150
151impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
152    fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
153        // By default, assume constant expressions are not same across partitions.
154        // However, if we have a literal, it will have a single value that is the
155        // same across all partitions.
156        let across = if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
157            AcrossPartitions::Uniform(Some(lit.value().clone()))
158        } else {
159            AcrossPartitions::Heterogeneous
160        };
161        Self {
162            expr,
163            across_partitions: across,
164        }
165    }
166}
167
168/// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known
169/// to have the same value for all tuples in a relation. These are generated by
170/// equality predicates (e.g. `a = b`), typically equi-join conditions and
171/// equality conditions in filters.
172///
173/// Two `EquivalenceClass`es are equal if they contains the same expressions in
174/// without any ordering.
175#[derive(Clone, Debug, Default, Eq, PartialEq)]
176pub struct EquivalenceClass {
177    /// The expressions in this equivalence class. The order doesn't matter for
178    /// equivalence purposes.
179    pub(crate) exprs: IndexSet<Arc<dyn PhysicalExpr>>,
180    /// Indicates whether the expressions in this equivalence class have a
181    /// constant value. A `Some` value indicates constant-ness.
182    pub(crate) constant: Option<AcrossPartitions>,
183}
184
185impl EquivalenceClass {
186    // Create a new equivalence class from a pre-existing collection.
187    pub fn new(exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>) -> Self {
188        let mut class = Self::default();
189        for expr in exprs {
190            class.push(expr);
191        }
192        class
193    }
194
195    /// Return the "canonical" expression for this class (the first element)
196    /// if non-empty.
197    pub fn canonical_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
198        self.exprs.iter().next()
199    }
200
201    /// Insert the expression into this class, meaning it is known to be equal to
202    /// all other expressions in this class.
203    pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
204        if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
205            let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone()));
206            if let Some(across) = self.constant.as_mut() {
207                // TODO: Return an error if constant values do not agree.
208                if *across == AcrossPartitions::Heterogeneous {
209                    *across = expr_across;
210                }
211            } else {
212                self.constant = Some(expr_across);
213            }
214        }
215        self.exprs.insert(expr);
216    }
217
218    /// Inserts all the expressions from other into this class.
219    pub fn extend(&mut self, other: Self) {
220        self.exprs.extend(other.exprs);
221        match (&self.constant, &other.constant) {
222            (Some(across), Some(_)) => {
223                // TODO: Return an error if constant values do not agree.
224                if across == &AcrossPartitions::Heterogeneous {
225                    self.constant = other.constant;
226                }
227            }
228            (None, Some(_)) => self.constant = other.constant,
229            (_, None) => {}
230        }
231    }
232
233    /// Returns whether this equivalence class has any entries in common with
234    /// `other`.
235    pub fn contains_any(&self, other: &Self) -> bool {
236        self.exprs.intersection(&other.exprs).next().is_some()
237    }
238
239    /// Returns whether this equivalence class is trivial, meaning that it is
240    /// either empty, or contains a single expression that is not a constant.
241    /// Such classes are not useful, and can be removed from equivalence groups.
242    pub fn is_trivial(&self) -> bool {
243        self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none())
244    }
245
246    /// Adds the given offset to all columns in the expressions inside this
247    /// class. This is used when schemas are appended, e.g. in joins.
248    pub fn try_with_offset(&self, offset: isize) -> Result<Self> {
249        let mut cls = Self::default();
250        for expr_result in self
251            .exprs
252            .iter()
253            .cloned()
254            .map(|e| add_offset_to_expr(e, offset))
255        {
256            cls.push(expr_result?);
257        }
258        Ok(cls)
259    }
260}
261
262impl Deref for EquivalenceClass {
263    type Target = IndexSet<Arc<dyn PhysicalExpr>>;
264
265    fn deref(&self) -> &Self::Target {
266        &self.exprs
267    }
268}
269
270impl IntoIterator for EquivalenceClass {
271    type Item = Arc<dyn PhysicalExpr>;
272    type IntoIter = <IndexSet<Self::Item> as IntoIterator>::IntoIter;
273
274    fn into_iter(self) -> Self::IntoIter {
275        self.exprs.into_iter()
276    }
277}
278
279impl Display for EquivalenceClass {
280    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281        write!(f, "{{")?;
282        write!(f, "members: {}", format_physical_expr_list(&self.exprs))?;
283        if let Some(across) = &self.constant {
284            write!(f, ", constant: {across}")?;
285        }
286        write!(f, "}}")
287    }
288}
289
290impl From<EquivalenceClass> for Vec<Arc<dyn PhysicalExpr>> {
291    fn from(cls: EquivalenceClass) -> Self {
292        cls.exprs.into_iter().collect()
293    }
294}
295
296type AugmentedMapping<'a> = IndexMap<
297    &'a Arc<dyn PhysicalExpr>,
298    (&'a ProjectionTargets, Option<&'a EquivalenceClass>),
299>;
300
301/// A collection of distinct `EquivalenceClass`es. This object supports fast
302/// lookups of expressions and their equivalence classes.
303#[derive(Clone, Debug, Default)]
304pub struct EquivalenceGroup {
305    /// A mapping from expressions to their equivalence class key.
306    map: HashMap<Arc<dyn PhysicalExpr>, usize>,
307    /// The equivalence classes in this group.
308    classes: Vec<EquivalenceClass>,
309}
310
311impl EquivalenceGroup {
312    /// Creates an equivalence group from the given equivalence classes.
313    pub fn new(classes: impl IntoIterator<Item = EquivalenceClass>) -> Self {
314        classes.into_iter().collect::<Vec<_>>().into()
315    }
316
317    /// Adds `expr` as a constant expression to this equivalence group.
318    pub fn add_constant(&mut self, const_expr: ConstExpr) {
319        // If the expression is already in an equivalence class, we should
320        // adjust the constant-ness of the class if necessary:
321        if let Some(idx) = self.map.get(&const_expr.expr) {
322            let cls = &mut self.classes[*idx];
323            if let Some(across) = cls.constant.as_mut() {
324                // TODO: Return an error if constant values do not agree.
325                if *across == AcrossPartitions::Heterogeneous {
326                    *across = const_expr.across_partitions;
327                }
328            } else {
329                cls.constant = Some(const_expr.across_partitions);
330            }
331            return;
332        }
333        // If the expression is not in any equivalence class, but has the same
334        // constant value with some class, add it to that class:
335        if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions {
336            for (idx, cls) in self.classes.iter_mut().enumerate() {
337                if cls
338                    .constant
339                    .as_ref()
340                    .is_some_and(|across| const_expr.across_partitions.eq(across))
341                {
342                    self.map.insert(Arc::clone(&const_expr.expr), idx);
343                    cls.push(const_expr.expr);
344                    return;
345                }
346            }
347        }
348        // Otherwise, create a new class with the expression as the only member:
349        let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr));
350        if new_class.constant.is_none() {
351            new_class.constant = Some(const_expr.across_partitions);
352        }
353        Self::update_lookup_table(&mut self.map, &new_class, self.classes.len());
354        self.classes.push(new_class);
355    }
356
357    /// Removes constant expressions that may change across partitions.
358    /// This method should be used when merging data from different partitions.
359    /// Returns whether any change was made to the equivalence group.
360    pub fn clear_per_partition_constants(&mut self) -> bool {
361        let (mut idx, mut change) = (0, false);
362        while idx < self.classes.len() {
363            let cls = &mut self.classes[idx];
364            if let Some(AcrossPartitions::Heterogeneous) = cls.constant {
365                change = true;
366                if cls.len() == 1 {
367                    // If this class becomes trivial, remove it entirely:
368                    self.remove_class_at_idx(idx);
369                    continue;
370                } else {
371                    cls.constant = None;
372                }
373            }
374            idx += 1;
375        }
376        change
377    }
378
379    /// Adds the equality `left` = `right` to this equivalence group. New
380    /// equality conditions often arise after steps like `Filter(a = b)`,
381    /// `Alias(a, a as b)` etc. Returns whether the given equality defines
382    /// a new equivalence class.
383    pub fn add_equal_conditions(
384        &mut self,
385        left: Arc<dyn PhysicalExpr>,
386        right: Arc<dyn PhysicalExpr>,
387    ) -> bool {
388        let first_class = self.map.get(&left).copied();
389        let second_class = self.map.get(&right).copied();
390        match (first_class, second_class) {
391            (Some(mut first_idx), Some(mut second_idx)) => {
392                // If the given left and right sides belong to different classes,
393                // we should unify/bridge these classes.
394                match first_idx.cmp(&second_idx) {
395                    // The equality is already known, return and signal this:
396                    std::cmp::Ordering::Equal => return false,
397                    // Swap indices to ensure `first_idx` is the lesser index.
398                    std::cmp::Ordering::Greater => {
399                        std::mem::swap(&mut first_idx, &mut second_idx);
400                    }
401                    _ => {}
402                }
403                // Remove the class at `second_idx` and merge its values with
404                // the class at `first_idx`. The convention above makes sure
405                // that `first_idx` is still valid after removing `second_idx`.
406                let other_class = self.remove_class_at_idx(second_idx);
407                // Update the lookup table for the second class:
408                Self::update_lookup_table(&mut self.map, &other_class, first_idx);
409                self.classes[first_idx].extend(other_class);
410            }
411            (Some(group_idx), None) => {
412                // Right side is new, extend left side's class:
413                self.map.insert(Arc::clone(&right), group_idx);
414                self.classes[group_idx].push(right);
415            }
416            (None, Some(group_idx)) => {
417                // Left side is new, extend right side's class:
418                self.map.insert(Arc::clone(&left), group_idx);
419                self.classes[group_idx].push(left);
420            }
421            (None, None) => {
422                // None of the expressions is among existing classes.
423                // Create a new equivalence class and extend the group.
424                let class = EquivalenceClass::new([left, right]);
425                Self::update_lookup_table(&mut self.map, &class, self.classes.len());
426                self.classes.push(class);
427                return true;
428            }
429        }
430        false
431    }
432
433    /// Removes the equivalence class at the given index from this group.
434    fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass {
435        // Remove the class at the given index:
436        let cls = self.classes.swap_remove(idx);
437        // Remove its entries from the lookup table:
438        for expr in cls.iter() {
439            self.map.remove(expr);
440        }
441        // Update the lookup table for the moved class:
442        if idx < self.classes.len() {
443            Self::update_lookup_table(&mut self.map, &self.classes[idx], idx);
444        }
445        cls
446    }
447
448    /// Updates the entry in lookup table for the given equivalence class with
449    /// the given index.
450    fn update_lookup_table(
451        map: &mut HashMap<Arc<dyn PhysicalExpr>, usize>,
452        cls: &EquivalenceClass,
453        idx: usize,
454    ) {
455        for expr in cls.iter() {
456            map.insert(Arc::clone(expr), idx);
457        }
458    }
459
460    /// Removes redundant entries from this group. Returns whether any change
461    /// was made to the equivalence group.
462    fn remove_redundant_entries(&mut self) -> bool {
463        // First, remove trivial equivalence classes:
464        let mut change = false;
465        for idx in (0..self.classes.len()).rev() {
466            if self.classes[idx].is_trivial() {
467                self.remove_class_at_idx(idx);
468                change = true;
469            }
470        }
471        // Then, unify/bridge groups that have common expressions:
472        self.bridge_classes() || change
473    }
474
475    /// This utility function unifies/bridges classes that have common expressions.
476    /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
477    /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
478    /// equal and belong to one class. This utility converts merges such classes.
479    /// Returns whether any change was made to the equivalence group.
480    fn bridge_classes(&mut self) -> bool {
481        let (mut idx, mut change) = (0, false);
482        'scan: while idx < self.classes.len() {
483            for other_idx in (idx + 1..self.classes.len()).rev() {
484                if self.classes[idx].contains_any(&self.classes[other_idx]) {
485                    let extension = self.remove_class_at_idx(other_idx);
486                    Self::update_lookup_table(&mut self.map, &extension, idx);
487                    self.classes[idx].extend(extension);
488                    change = true;
489                    continue 'scan;
490                }
491            }
492            idx += 1;
493        }
494        change
495    }
496
497    /// Extends this equivalence group with the `other` equivalence group.
498    /// Returns whether any equivalence classes were unified/bridged as a
499    /// result of the extension process.
500    pub fn extend(&mut self, other: Self) -> bool {
501        for (idx, cls) in other.classes.iter().enumerate() {
502            // Update the lookup table for the new class:
503            Self::update_lookup_table(&mut self.map, cls, idx);
504        }
505        self.classes.extend(other.classes);
506        self.bridge_classes()
507    }
508
509    /// Normalizes the given physical expression according to this group. The
510    /// expression is replaced with the first (canonical) expression in the
511    /// equivalence class it matches with (if any).
512    pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
513        expr.transform(|expr| {
514            let cls = self.get_equivalence_class(&expr);
515            let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else {
516                return Ok(Transformed::no(expr));
517            };
518            Ok(Transformed::yes(Arc::clone(canonical)))
519        })
520        .data()
521        .unwrap()
522        // The unwrap above is safe because the closure always returns `Ok`.
523    }
524
525    /// Normalizes the given sort expression according to this group. The
526    /// underlying physical expression is replaced with the first expression in
527    /// the equivalence class it matches with (if any). If the underlying
528    /// expression does not belong to any equivalence class in this group,
529    /// returns the sort expression as is.
530    pub fn normalize_sort_expr(
531        &self,
532        mut sort_expr: PhysicalSortExpr,
533    ) -> PhysicalSortExpr {
534        sort_expr.expr = self.normalize_expr(sort_expr.expr);
535        sort_expr
536    }
537
538    /// Normalizes the given sort expressions (i.e. `sort_exprs`) by:
539    /// - Replacing sections that belong to some equivalence class in the
540    ///   with the first entry in the matching equivalence class.
541    /// - Removing expressions that have a constant value.
542    ///
543    /// If columns `a` and `b` are known to be equal, `d` is known to be a
544    /// constant, and `sort_exprs` is `[b ASC, d DESC, c ASC, a ASC]`, this
545    /// function would return `[a ASC, c ASC, a ASC]`.
546    pub fn normalize_sort_exprs<'a>(
547        &'a self,
548        sort_exprs: impl IntoIterator<Item = PhysicalSortExpr> + 'a,
549    ) -> impl Iterator<Item = PhysicalSortExpr> + 'a {
550        sort_exprs
551            .into_iter()
552            .map(|sort_expr| self.normalize_sort_expr(sort_expr))
553            .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none())
554    }
555
556    /// Normalizes the given sort requirement according to this group. The
557    /// underlying physical expression is replaced with the first expression in
558    /// the equivalence class it matches with (if any). If the underlying
559    /// expression does not belong to any equivalence class in this group,
560    /// returns the given sort requirement as is.
561    pub fn normalize_sort_requirement(
562        &self,
563        mut sort_requirement: PhysicalSortRequirement,
564    ) -> PhysicalSortRequirement {
565        sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
566        sort_requirement
567    }
568
569    /// Normalizes the given sort requirements (i.e. `sort_reqs`) by:
570    /// - Replacing sections that belong to some equivalence class in the
571    ///   with the first entry in the matching equivalence class.
572    /// - Removing expressions that have a constant value.
573    ///
574    /// If columns `a` and `b` are known to be equal, `d` is known to be a
575    /// constant, and `sort_reqs` is `[b ASC, d DESC, c ASC, a ASC]`, this
576    /// function would return `[a ASC, c ASC, a ASC]`.
577    pub fn normalize_sort_requirements<'a>(
578        &'a self,
579        sort_reqs: impl IntoIterator<Item = PhysicalSortRequirement> + 'a,
580    ) -> impl Iterator<Item = PhysicalSortRequirement> + 'a {
581        sort_reqs
582            .into_iter()
583            .map(|req| self.normalize_sort_requirement(req))
584            .filter(|req| self.is_expr_constant(&req.expr).is_none())
585    }
586
587    /// Perform an indirect projection of `expr` by consulting the equivalence
588    /// classes.
589    fn project_expr_indirect(
590        aug_mapping: &AugmentedMapping,
591        expr: &Arc<dyn PhysicalExpr>,
592    ) -> Option<Arc<dyn PhysicalExpr>> {
593        // The given expression is not inside the mapping, so we try to project
594        // indirectly using equivalence classes.
595        for (targets, eq_class) in aug_mapping.values() {
596            // If we match an equivalent expression to a source expression in
597            // the mapping, then we can project. For example, if we have the
598            // mapping `(a as a1, a + c)` and the equivalence `a == b`,
599            // expression `b` projects to `a1`.
600            if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) {
601                let (target, _) = targets.first();
602                return Some(Arc::clone(target));
603            }
604        }
605        // Project a non-leaf expression by projecting its children.
606        let children = expr.children();
607        if children.is_empty() {
608            // A leaf expression should be inside the mapping.
609            return None;
610        }
611        children
612            .into_iter()
613            .map(|child| {
614                // First, we try to project children with an exact match. If
615                // we are unable to do this, we consult equivalence classes.
616                if let Some((targets, _)) = aug_mapping.get(child) {
617                    // If we match the source, we can project directly:
618                    let (target, _) = targets.first();
619                    Some(Arc::clone(target))
620                } else {
621                    Self::project_expr_indirect(aug_mapping, child)
622                }
623            })
624            .collect::<Option<Vec<_>>>()
625            .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
626    }
627
628    fn augment_projection_mapping<'a>(
629        &'a self,
630        mapping: &'a ProjectionMapping,
631    ) -> AugmentedMapping<'a> {
632        mapping
633            .iter()
634            .map(|(k, v)| {
635                let eq_class = self.get_equivalence_class(k);
636                (k, (v, eq_class))
637            })
638            .collect()
639    }
640
641    /// Projects `expr` according to the given projection mapping.
642    /// If the resulting expression is invalid after projection, returns `None`.
643    pub fn project_expr(
644        &self,
645        mapping: &ProjectionMapping,
646        expr: &Arc<dyn PhysicalExpr>,
647    ) -> Option<Arc<dyn PhysicalExpr>> {
648        if let Some(targets) = mapping.get(expr) {
649            // If we match the source, we can project directly:
650            let (target, _) = targets.first();
651            Some(Arc::clone(target))
652        } else {
653            let aug_mapping = self.augment_projection_mapping(mapping);
654            Self::project_expr_indirect(&aug_mapping, expr)
655        }
656    }
657
658    /// Projects `expressions` according to the given projection mapping.
659    /// This function is similar to [`Self::project_expr`], but projects multiple
660    /// expressions at once more efficiently than calling `project_expr` for each
661    /// expression.
662    pub fn project_expressions<'a>(
663        &'a self,
664        mapping: &'a ProjectionMapping,
665        expressions: impl IntoIterator<Item = &'a Arc<dyn PhysicalExpr>> + 'a,
666    ) -> impl Iterator<Item = Option<Arc<dyn PhysicalExpr>>> + 'a {
667        let mut aug_mapping = None;
668        expressions.into_iter().map(move |expr| {
669            if let Some(targets) = mapping.get(expr) {
670                // If we match the source, we can project directly:
671                let (target, _) = targets.first();
672                Some(Arc::clone(target))
673            } else {
674                let aug_mapping = aug_mapping
675                    .get_or_insert_with(|| self.augment_projection_mapping(mapping));
676                Self::project_expr_indirect(aug_mapping, expr)
677            }
678        })
679    }
680
681    /// Projects this equivalence group according to the given projection mapping.
682    pub fn project(&self, mapping: &ProjectionMapping) -> Self {
683        let projected_classes = self.iter().map(|cls| {
684            let new_exprs = self.project_expressions(mapping, cls.iter());
685            EquivalenceClass::new(new_exprs.flatten())
686        });
687
688        // The key is the source expression, and the value is the equivalence
689        // class that contains the corresponding target expression.
690        let mut new_constants = vec![];
691        let mut new_classes = IndexMap::<_, EquivalenceClass>::new();
692        for (source, targets) in mapping.iter() {
693            // We need to find equivalent projected expressions. For example,
694            // consider a table with columns `[a, b, c]` with `a` == `b`, and
695            // projection `[a + c, b + c]`. To conclude that `a + c == b + c`,
696            // we first normalize all source expressions in the mapping, then
697            // merge all equivalent expressions into the classes.
698            let normalized_expr = self.normalize_expr(Arc::clone(source));
699            let cls = new_classes.entry(normalized_expr).or_default();
700            for (target, _) in targets.iter() {
701                cls.push(Arc::clone(target));
702            }
703            // Save new constants arising from the projection:
704            if let Some(across) = self.is_expr_constant(source) {
705                for (target, _) in targets.iter() {
706                    let const_expr = ConstExpr::new(Arc::clone(target), across.clone());
707                    new_constants.push(const_expr);
708                }
709            }
710        }
711
712        // Union projected classes with new classes to make up the result:
713        let classes = projected_classes
714            .chain(new_classes.into_values())
715            .filter(|cls| !cls.is_trivial());
716        let mut result = Self::new(classes);
717        // Add new constants arising from the projection to the equivalence group:
718        for constant in new_constants {
719            result.add_constant(constant);
720        }
721        result
722    }
723
724    /// Returns a `Some` value if the expression is constant according to
725    /// equivalence group, and `None` otherwise. The `Some` variant contains
726    /// an `AcrossPartitions` value indicating whether the expression is
727    /// constant across partitions, and its actual value (if available).
728    pub fn is_expr_constant(
729        &self,
730        expr: &Arc<dyn PhysicalExpr>,
731    ) -> Option<AcrossPartitions> {
732        if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
733            return Some(AcrossPartitions::Uniform(Some(lit.value().clone())));
734        }
735        if let Some(cls) = self.get_equivalence_class(expr) {
736            if cls.constant.is_some() {
737                return cls.constant.clone();
738            }
739        }
740        // TODO: This function should be able to return values of non-literal
741        //       complex constants as well; e.g. it should return `8` for the
742        //       expression `3 + 5`, not an unknown `heterogenous` value.
743        let children = expr.children();
744        if children.is_empty() {
745            return None;
746        }
747        for child in children {
748            self.is_expr_constant(child)?;
749        }
750        Some(AcrossPartitions::Heterogeneous)
751    }
752
753    /// Returns the equivalence class containing `expr`. If no equivalence class
754    /// contains `expr`, returns `None`.
755    pub fn get_equivalence_class(
756        &self,
757        expr: &Arc<dyn PhysicalExpr>,
758    ) -> Option<&EquivalenceClass> {
759        self.map.get(expr).map(|idx| &self.classes[*idx])
760    }
761
762    /// Combine equivalence groups of the given join children.
763    pub fn join(
764        &self,
765        right_equivalences: &Self,
766        join_type: &JoinType,
767        left_size: usize,
768        on: &[(PhysicalExprRef, PhysicalExprRef)],
769    ) -> Result<Self> {
770        let group = match join_type {
771            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
772                let mut result = Self::new(
773                    self.iter().cloned().chain(
774                        right_equivalences
775                            .iter()
776                            .map(|cls| cls.try_with_offset(left_size as _))
777                            .collect::<Result<Vec<_>>>()?,
778                    ),
779                );
780                // In we have an inner join, expressions in the "on" condition
781                // are equal in the resulting table.
782                if join_type == &JoinType::Inner {
783                    for (lhs, rhs) in on.iter() {
784                        let new_lhs = Arc::clone(lhs);
785                        // Rewrite rhs to point to the right side of the join:
786                        let new_rhs =
787                            add_offset_to_expr(Arc::clone(rhs), left_size as _)?;
788                        result.add_equal_conditions(new_lhs, new_rhs);
789                    }
790                }
791                result
792            }
793            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
794            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
795                right_equivalences.clone()
796            }
797        };
798        Ok(group)
799    }
800
801    /// Checks if two expressions are equal directly or through equivalence
802    /// classes. For complex expressions (e.g. `a + b`), checks that the
803    /// expression trees are structurally identical and their leaf nodes are
804    /// equivalent either directly or through equivalence classes.
805    pub fn exprs_equal(
806        &self,
807        left: &Arc<dyn PhysicalExpr>,
808        right: &Arc<dyn PhysicalExpr>,
809    ) -> bool {
810        // Direct equality check
811        if left.eq(right) {
812            return true;
813        }
814
815        // Check if expressions are equivalent through equivalence classes
816        // We need to check both directions since expressions might be in different classes
817        if let Some(left_class) = self.get_equivalence_class(left) {
818            if left_class.contains(right) {
819                return true;
820            }
821        }
822        if let Some(right_class) = self.get_equivalence_class(right) {
823            if right_class.contains(left) {
824                return true;
825            }
826        }
827
828        // For non-leaf nodes, check structural equality
829        let left_children = left.children();
830        let right_children = right.children();
831
832        // If either expression is a leaf node and we haven't found equality yet,
833        // they must be different
834        if left_children.is_empty() || right_children.is_empty() {
835            return false;
836        }
837
838        // Type equality check through reflection
839        if left.as_any().type_id() != right.as_any().type_id() {
840            return false;
841        }
842
843        // Check if the number of children is the same
844        if left_children.len() != right_children.len() {
845            return false;
846        }
847
848        // Check if all children are equal
849        left_children
850            .into_iter()
851            .zip(right_children)
852            .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
853    }
854}
855
856impl Deref for EquivalenceGroup {
857    type Target = [EquivalenceClass];
858
859    fn deref(&self) -> &Self::Target {
860        &self.classes
861    }
862}
863
864impl IntoIterator for EquivalenceGroup {
865    type Item = EquivalenceClass;
866    type IntoIter = IntoIter<Self::Item>;
867
868    fn into_iter(self) -> Self::IntoIter {
869        self.classes.into_iter()
870    }
871}
872
873impl Display for EquivalenceGroup {
874    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
875        write!(f, "[")?;
876        let mut iter = self.iter();
877        if let Some(cls) = iter.next() {
878            write!(f, "{cls}")?;
879        }
880        for cls in iter {
881            write!(f, ", {cls}")?;
882        }
883        write!(f, "]")
884    }
885}
886
887impl From<Vec<EquivalenceClass>> for EquivalenceGroup {
888    fn from(classes: Vec<EquivalenceClass>) -> Self {
889        let mut result = Self {
890            map: classes
891                .iter()
892                .enumerate()
893                .flat_map(|(idx, cls)| {
894                    cls.iter().map(move |expr| (Arc::clone(expr), idx))
895                })
896                .collect(),
897            classes,
898        };
899        result.remove_redundant_entries();
900        result
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907    use crate::equivalence::tests::create_test_params;
908    use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal};
909    use arrow::datatypes::{DataType, Field, Schema};
910
911    use datafusion_common::{Result, ScalarValue};
912    use datafusion_expr::Operator;
913
914    #[test]
915    fn test_bridge_groups() -> Result<()> {
916        // First entry in the tuple is argument, second entry is the bridged result
917        let test_cases = vec![
918            // ------- TEST CASE 1 -----------//
919            (
920                vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
921                // Expected is compared with set equality. Order of the specific results may change.
922                vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
923            ),
924            // ------- TEST CASE 2 -----------//
925            (
926                vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
927                // Expected
928                vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
929            ),
930        ];
931        for (entries, expected) in test_cases {
932            let entries = entries
933                .into_iter()
934                .map(|entry| {
935                    entry.into_iter().map(|idx| {
936                        let c = Column::new(format!("col_{idx}").as_str(), idx);
937                        Arc::new(c) as _
938                    })
939                })
940                .map(EquivalenceClass::new)
941                .collect::<Vec<_>>();
942            let expected = expected
943                .into_iter()
944                .map(|entry| {
945                    entry.into_iter().map(|idx| {
946                        let c = Column::new(format!("col_{idx}").as_str(), idx);
947                        Arc::new(c) as _
948                    })
949                })
950                .map(EquivalenceClass::new)
951                .collect::<Vec<_>>();
952            let eq_groups: EquivalenceGroup = entries.clone().into();
953            let eq_groups = eq_groups.classes;
954            let err_msg = format!(
955                "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
956            );
957            assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
958            for idx in 0..eq_groups.len() {
959                assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
960            }
961        }
962        Ok(())
963    }
964
965    #[test]
966    fn test_remove_redundant_entries_eq_group() -> Result<()> {
967        let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _;
968        let entries = [
969            EquivalenceClass::new([c(1), c(1), lit(20)]),
970            EquivalenceClass::new([lit(30), lit(30)]),
971            EquivalenceClass::new([c(2), c(3), c(4)]),
972        ];
973        // Given equivalences classes are not in succinct form.
974        // Expected form is the most plain representation that is functionally same.
975        let expected = [
976            EquivalenceClass::new([c(1), lit(20)]),
977            EquivalenceClass::new([lit(30)]),
978            EquivalenceClass::new([c(2), c(3), c(4)]),
979        ];
980        let eq_groups = EquivalenceGroup::new(entries);
981        assert_eq!(eq_groups.classes, expected);
982        Ok(())
983    }
984
985    #[test]
986    fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
987        let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
988        let col_b = Arc::new(Column::new("b", 1)) as _;
989        let col_c = Arc::new(Column::new("c", 2)) as _;
990        // Assume that column a and c are aliases.
991        let (_, eq_properties) = create_test_params()?;
992        // Test cases for equivalence normalization. First entry in the tuple is
993        // the argument, second entry is expected result after normalization.
994        let expressions = vec![
995            // Normalized version of the column a and c should go to a
996            // (by convention all the expressions inside equivalence class are mapped to the first entry
997            // in this case a is the first entry in the equivalence class.)
998            (Arc::clone(&col_a), Arc::clone(&col_a)),
999            (col_c, col_a),
1000            // Cannot normalize column b
1001            (Arc::clone(&col_b), Arc::clone(&col_b)),
1002        ];
1003        let eq_group = eq_properties.eq_group();
1004        for (expr, expected_eq) in expressions {
1005            assert!(expected_eq.eq(&eq_group.normalize_expr(expr)));
1006        }
1007
1008        Ok(())
1009    }
1010
1011    #[test]
1012    fn test_contains_any() {
1013        let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _;
1014        let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _;
1015        let col_a_expr = Arc::new(Column::new("a", 0)) as _;
1016        let col_b_expr = Arc::new(Column::new("b", 1)) as _;
1017        let col_c_expr = Arc::new(Column::new("c", 2)) as _;
1018
1019        let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]);
1020        let cls2 = EquivalenceClass::new([lit_true, col_b_expr]);
1021        let cls3 = EquivalenceClass::new([col_c_expr, lit_false]);
1022
1023        // lit_true is common
1024        assert!(cls1.contains_any(&cls2));
1025        // there is no common entry
1026        assert!(!cls1.contains_any(&cls3));
1027        assert!(!cls2.contains_any(&cls3));
1028    }
1029
1030    #[test]
1031    fn test_exprs_equal() -> Result<()> {
1032        struct TestCase {
1033            left: Arc<dyn PhysicalExpr>,
1034            right: Arc<dyn PhysicalExpr>,
1035            expected: bool,
1036            description: &'static str,
1037        }
1038
1039        // Create test columns
1040        let col_a = Arc::new(Column::new("a", 0)) as _;
1041        let col_b = Arc::new(Column::new("b", 1)) as _;
1042        let col_x = Arc::new(Column::new("x", 2)) as _;
1043        let col_y = Arc::new(Column::new("y", 3)) as _;
1044
1045        // Create test literals
1046        let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _;
1047        let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _;
1048
1049        // Create equivalence group with classes (a = x) and (b = y)
1050        let eq_group = EquivalenceGroup::new([
1051            EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]),
1052            EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]),
1053        ]);
1054
1055        let test_cases = vec![
1056            // Basic equality tests
1057            TestCase {
1058                left: Arc::clone(&col_a),
1059                right: Arc::clone(&col_a),
1060                expected: true,
1061                description: "Same column should be equal",
1062            },
1063            // Equivalence class tests
1064            TestCase {
1065                left: Arc::clone(&col_a),
1066                right: Arc::clone(&col_x),
1067                expected: true,
1068                description: "Columns in same equivalence class should be equal",
1069            },
1070            TestCase {
1071                left: Arc::clone(&col_b),
1072                right: Arc::clone(&col_y),
1073                expected: true,
1074                description: "Columns in same equivalence class should be equal",
1075            },
1076            TestCase {
1077                left: Arc::clone(&col_a),
1078                right: Arc::clone(&col_b),
1079                expected: false,
1080                description:
1081                    "Columns in different equivalence classes should not be equal",
1082            },
1083            // Literal tests
1084            TestCase {
1085                left: Arc::clone(&lit_1),
1086                right: Arc::clone(&lit_1),
1087                expected: true,
1088                description: "Same literal should be equal",
1089            },
1090            TestCase {
1091                left: Arc::clone(&lit_1),
1092                right: Arc::clone(&lit_2),
1093                expected: false,
1094                description: "Different literals should not be equal",
1095            },
1096            // Complex expression tests
1097            TestCase {
1098                left: Arc::new(BinaryExpr::new(
1099                    Arc::clone(&col_a),
1100                    Operator::Plus,
1101                    Arc::clone(&col_b),
1102                )) as _,
1103                right: Arc::new(BinaryExpr::new(
1104                    Arc::clone(&col_x),
1105                    Operator::Plus,
1106                    Arc::clone(&col_y),
1107                )) as _,
1108                expected: true,
1109                description:
1110                    "Binary expressions with equivalent operands should be equal",
1111            },
1112            TestCase {
1113                left: Arc::new(BinaryExpr::new(
1114                    Arc::clone(&col_a),
1115                    Operator::Plus,
1116                    Arc::clone(&col_b),
1117                )) as _,
1118                right: Arc::new(BinaryExpr::new(
1119                    Arc::clone(&col_x),
1120                    Operator::Plus,
1121                    Arc::clone(&col_a),
1122                )) as _,
1123                expected: false,
1124                description:
1125                    "Binary expressions with non-equivalent operands should not be equal",
1126            },
1127            TestCase {
1128                left: Arc::new(BinaryExpr::new(
1129                    Arc::clone(&col_a),
1130                    Operator::Plus,
1131                    Arc::clone(&lit_1),
1132                )) as _,
1133                right: Arc::new(BinaryExpr::new(
1134                    Arc::clone(&col_x),
1135                    Operator::Plus,
1136                    Arc::clone(&lit_1),
1137                )) as _,
1138                expected: true,
1139                description: "Binary expressions with equivalent column and same literal should be equal",
1140            },
1141            TestCase {
1142                left: Arc::new(BinaryExpr::new(
1143                    Arc::new(BinaryExpr::new(
1144                        Arc::clone(&col_a),
1145                        Operator::Plus,
1146                        Arc::clone(&col_b),
1147                    )),
1148                    Operator::Multiply,
1149                    Arc::clone(&lit_1),
1150                )) as _,
1151                right: Arc::new(BinaryExpr::new(
1152                    Arc::new(BinaryExpr::new(
1153                        Arc::clone(&col_x),
1154                        Operator::Plus,
1155                        Arc::clone(&col_y),
1156                    )),
1157                    Operator::Multiply,
1158                    Arc::clone(&lit_1),
1159                )) as _,
1160                expected: true,
1161                description: "Nested binary expressions with equivalent operands should be equal",
1162            },
1163        ];
1164
1165        for TestCase {
1166            left,
1167            right,
1168            expected,
1169            description,
1170        } in test_cases
1171        {
1172            let actual = eq_group.exprs_equal(&left, &right);
1173            assert_eq!(
1174                actual, expected,
1175                "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1176            );
1177        }
1178
1179        Ok(())
1180    }
1181
1182    #[test]
1183    fn test_project_classes() -> Result<()> {
1184        // - columns: [a, b, c].
1185        // - "a" and "b" in the same equivalence class.
1186        // - then after a+c, b+c projection col(0) and col(1) must be
1187        // in the same class too.
1188        let schema = Arc::new(Schema::new(vec![
1189            Field::new("a", DataType::Int32, false),
1190            Field::new("b", DataType::Int32, false),
1191            Field::new("c", DataType::Int32, false),
1192        ]));
1193        let mut group = EquivalenceGroup::default();
1194        group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?);
1195
1196        let projected_schema = Arc::new(Schema::new(vec![
1197            Field::new("a+c", DataType::Int32, false),
1198            Field::new("b+c", DataType::Int32, false),
1199        ]));
1200
1201        let mapping = [
1202            (
1203                binary(
1204                    col("a", &schema)?,
1205                    Operator::Plus,
1206                    col("c", &schema)?,
1207                    &schema,
1208                )?,
1209                vec![(col("a+c", &projected_schema)?, 0)].into(),
1210            ),
1211            (
1212                binary(
1213                    col("b", &schema)?,
1214                    Operator::Plus,
1215                    col("c", &schema)?,
1216                    &schema,
1217                )?,
1218                vec![(col("b+c", &projected_schema)?, 1)].into(),
1219            ),
1220        ]
1221        .into_iter()
1222        .collect::<ProjectionMapping>();
1223
1224        let projected = group.project(&mapping);
1225
1226        assert!(!projected.is_empty());
1227        let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1228        let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1229
1230        assert!(first_normalized.eq(&second_normalized));
1231
1232        Ok(())
1233    }
1234}