use super::{add_offset_to_expr, ProjectionMapping};
use crate::{
expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef,
PhysicalSortExpr, PhysicalSortRequirement,
};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{JoinType, ScalarValue};
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
use std::fmt::Display;
use std::sync::Arc;
use std::vec::IntoIter;
use indexmap::{IndexMap, IndexSet};
#[derive(Debug, Clone)]
pub struct ConstExpr {
expr: Arc<dyn PhysicalExpr>,
across_partitions: AcrossPartitions,
}
#[derive(PartialEq, Clone, Debug)]
pub enum AcrossPartitions {
Heterogeneous,
Uniform(Option<ScalarValue>),
}
impl Default for AcrossPartitions {
fn default() -> Self {
Self::Heterogeneous
}
}
impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
}
}
impl ConstExpr {
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self {
expr,
across_partitions: Default::default(),
}
}
pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
self.across_partitions = across_partitions;
self
}
pub fn across_partitions(&self) -> AcrossPartitions {
self.across_partitions.clone()
}
pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
&self.expr
}
pub fn owned_expr(self) -> Arc<dyn PhysicalExpr> {
self.expr
}
pub fn map<F>(&self, f: F) -> Option<Self>
where
F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
{
let maybe_expr = f(&self.expr);
maybe_expr.map(|expr| Self {
expr,
across_partitions: self.across_partitions.clone(),
})
}
pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
self.expr.as_ref() == other.as_ref()
}
pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
struct DisplayableList<'a>(&'a [ConstExpr]);
impl Display for DisplayableList<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut first = true;
for const_expr in self.0 {
if first {
first = false;
} else {
write!(f, ",")?;
}
write!(f, "{}", const_expr)?;
}
Ok(())
}
}
DisplayableList(input)
}
}
impl Display for ConstExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)?;
match &self.across_partitions {
AcrossPartitions::Heterogeneous => {
write!(f, "(heterogeneous)")?;
}
AcrossPartitions::Uniform(value) => {
if let Some(val) = value {
write!(f, "(uniform: {})", val)?;
} else {
write!(f, "(uniform: unknown)")?;
}
}
}
Ok(())
}
}
impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
Self::new(expr)
}
}
impl From<&Arc<dyn PhysicalExpr>> for ConstExpr {
fn from(expr: &Arc<dyn PhysicalExpr>) -> Self {
Self::new(Arc::clone(expr))
}
}
pub fn const_exprs_contains(
const_exprs: &[ConstExpr],
expr: &Arc<dyn PhysicalExpr>,
) -> bool {
const_exprs
.iter()
.any(|const_expr| const_expr.expr.eq(expr))
}
#[derive(Debug, Clone)]
pub struct EquivalenceClass {
exprs: IndexSet<Arc<dyn PhysicalExpr>>,
}
impl PartialEq for EquivalenceClass {
fn eq(&self, other: &Self) -> bool {
self.exprs.eq(&other.exprs)
}
}
impl EquivalenceClass {
pub fn new_empty() -> Self {
Self {
exprs: IndexSet::new(),
}
}
pub fn new(exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
Self {
exprs: exprs.into_iter().collect(),
}
}
pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
self.exprs.into_iter().collect()
}
fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
self.exprs.iter().next().cloned()
}
pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
self.exprs.insert(expr);
}
pub fn extend(&mut self, other: Self) {
for expr in other.exprs {
self.push(expr);
}
}
pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
self.exprs.contains(expr)
}
pub fn contains_any(&self, other: &Self) -> bool {
self.exprs.iter().any(|e| other.contains(e))
}
pub fn len(&self) -> usize {
self.exprs.len()
}
pub fn is_empty(&self) -> bool {
self.exprs.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> {
self.exprs.iter()
}
pub fn with_offset(&self, offset: usize) -> Self {
let new_exprs = self
.exprs
.iter()
.cloned()
.map(|e| add_offset_to_expr(e, offset))
.collect();
Self::new(new_exprs)
}
}
impl Display for EquivalenceClass {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "[{}]", format_physical_expr_list(&self.exprs))
}
}
#[derive(Debug, Clone)]
pub struct EquivalenceGroup {
classes: Vec<EquivalenceClass>,
}
impl EquivalenceGroup {
pub fn empty() -> Self {
Self { classes: vec![] }
}
pub fn new(classes: Vec<EquivalenceClass>) -> Self {
let mut result = Self { classes };
result.remove_redundant_entries();
result
}
pub fn len(&self) -> usize {
self.classes.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> {
self.classes.iter()
}
pub fn add_equal_conditions(
&mut self,
left: &Arc<dyn PhysicalExpr>,
right: &Arc<dyn PhysicalExpr>,
) {
let mut first_class = None;
let mut second_class = None;
for (idx, cls) in self.classes.iter().enumerate() {
if cls.contains(left) {
first_class = Some(idx);
}
if cls.contains(right) {
second_class = Some(idx);
}
}
match (first_class, second_class) {
(Some(mut first_idx), Some(mut second_idx)) => {
if first_idx != second_idx {
if first_idx > second_idx {
(first_idx, second_idx) = (second_idx, first_idx);
}
let other_class = self.classes.swap_remove(second_idx);
self.classes[first_idx].extend(other_class);
}
}
(Some(group_idx), None) => {
self.classes[group_idx].push(Arc::clone(right));
}
(None, Some(group_idx)) => {
self.classes[group_idx].push(Arc::clone(left));
}
(None, None) => {
self.classes.push(EquivalenceClass::new(vec![
Arc::clone(left),
Arc::clone(right),
]));
}
}
}
fn remove_redundant_entries(&mut self) {
self.classes.retain_mut(|cls| {
cls.len() > 1
});
self.bridge_classes()
}
fn bridge_classes(&mut self) {
let mut idx = 0;
while idx < self.classes.len() {
let mut next_idx = idx + 1;
let start_size = self.classes[idx].len();
while next_idx < self.classes.len() {
if self.classes[idx].contains_any(&self.classes[next_idx]) {
let extension = self.classes.swap_remove(next_idx);
self.classes[idx].extend(extension);
} else {
next_idx += 1;
}
}
if self.classes[idx].len() > start_size {
continue;
}
idx += 1;
}
}
pub fn extend(&mut self, other: Self) {
self.classes.extend(other.classes);
self.remove_redundant_entries();
}
pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
expr.transform(|expr| {
for cls in self.iter() {
if cls.contains(&expr) {
return Ok(Transformed::yes(cls.canonical_expr().unwrap()));
}
}
Ok(Transformed::no(expr))
})
.data()
.unwrap()
}
pub fn normalize_sort_expr(
&self,
mut sort_expr: PhysicalSortExpr,
) -> PhysicalSortExpr {
sort_expr.expr = self.normalize_expr(sort_expr.expr);
sort_expr
}
pub fn normalize_sort_requirement(
&self,
mut sort_requirement: PhysicalSortRequirement,
) -> PhysicalSortRequirement {
sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
sort_requirement
}
pub fn normalize_exprs(
&self,
exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
) -> Vec<Arc<dyn PhysicalExpr>> {
exprs
.into_iter()
.map(|expr| self.normalize_expr(expr))
.collect()
}
pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering {
let sort_reqs = LexRequirement::from(sort_exprs.clone());
let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
LexOrdering::from(normalized_sort_reqs)
}
pub fn normalize_sort_requirements(
&self,
sort_reqs: &LexRequirement,
) -> LexRequirement {
LexRequirement::new(
sort_reqs
.iter()
.map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
.collect(),
)
.collapse()
}
pub fn project_expr(
&self,
mapping: &ProjectionMapping,
expr: &Arc<dyn PhysicalExpr>,
) -> Option<Arc<dyn PhysicalExpr>> {
if let Some(target) = mapping.target_expr(expr) {
return Some(target);
} else {
for (source, target) in mapping.iter() {
if self
.get_equivalence_class(source)
.is_some_and(|group| group.contains(expr))
{
return Some(Arc::clone(target));
}
}
}
let children = expr.children();
if children.is_empty() {
return None;
}
children
.into_iter()
.map(|child| self.project_expr(mapping, child))
.collect::<Option<Vec<_>>>()
.map(|children| Arc::clone(expr).with_new_children(children).unwrap())
}
pub fn project(&self, mapping: &ProjectionMapping) -> Self {
let projected_classes = self.iter().filter_map(|cls| {
let new_class = cls
.iter()
.filter_map(|expr| self.project_expr(mapping, expr))
.collect::<Vec<_>>();
(new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
});
let mut new_classes: IndexMap<_, _> = IndexMap::new();
for (source, target) in mapping.iter() {
let normalized_expr = self.normalize_expr(Arc::clone(source));
new_classes
.entry(normalized_expr)
.or_insert_with(EquivalenceClass::new_empty)
.push(Arc::clone(target));
}
let new_classes = new_classes
.into_iter()
.filter_map(|(_, cls)| (cls.len() > 1).then_some(cls));
let classes = projected_classes.chain(new_classes).collect();
Self::new(classes)
}
fn get_equivalence_class(
&self,
expr: &Arc<dyn PhysicalExpr>,
) -> Option<&EquivalenceClass> {
self.iter().find(|cls| cls.contains(expr))
}
pub fn join(
&self,
right_equivalences: &Self,
join_type: &JoinType,
left_size: usize,
on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Self {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let mut result = Self::new(
self.iter()
.cloned()
.chain(
right_equivalences
.iter()
.map(|cls| cls.with_offset(left_size)),
)
.collect(),
);
if join_type == &JoinType::Inner {
for (lhs, rhs) in on.iter() {
let new_lhs = Arc::clone(lhs);
let new_rhs = Arc::clone(rhs)
.transform(|expr| {
if let Some(column) =
expr.as_any().downcast_ref::<Column>()
{
let new_column = Arc::new(Column::new(
column.name(),
column.index() + left_size,
))
as _;
return Ok(Transformed::yes(new_column));
}
Ok(Transformed::no(expr))
})
.data()
.unwrap();
result.add_equal_conditions(&new_lhs, &new_rhs);
}
}
result
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
}
}
pub fn exprs_equal(
&self,
left: &Arc<dyn PhysicalExpr>,
right: &Arc<dyn PhysicalExpr>,
) -> bool {
if left.eq(right) {
return true;
}
if let Some(left_class) = self.get_equivalence_class(left) {
if left_class.contains(right) {
return true;
}
}
if let Some(right_class) = self.get_equivalence_class(right) {
if right_class.contains(left) {
return true;
}
}
let left_children = left.children();
let right_children = right.children();
if left_children.is_empty() || right_children.is_empty() {
return false;
}
if left.as_any().type_id() != right.as_any().type_id() {
return false;
}
if left_children.len() != right_children.len() {
return false;
}
left_children
.into_iter()
.zip(right_children)
.all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
}
pub fn into_inner(self) -> Vec<EquivalenceClass> {
self.classes
}
}
impl IntoIterator for EquivalenceGroup {
type Item = EquivalenceClass;
type IntoIter = IntoIter<EquivalenceClass>;
fn into_iter(self) -> Self::IntoIter {
self.classes.into_iter()
}
}
impl Display for EquivalenceGroup {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "[")?;
let mut iter = self.iter();
if let Some(cls) = iter.next() {
write!(f, "{}", cls)?;
}
for cls in iter {
write!(f, ", {}", cls)?;
}
write!(f, "]")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::equivalence::tests::create_test_params;
use crate::expressions::{binary, col, lit, BinaryExpr, Literal};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Operator;
#[test]
fn test_bridge_groups() -> Result<()> {
let test_cases = vec![
(
vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
),
(
vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
),
];
for (entries, expected) in test_cases {
let entries = entries
.into_iter()
.map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
.map(EquivalenceClass::new)
.collect::<Vec<_>>();
let expected = expected
.into_iter()
.map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
.map(EquivalenceClass::new)
.collect::<Vec<_>>();
let mut eq_groups = EquivalenceGroup::new(entries.clone());
eq_groups.bridge_classes();
let eq_groups = eq_groups.classes;
let err_msg = format!(
"error in test entries: {:?}, expected: {:?}, actual:{:?}",
entries, expected, eq_groups
);
assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg);
for idx in 0..eq_groups.len() {
assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg);
}
}
Ok(())
}
#[test]
fn test_remove_redundant_entries_eq_group() -> Result<()> {
let entries = [
EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]),
EquivalenceClass::new(vec![lit(3), lit(3)]),
EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
];
let expected = [
EquivalenceClass::new(vec![lit(1), lit(2)]),
EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
];
let mut eq_groups = EquivalenceGroup::new(entries.to_vec());
eq_groups.remove_redundant_entries();
let eq_groups = eq_groups.classes;
assert_eq!(eq_groups.len(), expected.len());
assert_eq!(eq_groups.len(), 2);
assert_eq!(eq_groups[0], expected[0]);
assert_eq!(eq_groups[1], expected[1]);
Ok(())
}
#[test]
fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
let col_a = &Column::new("a", 0);
let col_b = &Column::new("b", 1);
let col_c = &Column::new("c", 2);
let (_test_schema, eq_properties) = create_test_params()?;
let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>;
let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>;
let expressions = vec![
(&col_a_expr, &col_a_expr),
(&col_c_expr, &col_a_expr),
(&col_b_expr, &col_b_expr),
];
let eq_group = eq_properties.eq_group();
for (expr, expected_eq) in expressions {
assert!(
expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))),
"error in test: expr: {expr:?}"
);
}
Ok(())
}
#[test]
fn test_contains_any() {
let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>;
let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>;
let lit2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let lit1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let cls1 =
EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]);
let cls2 =
EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]);
let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]);
assert!(cls1.contains_any(&cls2));
assert!(!cls1.contains_any(&cls3));
assert!(!cls2.contains_any(&cls3));
}
#[test]
fn test_exprs_equal() -> Result<()> {
struct TestCase {
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
expected: bool,
description: &'static str,
}
let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;
let lit_1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let lit_2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let eq_group = EquivalenceGroup::new(vec![
EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
]);
let test_cases = vec![
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_a),
expected: true,
description: "Same column should be equal",
},
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_x),
expected: true,
description: "Columns in same equivalence class should be equal",
},
TestCase {
left: Arc::clone(&col_b),
right: Arc::clone(&col_y),
expected: true,
description: "Columns in same equivalence class should be equal",
},
TestCase {
left: Arc::clone(&col_a),
right: Arc::clone(&col_b),
expected: false,
description:
"Columns in different equivalence classes should not be equal",
},
TestCase {
left: Arc::clone(&lit_1),
right: Arc::clone(&lit_1),
expected: true,
description: "Same literal should be equal",
},
TestCase {
left: Arc::clone(&lit_1),
right: Arc::clone(&lit_2),
expected: false,
description: "Different literals should not be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_y),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description:
"Binary expressions with equivalent operands should be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_a),
)) as Arc<dyn PhysicalExpr>,
expected: false,
description:
"Binary expressions with non-equivalent operands should not be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description: "Binary expressions with equivalent column and same literal should be equal",
},
TestCase {
left: Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::clone(&col_a),
Operator::Plus,
Arc::clone(&col_b),
)),
Operator::Multiply,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
right: Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::clone(&col_x),
Operator::Plus,
Arc::clone(&col_y),
)),
Operator::Multiply,
Arc::clone(&lit_1),
)) as Arc<dyn PhysicalExpr>,
expected: true,
description: "Nested binary expressions with equivalent operands should be equal",
},
];
for TestCase {
left,
right,
expected,
description,
} in test_cases
{
let actual = eq_group.exprs_equal(&left, &right);
assert_eq!(
actual, expected,
"{}: Failed comparing {:?} and {:?}, expected {}, got {}",
description, left, right, expected, actual
);
}
Ok(())
}
#[test]
fn test_project_classes() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let mut group = EquivalenceGroup::empty();
group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?);
let projected_schema = Arc::new(Schema::new(vec![
Field::new("a+c", DataType::Int32, false),
Field::new("b+c", DataType::Int32, false),
]));
let mapping = ProjectionMapping {
map: vec![
(
binary(
col("a", &schema)?,
Operator::Plus,
col("c", &schema)?,
&schema,
)?,
col("a+c", &projected_schema)?,
),
(
binary(
col("b", &schema)?,
Operator::Plus,
col("c", &schema)?,
&schema,
)?,
col("b+c", &projected_schema)?,
),
],
};
let projected = group.project(&mapping);
assert!(!projected.is_empty());
let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
assert!(first_normalized.eq(&second_normalized));
Ok(())
}
}