Skip to main content

datafusion_physical_expr/equivalence/
class.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Display;
20use std::ops::Deref;
21use std::sync::Arc;
22use std::vec::IntoIter;
23
24use super::ProjectionMapping;
25use crate::expressions::Literal;
26use crate::physical_expr::add_offset_to_expr;
27use crate::projection::ProjectionTargets;
28use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement};
29
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::{JoinType, Result, ScalarValue};
32use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
33
34use indexmap::{IndexMap, IndexSet};
35
36/// Represents whether a constant expression's value is uniform or varies across
37/// partitions. Has two variants:
38/// - `Heterogeneous`: The constant expression may have different values for
39///   different partitions.
40/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value
41///   across all partitions, or is `None` if the value is unknown.
42#[derive(Clone, Debug, Default, Eq, PartialEq)]
43pub enum AcrossPartitions {
44    #[default]
45    Heterogeneous,
46    Uniform(Option<ScalarValue>),
47}
48
49impl Display for AcrossPartitions {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"),
53            AcrossPartitions::Uniform(value) => {
54                if let Some(val) = value {
55                    write!(f, "(uniform: {val})")
56                } else {
57                    write!(f, "(uniform: unknown)")
58                }
59            }
60        }
61    }
62}
63
64/// A structure representing a expression known to be constant in a physical
65/// execution plan.
66///
67/// The `ConstExpr` struct encapsulates an expression that is constant during
68/// the execution of a query. For example if a filter like `A = 5` appears
69/// earlier in the plan, `A` would become a constant in subsequent operations.
70///
71/// # Fields
72///
73/// - `expr`: Constant expression for a node in the physical plan.
74/// - `across_partitions`: A boolean flag indicating whether the constant
75///   expression is the same across partitions. If set to `true`, the constant
76///   expression has same value for all partitions. If set to `false`, the
77///   constant expression may have different values for different partitions.
78///
79/// # Example
80///
81/// ```rust
82/// # use datafusion_physical_expr::ConstExpr;
83/// # use datafusion_physical_expr::expressions::lit;
84/// let col = lit(5);
85/// // Create a constant expression from a physical expression:
86/// let const_expr = ConstExpr::from(col);
87/// ```
88#[derive(Clone, Debug)]
89pub struct ConstExpr {
90    /// The expression that is known to be constant (e.g. a `Column`).
91    pub expr: Arc<dyn PhysicalExpr>,
92    /// Indicates whether the constant have the same value across all partitions.
93    pub across_partitions: AcrossPartitions,
94}
95// TODO: The `ConstExpr` definition above can be in an inconsistent state where
96//       `expr` is a literal but `across_partitions` is not `Uniform`. Consider
97//       a refactor to ensure that `ConstExpr` is always in a consistent state
98//       (either by changing type definition, or by API constraints).
99
100impl ConstExpr {
101    /// Create a new constant expression from a physical expression, specifying
102    /// whether the constant expression is the same across partitions.
103    ///
104    /// Note that you can also use `ConstExpr::from` to create a constant
105    /// expression from just a physical expression, with the *safe* assumption
106    /// of heterogenous values across partitions unless the expression is a
107    /// literal.
108    pub fn new(expr: Arc<dyn PhysicalExpr>, across_partitions: AcrossPartitions) -> Self {
109        let mut result = ConstExpr::from(expr);
110        // Override the across partitions specification if the expression is not
111        // a literal.
112        if result.across_partitions == AcrossPartitions::Heterogeneous {
113            result.across_partitions = across_partitions;
114        }
115        result
116    }
117
118    /// Returns a [`Display`]able list of `ConstExpr`.
119    pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
120        struct DisplayableList<'a>(&'a [ConstExpr]);
121        impl Display for DisplayableList<'_> {
122            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
123                let mut first = true;
124                for const_expr in self.0 {
125                    if first {
126                        first = false;
127                    } else {
128                        write!(f, ",")?;
129                    }
130                    write!(f, "{const_expr}")?;
131                }
132                Ok(())
133            }
134        }
135        DisplayableList(input)
136    }
137}
138
139impl PartialEq for ConstExpr {
140    fn eq(&self, other: &Self) -> bool {
141        self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
142    }
143}
144
145impl Display for ConstExpr {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        write!(f, "{}", self.expr)?;
148        write!(f, "{}", self.across_partitions)
149    }
150}
151
152impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
153    fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
154        // By default, assume constant expressions are not same across partitions.
155        // However, if we have a literal, it will have a single value that is the
156        // same across all partitions.
157        let across = if let Some(lit) = expr.downcast_ref::<Literal>() {
158            AcrossPartitions::Uniform(Some(lit.value().clone()))
159        } else {
160            AcrossPartitions::Heterogeneous
161        };
162        Self {
163            expr,
164            across_partitions: across,
165        }
166    }
167}
168
169/// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known
170/// to have the same value for all tuples in a relation. These are generated by
171/// equality predicates (e.g. `a = b`), typically equi-join conditions and
172/// equality conditions in filters.
173///
174/// Two `EquivalenceClass`es are equal if they contains the same expressions in
175/// without any ordering.
176#[derive(Clone, Debug, Default, Eq, PartialEq)]
177pub struct EquivalenceClass {
178    /// The expressions in this equivalence class. The order doesn't matter for
179    /// equivalence purposes.
180    pub(crate) exprs: IndexSet<Arc<dyn PhysicalExpr>>,
181    /// Indicates whether the expressions in this equivalence class have a
182    /// constant value. A `Some` value indicates constant-ness.
183    pub(crate) constant: Option<AcrossPartitions>,
184}
185
186impl EquivalenceClass {
187    // Create a new equivalence class from a pre-existing collection.
188    pub fn new(exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>) -> Self {
189        let mut class = Self::default();
190        for expr in exprs {
191            class.push(expr);
192        }
193        class
194    }
195
196    /// Return the "canonical" expression for this class (the first element)
197    /// if non-empty.
198    pub fn canonical_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
199        self.exprs.iter().next()
200    }
201
202    /// Insert the expression into this class, meaning it is known to be equal to
203    /// all other expressions in this class.
204    pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
205        if let Some(lit) = expr.downcast_ref::<Literal>() {
206            let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone()));
207            if let Some(across) = self.constant.as_mut() {
208                // TODO: Return an error if constant values do not agree.
209                if *across == AcrossPartitions::Heterogeneous {
210                    *across = expr_across;
211                }
212            } else {
213                self.constant = Some(expr_across);
214            }
215        }
216        self.exprs.insert(expr);
217    }
218
219    /// Inserts all the expressions from other into this class.
220    pub fn extend(&mut self, other: Self) {
221        self.exprs.extend(other.exprs);
222        match (&self.constant, &other.constant) {
223            (Some(across), Some(_)) => {
224                // TODO: Return an error if constant values do not agree.
225                if across == &AcrossPartitions::Heterogeneous {
226                    self.constant = other.constant;
227                }
228            }
229            (None, Some(_)) => self.constant = other.constant,
230            (_, None) => {}
231        }
232    }
233
234    /// Returns whether this equivalence class has any entries in common with
235    /// `other`.
236    pub fn contains_any(&self, other: &Self) -> bool {
237        self.exprs.intersection(&other.exprs).next().is_some()
238    }
239
240    /// Returns whether this equivalence class is trivial, meaning that it is
241    /// either empty, or contains a single expression that is not a constant.
242    /// Such classes are not useful, and can be removed from equivalence groups.
243    pub fn is_trivial(&self) -> bool {
244        self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none())
245    }
246
247    /// Adds the given offset to all columns in the expressions inside this
248    /// class. This is used when schemas are appended, e.g. in joins.
249    pub fn try_with_offset(&self, offset: isize) -> Result<Self> {
250        let mut cls = Self::default();
251        for expr_result in self
252            .exprs
253            .iter()
254            .cloned()
255            .map(|e| add_offset_to_expr(e, offset))
256        {
257            cls.push(expr_result?);
258        }
259        Ok(cls)
260    }
261}
262
263impl Deref for EquivalenceClass {
264    type Target = IndexSet<Arc<dyn PhysicalExpr>>;
265
266    fn deref(&self) -> &Self::Target {
267        &self.exprs
268    }
269}
270
271impl IntoIterator for EquivalenceClass {
272    type Item = Arc<dyn PhysicalExpr>;
273    type IntoIter = <IndexSet<Self::Item> as IntoIterator>::IntoIter;
274
275    fn into_iter(self) -> Self::IntoIter {
276        self.exprs.into_iter()
277    }
278}
279
280impl Display for EquivalenceClass {
281    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
282        write!(f, "{{")?;
283        write!(f, "members: {}", format_physical_expr_list(&self.exprs))?;
284        if let Some(across) = &self.constant {
285            write!(f, ", constant: {across}")?;
286        }
287        write!(f, "}}")
288    }
289}
290
291impl From<EquivalenceClass> for Vec<Arc<dyn PhysicalExpr>> {
292    fn from(cls: EquivalenceClass) -> Self {
293        cls.exprs.into_iter().collect()
294    }
295}
296
297type AugmentedMapping<'a> = IndexMap<
298    &'a Arc<dyn PhysicalExpr>,
299    (&'a ProjectionTargets, Option<&'a EquivalenceClass>),
300>;
301
302/// A collection of distinct `EquivalenceClass`es. This object supports fast
303/// lookups of expressions and their equivalence classes.
304#[derive(Clone, Debug, Default)]
305pub struct EquivalenceGroup {
306    /// A mapping from expressions to their equivalence class key.
307    map: IndexMap<Arc<dyn PhysicalExpr>, usize>,
308    /// The equivalence classes in this group.
309    classes: Vec<EquivalenceClass>,
310}
311
312impl EquivalenceGroup {
313    /// Creates an equivalence group from the given equivalence classes.
314    pub fn new(classes: impl IntoIterator<Item = EquivalenceClass>) -> Self {
315        classes.into_iter().collect::<Vec<_>>().into()
316    }
317
318    /// Adds `expr` as a constant expression to this equivalence group.
319    pub fn add_constant(&mut self, const_expr: ConstExpr) {
320        // If the expression is already in an equivalence class, we should
321        // adjust the constant-ness of the class if necessary:
322        if let Some(idx) = self.map.get(&const_expr.expr) {
323            let cls = &mut self.classes[*idx];
324            if let Some(across) = cls.constant.as_mut() {
325                // TODO: Return an error if constant values do not agree.
326                if *across == AcrossPartitions::Heterogeneous {
327                    *across = const_expr.across_partitions;
328                }
329            } else {
330                cls.constant = Some(const_expr.across_partitions);
331            }
332            return;
333        }
334        // If the expression is not in any equivalence class, but has the same
335        // constant value with some class, add it to that class:
336        if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions {
337            for (idx, cls) in self.classes.iter_mut().enumerate() {
338                if cls
339                    .constant
340                    .as_ref()
341                    .is_some_and(|across| const_expr.across_partitions.eq(across))
342                {
343                    self.map.insert(Arc::clone(&const_expr.expr), idx);
344                    cls.push(const_expr.expr);
345                    return;
346                }
347            }
348        }
349        // Otherwise, create a new class with the expression as the only member:
350        let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr));
351        if new_class.constant.is_none() {
352            new_class.constant = Some(const_expr.across_partitions);
353        }
354        Self::update_lookup_table(&mut self.map, &new_class, self.classes.len());
355        self.classes.push(new_class);
356    }
357
358    /// Removes constant expressions that may change across partitions.
359    /// This method should be used when merging data from different partitions.
360    /// Returns whether any change was made to the equivalence group.
361    pub fn clear_per_partition_constants(&mut self) -> bool {
362        let (mut idx, mut change) = (0, false);
363        while idx < self.classes.len() {
364            let cls = &mut self.classes[idx];
365            if let Some(AcrossPartitions::Heterogeneous) = cls.constant {
366                change = true;
367                if cls.len() == 1 {
368                    // If this class becomes trivial, remove it entirely:
369                    self.remove_class_at_idx(idx);
370                    continue;
371                } else {
372                    cls.constant = None;
373                }
374            }
375            idx += 1;
376        }
377        change
378    }
379
380    /// Adds the equality `left` = `right` to this equivalence group. New
381    /// equality conditions often arise after steps like `Filter(a = b)`,
382    /// `Alias(a, a as b)` etc. Returns whether the given equality defines
383    /// a new equivalence class.
384    pub fn add_equal_conditions(
385        &mut self,
386        left: Arc<dyn PhysicalExpr>,
387        right: Arc<dyn PhysicalExpr>,
388    ) -> bool {
389        let first_class = self.map.get(&left).copied();
390        let second_class = self.map.get(&right).copied();
391        match (first_class, second_class) {
392            (Some(mut first_idx), Some(mut second_idx)) => {
393                // If the given left and right sides belong to different classes,
394                // we should unify/bridge these classes.
395                match first_idx.cmp(&second_idx) {
396                    // The equality is already known, return and signal this:
397                    std::cmp::Ordering::Equal => return false,
398                    // Swap indices to ensure `first_idx` is the lesser index.
399                    std::cmp::Ordering::Greater => {
400                        std::mem::swap(&mut first_idx, &mut second_idx);
401                    }
402                    _ => {}
403                }
404                // Remove the class at `second_idx` and merge its values with
405                // the class at `first_idx`. The convention above makes sure
406                // that `first_idx` is still valid after removing `second_idx`.
407                let other_class = self.remove_class_at_idx(second_idx);
408                // Update the lookup table for the second class:
409                Self::update_lookup_table(&mut self.map, &other_class, first_idx);
410                self.classes[first_idx].extend(other_class);
411            }
412            (Some(group_idx), None) => {
413                // Right side is new, extend left side's class:
414                self.map.insert(Arc::clone(&right), group_idx);
415                self.classes[group_idx].push(right);
416            }
417            (None, Some(group_idx)) => {
418                // Left side is new, extend right side's class:
419                self.map.insert(Arc::clone(&left), group_idx);
420                self.classes[group_idx].push(left);
421            }
422            (None, None) => {
423                // None of the expressions is among existing classes.
424                // Create a new equivalence class and extend the group.
425                let class = EquivalenceClass::new([left, right]);
426                Self::update_lookup_table(&mut self.map, &class, self.classes.len());
427                self.classes.push(class);
428                return true;
429            }
430        }
431        false
432    }
433
434    /// Removes the equivalence class at the given index from this group.
435    fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass {
436        // Remove the class at the given index:
437        let cls = self.classes.swap_remove(idx);
438        // Remove its entries from the lookup table:
439        for expr in cls.iter() {
440            self.map.swap_remove(expr);
441        }
442        // Update the lookup table for the moved class:
443        if idx < self.classes.len() {
444            Self::update_lookup_table(&mut self.map, &self.classes[idx], idx);
445        }
446        cls
447    }
448
449    /// Updates the entry in lookup table for the given equivalence class with
450    /// the given index.
451    fn update_lookup_table(
452        map: &mut IndexMap<Arc<dyn PhysicalExpr>, usize>,
453        cls: &EquivalenceClass,
454        idx: usize,
455    ) {
456        for expr in cls.iter() {
457            map.insert(Arc::clone(expr), idx);
458        }
459    }
460
461    /// Removes redundant entries from this group. Returns whether any change
462    /// was made to the equivalence group.
463    fn remove_redundant_entries(&mut self) -> bool {
464        // First, remove trivial equivalence classes:
465        let mut change = false;
466        for idx in (0..self.classes.len()).rev() {
467            if self.classes[idx].is_trivial() {
468                self.remove_class_at_idx(idx);
469                change = true;
470            }
471        }
472        // Then, unify/bridge groups that have common expressions:
473        self.bridge_classes() || change
474    }
475
476    /// This utility function unifies/bridges classes that have common expressions.
477    /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
478    /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
479    /// equal and belong to one class. This utility converts merges such classes.
480    /// Returns whether any change was made to the equivalence group.
481    fn bridge_classes(&mut self) -> bool {
482        let (mut idx, mut change) = (0, false);
483        'scan: while idx < self.classes.len() {
484            for other_idx in (idx + 1..self.classes.len()).rev() {
485                if self.classes[idx].contains_any(&self.classes[other_idx]) {
486                    let extension = self.remove_class_at_idx(other_idx);
487                    Self::update_lookup_table(&mut self.map, &extension, idx);
488                    self.classes[idx].extend(extension);
489                    change = true;
490                    continue 'scan;
491                }
492            }
493            idx += 1;
494        }
495        change
496    }
497
498    /// Extends this equivalence group with the `other` equivalence group.
499    /// Returns whether any equivalence classes were unified/bridged as a
500    /// result of the extension process.
501    pub fn extend(&mut self, other: Self) -> bool {
502        for (idx, cls) in other.classes.iter().enumerate() {
503            // Update the lookup table for the new class:
504            Self::update_lookup_table(&mut self.map, cls, idx);
505        }
506        self.classes.extend(other.classes);
507        self.bridge_classes()
508    }
509
510    /// Normalizes the given physical expression according to this group. The
511    /// expression is replaced with the first (canonical) expression in the
512    /// equivalence class it matches with (if any).
513    pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
514        expr.transform(|expr| {
515            let cls = self.get_equivalence_class(&expr);
516            let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else {
517                return Ok(Transformed::no(expr));
518            };
519            Ok(Transformed::yes(Arc::clone(canonical)))
520        })
521        .data()
522        .unwrap()
523        // The unwrap above is safe because the closure always returns `Ok`.
524    }
525
526    /// Normalizes the given sort expression according to this group. The
527    /// underlying physical expression is replaced with the first expression in
528    /// the equivalence class it matches with (if any). If the underlying
529    /// expression does not belong to any equivalence class in this group,
530    /// returns the sort expression as is.
531    pub fn normalize_sort_expr(
532        &self,
533        mut sort_expr: PhysicalSortExpr,
534    ) -> PhysicalSortExpr {
535        sort_expr.expr = self.normalize_expr(sort_expr.expr);
536        sort_expr
537    }
538
539    /// Normalizes the given sort expressions (i.e. `sort_exprs`) by:
540    /// - Replacing sections that belong to some equivalence class in the
541    ///   with the first entry in the matching equivalence class.
542    /// - Removing expressions that have a constant value.
543    ///
544    /// If columns `a` and `b` are known to be equal, `d` is known to be a
545    /// constant, and `sort_exprs` is `[b ASC, d DESC, c ASC, a ASC]`, this
546    /// function would return `[a ASC, c ASC, a ASC]`.
547    pub fn normalize_sort_exprs<'a>(
548        &'a self,
549        sort_exprs: impl IntoIterator<Item = PhysicalSortExpr> + 'a,
550    ) -> impl Iterator<Item = PhysicalSortExpr> + 'a {
551        sort_exprs
552            .into_iter()
553            .map(|sort_expr| self.normalize_sort_expr(sort_expr))
554            .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none())
555    }
556
557    /// Normalizes the given sort requirement according to this group. The
558    /// underlying physical expression is replaced with the first expression in
559    /// the equivalence class it matches with (if any). If the underlying
560    /// expression does not belong to any equivalence class in this group,
561    /// returns the given sort requirement as is.
562    pub fn normalize_sort_requirement(
563        &self,
564        mut sort_requirement: PhysicalSortRequirement,
565    ) -> PhysicalSortRequirement {
566        sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
567        sort_requirement
568    }
569
570    /// Normalizes the given sort requirements (i.e. `sort_reqs`) by:
571    /// - Replacing sections that belong to some equivalence class in the
572    ///   with the first entry in the matching equivalence class.
573    /// - Removing expressions that have a constant value.
574    ///
575    /// If columns `a` and `b` are known to be equal, `d` is known to be a
576    /// constant, and `sort_reqs` is `[b ASC, d DESC, c ASC, a ASC]`, this
577    /// function would return `[a ASC, c ASC, a ASC]`.
578    pub fn normalize_sort_requirements<'a>(
579        &'a self,
580        sort_reqs: impl IntoIterator<Item = PhysicalSortRequirement> + 'a,
581    ) -> impl Iterator<Item = PhysicalSortRequirement> + 'a {
582        sort_reqs
583            .into_iter()
584            .map(|req| self.normalize_sort_requirement(req))
585            .filter(|req| self.is_expr_constant(&req.expr).is_none())
586    }
587
588    /// Perform an indirect projection of `expr` by consulting the equivalence
589    /// classes.
590    fn project_expr_indirect(
591        aug_mapping: &AugmentedMapping,
592        expr: &Arc<dyn PhysicalExpr>,
593    ) -> Option<Arc<dyn PhysicalExpr>> {
594        // Literals don't need to be projected
595        if expr.downcast_ref::<Literal>().is_some() {
596            return Some(Arc::clone(expr));
597        }
598
599        // The given expression is not inside the mapping, so we try to project
600        // indirectly using equivalence classes.
601        for (targets, eq_class) in aug_mapping.values() {
602            // If we match an equivalent expression to a source expression in
603            // the mapping, then we can project. For example, if we have the
604            // mapping `(a as a1, a + c)` and the equivalence `a == b`,
605            // expression `b` projects to `a1`.
606            if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) {
607                let (target, _) = targets.first();
608                return Some(Arc::clone(target));
609            }
610        }
611        // Project a non-leaf expression by projecting its children.
612        let children = expr.children();
613        if children.is_empty() {
614            // A leaf expression should be inside the mapping.
615            return None;
616        }
617        children
618            .into_iter()
619            .map(|child| {
620                // First, we try to project children with an exact match. If
621                // we are unable to do this, we consult equivalence classes.
622                if let Some((targets, _)) = aug_mapping.get(child) {
623                    // If we match the source, we can project directly:
624                    let (target, _) = targets.first();
625                    Some(Arc::clone(target))
626                } else {
627                    Self::project_expr_indirect(aug_mapping, child)
628                }
629            })
630            .collect::<Option<Vec<_>>>()
631            .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
632    }
633
634    fn augment_projection_mapping<'a>(
635        &'a self,
636        mapping: &'a ProjectionMapping,
637    ) -> AugmentedMapping<'a> {
638        mapping
639            .iter()
640            .map(|(k, v)| {
641                let eq_class = self.get_equivalence_class(k);
642                (k, (v, eq_class))
643            })
644            .collect()
645    }
646
647    /// Projects `expr` according to the given projection mapping.
648    /// If the resulting expression is invalid after projection, returns `None`.
649    pub fn project_expr(
650        &self,
651        mapping: &ProjectionMapping,
652        expr: &Arc<dyn PhysicalExpr>,
653    ) -> Option<Arc<dyn PhysicalExpr>> {
654        if let Some(targets) = mapping.get(expr) {
655            // If we match the source, we can project directly:
656            let (target, _) = targets.first();
657            Some(Arc::clone(target))
658        } else {
659            let aug_mapping = self.augment_projection_mapping(mapping);
660            Self::project_expr_indirect(&aug_mapping, expr)
661        }
662    }
663
664    /// Projects `expressions` according to the given projection mapping.
665    /// This function is similar to [`Self::project_expr`], but projects multiple
666    /// expressions at once more efficiently than calling `project_expr` for each
667    /// expression.
668    pub fn project_expressions<'a>(
669        &'a self,
670        mapping: &'a ProjectionMapping,
671        expressions: impl IntoIterator<Item = &'a Arc<dyn PhysicalExpr>> + 'a,
672    ) -> impl Iterator<Item = Option<Arc<dyn PhysicalExpr>>> + 'a {
673        let mut aug_mapping = None;
674        expressions.into_iter().map(move |expr| {
675            if let Some(targets) = mapping.get(expr) {
676                // If we match the source, we can project directly:
677                let (target, _) = targets.first();
678                Some(Arc::clone(target))
679            } else {
680                let aug_mapping = aug_mapping
681                    .get_or_insert_with(|| self.augment_projection_mapping(mapping));
682                Self::project_expr_indirect(aug_mapping, expr)
683            }
684        })
685    }
686
687    /// Projects this equivalence group according to the given projection mapping.
688    pub fn project(&self, mapping: &ProjectionMapping) -> Self {
689        let projected_classes = self.iter().map(|cls| {
690            let new_exprs = self.project_expressions(mapping, cls.iter());
691            EquivalenceClass::new(new_exprs.flatten())
692        });
693
694        // The key is the source expression, and the value is the equivalence
695        // class that contains the corresponding target expression.
696        let mut new_constants = vec![];
697        let mut new_classes = IndexMap::<_, EquivalenceClass>::new();
698        for (source, targets) in mapping.iter() {
699            // We need to find equivalent projected expressions. For example,
700            // consider a table with columns `[a, b, c]` with `a` == `b`, and
701            // projection `[a + c, b + c]`. To conclude that `a + c == b + c`,
702            // we first normalize all source expressions in the mapping, then
703            // merge all equivalent expressions into the classes.
704            let normalized_expr = self.normalize_expr(Arc::clone(source));
705            let cls = new_classes.entry(normalized_expr).or_default();
706            for (target, _) in targets.iter() {
707                cls.push(Arc::clone(target));
708            }
709            // Save new constants arising from the projection:
710            if let Some(across) = self.is_expr_constant(source) {
711                for (target, _) in targets.iter() {
712                    let const_expr = ConstExpr::new(Arc::clone(target), across.clone());
713                    new_constants.push(const_expr);
714                }
715            }
716        }
717
718        // Union projected classes with new classes to make up the result:
719        let classes = projected_classes
720            .chain(new_classes.into_values())
721            .filter(|cls| !cls.is_trivial());
722        let mut result = Self::new(classes);
723        // Add new constants arising from the projection to the equivalence group:
724        for constant in new_constants {
725            result.add_constant(constant);
726        }
727        result
728    }
729
730    /// Returns a `Some` value if the expression is constant according to
731    /// equivalence group, and `None` otherwise. The `Some` variant contains
732    /// an `AcrossPartitions` value indicating whether the expression is
733    /// constant across partitions, and its actual value (if available).
734    pub fn is_expr_constant(
735        &self,
736        expr: &Arc<dyn PhysicalExpr>,
737    ) -> Option<AcrossPartitions> {
738        if let Some(lit) = expr.downcast_ref::<Literal>() {
739            return Some(AcrossPartitions::Uniform(Some(lit.value().clone())));
740        }
741        if let Some(cls) = self.get_equivalence_class(expr)
742            && cls.constant.is_some()
743        {
744            return cls.constant.clone();
745        }
746        // TODO: This function should be able to return values of non-literal
747        //       complex constants as well; e.g. it should return `8` for the
748        //       expression `3 + 5`, not an unknown `heterogenous` value.
749        let children = expr.children();
750        if children.is_empty() {
751            return None;
752        }
753        for child in children {
754            self.is_expr_constant(child)?;
755        }
756        Some(AcrossPartitions::Heterogeneous)
757    }
758
759    /// Returns the equivalence class containing `expr`. If no equivalence class
760    /// contains `expr`, returns `None`.
761    pub fn get_equivalence_class(
762        &self,
763        expr: &Arc<dyn PhysicalExpr>,
764    ) -> Option<&EquivalenceClass> {
765        self.map.get(expr).map(|idx| &self.classes[*idx])
766    }
767
768    /// Combine equivalence groups of the given join children.
769    pub fn join(
770        &self,
771        right_equivalences: &Self,
772        join_type: &JoinType,
773        left_size: usize,
774        on: &[(PhysicalExprRef, PhysicalExprRef)],
775    ) -> Result<Self> {
776        let group = match join_type {
777            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
778                let mut result = Self::new(
779                    self.iter().cloned().chain(
780                        right_equivalences
781                            .iter()
782                            .map(|cls| cls.try_with_offset(left_size as _))
783                            .collect::<Result<Vec<_>>>()?,
784                    ),
785                );
786                // In we have an inner join, expressions in the "on" condition
787                // are equal in the resulting table.
788                if join_type == &JoinType::Inner {
789                    for (lhs, rhs) in on.iter() {
790                        let new_lhs = Arc::clone(lhs);
791                        // Rewrite rhs to point to the right side of the join:
792                        let new_rhs =
793                            add_offset_to_expr(Arc::clone(rhs), left_size as _)?;
794                        result.add_equal_conditions(new_lhs, new_rhs);
795                    }
796                }
797                result
798            }
799            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
800            JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
801                right_equivalences.clone()
802            }
803        };
804        Ok(group)
805    }
806
807    /// Checks if two expressions are equal directly or through equivalence
808    /// classes. For complex expressions (e.g. `a + b`), checks that the
809    /// expression trees are structurally identical and their leaf nodes are
810    /// equivalent either directly or through equivalence classes.
811    pub fn exprs_equal(
812        &self,
813        left: &Arc<dyn PhysicalExpr>,
814        right: &Arc<dyn PhysicalExpr>,
815    ) -> bool {
816        // Direct equality check
817        if left.eq(right) {
818            return true;
819        }
820
821        // Check if expressions are equivalent through equivalence classes
822        // We need to check both directions since expressions might be in different classes
823        if let Some(left_class) = self.get_equivalence_class(left)
824            && left_class.contains(right)
825        {
826            return true;
827        }
828        if let Some(right_class) = self.get_equivalence_class(right)
829            && right_class.contains(left)
830        {
831            return true;
832        }
833
834        // For non-leaf nodes, check structural equality
835        let left_children = left.children();
836        let right_children = right.children();
837
838        // If either expression is a leaf node and we haven't found equality yet,
839        // they must be different
840        if left_children.is_empty() || right_children.is_empty() {
841            return false;
842        }
843
844        // Type equality check through reflection
845        if (left as &dyn Any).type_id() != (right as &dyn Any).type_id() {
846            return false;
847        }
848
849        // Check if the number of children is the same
850        if left_children.len() != right_children.len() {
851            return false;
852        }
853
854        // Check if all children are equal
855        left_children
856            .into_iter()
857            .zip(right_children)
858            .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
859    }
860}
861
862impl Deref for EquivalenceGroup {
863    type Target = [EquivalenceClass];
864
865    fn deref(&self) -> &Self::Target {
866        &self.classes
867    }
868}
869
870impl IntoIterator for EquivalenceGroup {
871    type Item = EquivalenceClass;
872    type IntoIter = IntoIter<Self::Item>;
873
874    fn into_iter(self) -> Self::IntoIter {
875        self.classes.into_iter()
876    }
877}
878
879impl Display for EquivalenceGroup {
880    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
881        write!(f, "[")?;
882        let mut iter = self.iter();
883        if let Some(cls) = iter.next() {
884            write!(f, "{cls}")?;
885        }
886        for cls in iter {
887            write!(f, ", {cls}")?;
888        }
889        write!(f, "]")
890    }
891}
892
893impl From<Vec<EquivalenceClass>> for EquivalenceGroup {
894    fn from(classes: Vec<EquivalenceClass>) -> Self {
895        let mut result = Self {
896            map: classes
897                .iter()
898                .enumerate()
899                .flat_map(|(idx, cls)| {
900                    cls.iter().map(move |expr| (Arc::clone(expr), idx))
901                })
902                .collect(),
903            classes,
904        };
905        result.remove_redundant_entries();
906        result
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913    use crate::equivalence::tests::create_test_params;
914    use crate::expressions::{BinaryExpr, Column, binary, col, lit};
915    use arrow::datatypes::{DataType, Field, Schema};
916
917    use datafusion_expr::Operator;
918
919    #[test]
920    fn test_bridge_groups() -> Result<()> {
921        // First entry in the tuple is argument, second entry is the bridged result
922        let test_cases = vec![
923            // ------- TEST CASE 1 -----------//
924            (
925                vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
926                // Expected is compared with set equality. Order of the specific results may change.
927                vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
928            ),
929            // ------- TEST CASE 2 -----------//
930            (
931                vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
932                // Expected
933                vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
934            ),
935        ];
936        for (entries, expected) in test_cases {
937            let entries = entries
938                .into_iter()
939                .map(|entry| {
940                    entry.into_iter().map(|idx| {
941                        let c = Column::new(format!("col_{idx}").as_str(), idx);
942                        Arc::new(c) as _
943                    })
944                })
945                .map(EquivalenceClass::new)
946                .collect::<Vec<_>>();
947            let expected = expected
948                .into_iter()
949                .map(|entry| {
950                    entry.into_iter().map(|idx| {
951                        let c = Column::new(format!("col_{idx}").as_str(), idx);
952                        Arc::new(c) as _
953                    })
954                })
955                .map(EquivalenceClass::new)
956                .collect::<Vec<_>>();
957            let eq_groups: EquivalenceGroup = entries.clone().into();
958            let eq_groups = eq_groups.classes;
959            let err_msg = format!(
960                "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}"
961            );
962            assert_eq!(eq_groups.len(), expected.len(), "{err_msg}");
963            for idx in 0..eq_groups.len() {
964                assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}");
965            }
966        }
967        Ok(())
968    }
969
970    #[test]
971    fn test_remove_redundant_entries_eq_group() -> Result<()> {
972        let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _;
973        let entries = [
974            EquivalenceClass::new([c(1), c(1), lit(20)]),
975            EquivalenceClass::new([lit(30), lit(30)]),
976            EquivalenceClass::new([c(2), c(3), c(4)]),
977        ];
978        // Given equivalences classes are not in succinct form.
979        // Expected form is the most plain representation that is functionally same.
980        let expected = [
981            EquivalenceClass::new([c(1), lit(20)]),
982            EquivalenceClass::new([lit(30)]),
983            EquivalenceClass::new([c(2), c(3), c(4)]),
984        ];
985        let eq_groups = EquivalenceGroup::new(entries);
986        assert_eq!(eq_groups.classes, expected);
987        Ok(())
988    }
989
990    #[test]
991    fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
992        let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
993        let col_b = Arc::new(Column::new("b", 1)) as _;
994        let col_c = Arc::new(Column::new("c", 2)) as _;
995        // Assume that column a and c are aliases.
996        let (_, eq_properties) = create_test_params()?;
997        // Test cases for equivalence normalization. First entry in the tuple is
998        // the argument, second entry is expected result after normalization.
999        let expressions = vec![
1000            // Normalized version of the column a and c should go to a
1001            // (by convention all the expressions inside equivalence class are mapped to the first entry
1002            // in this case a is the first entry in the equivalence class.)
1003            (Arc::clone(&col_a), Arc::clone(&col_a)),
1004            (col_c, col_a),
1005            // Cannot normalize column b
1006            (Arc::clone(&col_b), Arc::clone(&col_b)),
1007        ];
1008        let eq_group = eq_properties.eq_group();
1009        for (expr, expected_eq) in expressions {
1010            assert!(expected_eq.eq(&eq_group.normalize_expr(expr)));
1011        }
1012
1013        Ok(())
1014    }
1015
1016    #[test]
1017    fn test_contains_any() {
1018        let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _;
1019        let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _;
1020        let col_a_expr = Arc::new(Column::new("a", 0)) as _;
1021        let col_b_expr = Arc::new(Column::new("b", 1)) as _;
1022        let col_c_expr = Arc::new(Column::new("c", 2)) as _;
1023
1024        let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]);
1025        let cls2 = EquivalenceClass::new([lit_true, col_b_expr]);
1026        let cls3 = EquivalenceClass::new([col_c_expr, lit_false]);
1027
1028        // lit_true is common
1029        assert!(cls1.contains_any(&cls2));
1030        // there is no common entry
1031        assert!(!cls1.contains_any(&cls3));
1032        assert!(!cls2.contains_any(&cls3));
1033    }
1034
1035    #[test]
1036    fn test_exprs_equal() -> Result<()> {
1037        struct TestCase {
1038            left: Arc<dyn PhysicalExpr>,
1039            right: Arc<dyn PhysicalExpr>,
1040            expected: bool,
1041            description: &'static str,
1042        }
1043
1044        // Create test columns
1045        let col_a = Arc::new(Column::new("a", 0)) as _;
1046        let col_b = Arc::new(Column::new("b", 1)) as _;
1047        let col_x = Arc::new(Column::new("x", 2)) as _;
1048        let col_y = Arc::new(Column::new("y", 3)) as _;
1049
1050        // Create test literals
1051        let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _;
1052        let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _;
1053
1054        // Create equivalence group with classes (a = x) and (b = y)
1055        let eq_group = EquivalenceGroup::new([
1056            EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]),
1057            EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]),
1058        ]);
1059
1060        let test_cases = vec![
1061            // Basic equality tests
1062            TestCase {
1063                left: Arc::clone(&col_a),
1064                right: Arc::clone(&col_a),
1065                expected: true,
1066                description: "Same column should be equal",
1067            },
1068            // Equivalence class tests
1069            TestCase {
1070                left: Arc::clone(&col_a),
1071                right: Arc::clone(&col_x),
1072                expected: true,
1073                description: "Columns in same equivalence class should be equal",
1074            },
1075            TestCase {
1076                left: Arc::clone(&col_b),
1077                right: Arc::clone(&col_y),
1078                expected: true,
1079                description: "Columns in same equivalence class should be equal",
1080            },
1081            TestCase {
1082                left: Arc::clone(&col_a),
1083                right: Arc::clone(&col_b),
1084                expected: false,
1085                description: "Columns in different equivalence classes should not be equal",
1086            },
1087            // Literal tests
1088            TestCase {
1089                left: Arc::clone(&lit_1),
1090                right: Arc::clone(&lit_1),
1091                expected: true,
1092                description: "Same literal should be equal",
1093            },
1094            TestCase {
1095                left: Arc::clone(&lit_1),
1096                right: Arc::clone(&lit_2),
1097                expected: false,
1098                description: "Different literals should not be equal",
1099            },
1100            // Complex expression tests
1101            TestCase {
1102                left: Arc::new(BinaryExpr::new(
1103                    Arc::clone(&col_a),
1104                    Operator::Plus,
1105                    Arc::clone(&col_b),
1106                )) as _,
1107                right: Arc::new(BinaryExpr::new(
1108                    Arc::clone(&col_x),
1109                    Operator::Plus,
1110                    Arc::clone(&col_y),
1111                )) as _,
1112                expected: true,
1113                description: "Binary expressions with equivalent operands should be equal",
1114            },
1115            TestCase {
1116                left: Arc::new(BinaryExpr::new(
1117                    Arc::clone(&col_a),
1118                    Operator::Plus,
1119                    Arc::clone(&col_b),
1120                )) as _,
1121                right: Arc::new(BinaryExpr::new(
1122                    Arc::clone(&col_x),
1123                    Operator::Plus,
1124                    Arc::clone(&col_a),
1125                )) as _,
1126                expected: false,
1127                description: "Binary expressions with non-equivalent operands should not be equal",
1128            },
1129            TestCase {
1130                left: Arc::new(BinaryExpr::new(
1131                    Arc::clone(&col_a),
1132                    Operator::Plus,
1133                    Arc::clone(&lit_1),
1134                )) as _,
1135                right: Arc::new(BinaryExpr::new(
1136                    Arc::clone(&col_x),
1137                    Operator::Plus,
1138                    Arc::clone(&lit_1),
1139                )) as _,
1140                expected: true,
1141                description: "Binary expressions with equivalent column and same literal should be equal",
1142            },
1143            TestCase {
1144                left: Arc::new(BinaryExpr::new(
1145                    Arc::new(BinaryExpr::new(
1146                        Arc::clone(&col_a),
1147                        Operator::Plus,
1148                        Arc::clone(&col_b),
1149                    )),
1150                    Operator::Multiply,
1151                    Arc::clone(&lit_1),
1152                )) as _,
1153                right: Arc::new(BinaryExpr::new(
1154                    Arc::new(BinaryExpr::new(
1155                        Arc::clone(&col_x),
1156                        Operator::Plus,
1157                        Arc::clone(&col_y),
1158                    )),
1159                    Operator::Multiply,
1160                    Arc::clone(&lit_1),
1161                )) as _,
1162                expected: true,
1163                description: "Nested binary expressions with equivalent operands should be equal",
1164            },
1165        ];
1166
1167        for TestCase {
1168            left,
1169            right,
1170            expected,
1171            description,
1172        } in test_cases
1173        {
1174            let actual = eq_group.exprs_equal(&left, &right);
1175            assert_eq!(
1176                actual, expected,
1177                "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}"
1178            );
1179        }
1180
1181        Ok(())
1182    }
1183
1184    #[test]
1185    fn test_project_classes() -> Result<()> {
1186        // - columns: [a, b, c].
1187        // - "a" and "b" in the same equivalence class.
1188        // - then after a+c, b+c projection col(0) and col(1) must be
1189        // in the same class too.
1190        let schema = Arc::new(Schema::new(vec![
1191            Field::new("a", DataType::Int32, false),
1192            Field::new("b", DataType::Int32, false),
1193            Field::new("c", DataType::Int32, false),
1194        ]));
1195        let mut group = EquivalenceGroup::default();
1196        group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?);
1197
1198        let projected_schema = Arc::new(Schema::new(vec![
1199            Field::new("a+c", DataType::Int32, false),
1200            Field::new("b+c", DataType::Int32, false),
1201        ]));
1202
1203        let mapping = [
1204            (
1205                binary(
1206                    col("a", &schema)?,
1207                    Operator::Plus,
1208                    col("c", &schema)?,
1209                    &schema,
1210                )?,
1211                vec![(col("a+c", &projected_schema)?, 0)].into(),
1212            ),
1213            (
1214                binary(
1215                    col("b", &schema)?,
1216                    Operator::Plus,
1217                    col("c", &schema)?,
1218                    &schema,
1219                )?,
1220                vec![(col("b+c", &projected_schema)?, 1)].into(),
1221            ),
1222        ]
1223        .into_iter()
1224        .collect::<ProjectionMapping>();
1225
1226        let projected = group.project(&mapping);
1227
1228        assert!(!projected.is_empty());
1229        let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1230        let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1231
1232        assert!(first_normalized.eq(&second_normalized));
1233
1234        Ok(())
1235    }
1236}