use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping};
use crate::{
expressions::Column, physical_expr::deduplicate_physical_exprs,
physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef,
LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr,
PhysicalSortRequirement,
};
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{tree_node::Transformed, JoinType};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct EquivalenceClass {
exprs: Vec<Arc<dyn PhysicalExpr>>,
}
impl PartialEq for EquivalenceClass {
fn eq(&self, other: &Self) -> bool {
physical_exprs_bag_equal(&self.exprs, &other.exprs)
}
}
impl EquivalenceClass {
pub fn new_empty() -> Self {
Self { exprs: vec![] }
}
pub fn new(mut exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
deduplicate_physical_exprs(&mut exprs);
Self { exprs }
}
pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
self.exprs
}
fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
self.exprs.first().cloned()
}
pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
if !self.contains(&expr) {
self.exprs.push(expr);
}
}
pub fn extend(&mut self, other: Self) {
for expr in other.exprs {
self.push(expr);
}
}
pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
physical_exprs_contains(&self.exprs, expr)
}
pub fn contains_any(&self, other: &Self) -> bool {
self.exprs.iter().any(|e| other.contains(e))
}
pub fn len(&self) -> usize {
self.exprs.len()
}
pub fn is_empty(&self) -> bool {
self.exprs.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> {
self.exprs.iter()
}
pub fn with_offset(&self, offset: usize) -> Self {
let new_exprs = self
.exprs
.iter()
.cloned()
.map(|e| add_offset_to_expr(e, offset))
.collect();
Self::new(new_exprs)
}
}
#[derive(Debug, Clone)]
pub struct EquivalenceGroup {
pub classes: Vec<EquivalenceClass>,
}
impl EquivalenceGroup {
pub fn empty() -> Self {
Self { classes: vec![] }
}
pub fn new(classes: Vec<EquivalenceClass>) -> Self {
let mut result = Self { classes };
result.remove_redundant_entries();
result
}
pub fn len(&self) -> usize {
self.classes.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> {
self.classes.iter()
}
pub fn add_equal_conditions(
&mut self,
left: &Arc<dyn PhysicalExpr>,
right: &Arc<dyn PhysicalExpr>,
) {
let mut first_class = None;
let mut second_class = None;
for (idx, cls) in self.classes.iter().enumerate() {
if cls.contains(left) {
first_class = Some(idx);
}
if cls.contains(right) {
second_class = Some(idx);
}
}
match (first_class, second_class) {
(Some(mut first_idx), Some(mut second_idx)) => {
if first_idx != second_idx {
if first_idx > second_idx {
(first_idx, second_idx) = (second_idx, first_idx);
}
let other_class = self.classes.swap_remove(second_idx);
self.classes[first_idx].extend(other_class);
}
}
(Some(group_idx), None) => {
self.classes[group_idx].push(right.clone());
}
(None, Some(group_idx)) => {
self.classes[group_idx].push(left.clone());
}
(None, None) => {
self.classes
.push(EquivalenceClass::new(vec![left.clone(), right.clone()]));
}
}
}
fn remove_redundant_entries(&mut self) {
self.classes.retain_mut(|cls| {
cls.len() > 1
});
self.bridge_classes()
}
fn bridge_classes(&mut self) {
let mut idx = 0;
while idx < self.classes.len() {
let mut next_idx = idx + 1;
let start_size = self.classes[idx].len();
while next_idx < self.classes.len() {
if self.classes[idx].contains_any(&self.classes[next_idx]) {
let extension = self.classes.swap_remove(next_idx);
self.classes[idx].extend(extension);
} else {
next_idx += 1;
}
}
if self.classes[idx].len() > start_size {
continue;
}
idx += 1;
}
}
pub fn extend(&mut self, other: Self) {
self.classes.extend(other.classes);
self.remove_redundant_entries();
}
pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
expr.clone()
.transform(&|expr| {
for cls in self.iter() {
if cls.contains(&expr) {
return Ok(Transformed::Yes(cls.canonical_expr().unwrap()));
}
}
Ok(Transformed::No(expr))
})
.unwrap_or(expr)
}
pub fn normalize_sort_expr(
&self,
mut sort_expr: PhysicalSortExpr,
) -> PhysicalSortExpr {
sort_expr.expr = self.normalize_expr(sort_expr.expr);
sort_expr
}
pub fn normalize_sort_requirement(
&self,
mut sort_requirement: PhysicalSortRequirement,
) -> PhysicalSortRequirement {
sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
sort_requirement
}
pub fn normalize_exprs(
&self,
exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
) -> Vec<Arc<dyn PhysicalExpr>> {
exprs
.into_iter()
.map(|expr| self.normalize_expr(expr))
.collect()
}
pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering {
let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter());
let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs)
}
pub fn normalize_sort_requirements(
&self,
sort_reqs: LexRequirementRef,
) -> LexRequirement {
collapse_lex_req(
sort_reqs
.iter()
.map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
.collect(),
)
}
pub fn project_expr(
&self,
mapping: &ProjectionMapping,
expr: &Arc<dyn PhysicalExpr>,
) -> Option<Arc<dyn PhysicalExpr>> {
if let Some(target) = mapping.target_expr(expr) {
return Some(target);
} else {
for (source, target) in mapping.iter() {
if self
.get_equivalence_class(source)
.map_or(false, |group| group.contains(expr))
{
return Some(target.clone());
}
}
}
let children = expr.children();
if children.is_empty() {
return None;
}
children
.into_iter()
.map(|child| self.project_expr(mapping, &child))
.collect::<Option<Vec<_>>>()
.map(|children| expr.clone().with_new_children(children).unwrap())
}
pub fn project(&self, mapping: &ProjectionMapping) -> Self {
let projected_classes = self.iter().filter_map(|cls| {
let new_class = cls
.iter()
.filter_map(|expr| self.project_expr(mapping, expr))
.collect::<Vec<_>>();
(new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
});
let mut new_classes = vec![];
for (source, target) in mapping.iter() {
if new_classes.is_empty() {
new_classes.push((source, vec![target.clone()]));
}
if let Some((_, values)) =
new_classes.iter_mut().find(|(key, _)| key.eq(source))
{
if !physical_exprs_contains(values, target) {
values.push(target.clone());
}
}
}
let new_classes = new_classes
.into_iter()
.filter_map(|(_, values)| (values.len() > 1).then_some(values))
.map(EquivalenceClass::new);
let classes = projected_classes.chain(new_classes).collect();
Self::new(classes)
}
fn get_equivalence_class(
&self,
expr: &Arc<dyn PhysicalExpr>,
) -> Option<&EquivalenceClass> {
self.iter().find(|cls| cls.contains(expr))
}
pub fn join(
&self,
right_equivalences: &Self,
join_type: &JoinType,
left_size: usize,
on: &[(Column, Column)],
) -> Self {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let mut result = Self::new(
self.iter()
.cloned()
.chain(
right_equivalences
.iter()
.map(|cls| cls.with_offset(left_size)),
)
.collect(),
);
if join_type == &JoinType::Inner {
for (lhs, rhs) in on.iter() {
let index = rhs.index() + left_size;
let new_lhs = Arc::new(lhs.clone()) as _;
let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _;
result.add_equal_conditions(&new_lhs, &new_rhs);
}
}
result
}
JoinType::LeftSemi | JoinType::LeftAnti => self.clone(),
JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::equivalence::tests::create_test_params;
use crate::equivalence::{EquivalenceClass, EquivalenceGroup};
use crate::expressions::lit;
use crate::expressions::Column;
use crate::expressions::Literal;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use std::sync::Arc;
#[test]
fn test_bridge_groups() -> Result<()> {
let test_cases = vec![
(
vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
),
(
vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
),
];
for (entries, expected) in test_cases {
let entries = entries
.into_iter()
.map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
.map(EquivalenceClass::new)
.collect::<Vec<_>>();
let expected = expected
.into_iter()
.map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
.map(EquivalenceClass::new)
.collect::<Vec<_>>();
let mut eq_groups = EquivalenceGroup::new(entries.clone());
eq_groups.bridge_classes();
let eq_groups = eq_groups.classes;
let err_msg = format!(
"error in test entries: {:?}, expected: {:?}, actual:{:?}",
entries, expected, eq_groups
);
assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg);
for idx in 0..eq_groups.len() {
assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg);
}
}
Ok(())
}
#[test]
fn test_remove_redundant_entries_eq_group() -> Result<()> {
let entries = vec![
EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]),
EquivalenceClass::new(vec![lit(3), lit(3)]),
EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
];
let expected = vec![
EquivalenceClass::new(vec![lit(1), lit(2)]),
EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
];
let mut eq_groups = EquivalenceGroup::new(entries);
eq_groups.remove_redundant_entries();
let eq_groups = eq_groups.classes;
assert_eq!(eq_groups.len(), expected.len());
assert_eq!(eq_groups.len(), 2);
assert_eq!(eq_groups[0], expected[0]);
assert_eq!(eq_groups[1], expected[1]);
Ok(())
}
#[test]
fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
let col_a = &Column::new("a", 0);
let col_b = &Column::new("b", 1);
let col_c = &Column::new("c", 2);
let (_test_schema, eq_properties) = create_test_params()?;
let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>;
let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>;
let expressions = vec![
(&col_a_expr, &col_a_expr),
(&col_c_expr, &col_a_expr),
(&col_b_expr, &col_b_expr),
];
let eq_group = eq_properties.eq_group();
for (expr, expected_eq) in expressions {
assert!(
expected_eq.eq(&eq_group.normalize_expr(expr.clone())),
"error in test: expr: {expr:?}"
);
}
Ok(())
}
#[test]
fn test_contains_any() {
let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>;
let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>;
let lit2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let lit1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]);
let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]);
let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]);
assert!(cls1.contains_any(&cls2));
assert!(!cls1.contains_any(&cls3));
assert!(!cls2.contains_any(&cls3));
}
}