use datafusion_common::{internal_err, Result};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use std::iter::Peekable;
use std::sync::Arc;
use crate::equivalence::class::AcrossPartitions;
use crate::ConstExpr;
use super::EquivalenceProperties;
use crate::PhysicalSortExpr;
use arrow::datatypes::SchemaRef;
use std::slice::Iter;
fn calculate_union_binary(
lhs: EquivalenceProperties,
mut rhs: EquivalenceProperties,
) -> Result<EquivalenceProperties> {
if !rhs.schema.eq(&lhs.schema) {
rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?;
}
let constants = lhs
.constants()
.iter()
.filter_map(|lhs_const| {
rhs.constants()
.iter()
.find(|rhs_const| rhs_const.expr().eq(lhs_const.expr()))
.map(|rhs_const| {
let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr()));
if let (
AcrossPartitions::Uniform(Some(lhs_val)),
AcrossPartitions::Uniform(Some(rhs_val)),
) = (lhs_const.across_partitions(), rhs_const.across_partitions())
{
if lhs_val == rhs_val {
const_expr = const_expr.with_across_partitions(
AcrossPartitions::Uniform(Some(lhs_val)),
)
}
}
const_expr
})
})
.collect::<Vec<_>>();
let mut orderings = UnionEquivalentOrderingBuilder::new();
orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs);
orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs);
let orderings = orderings.build();
let mut eq_properties =
EquivalenceProperties::new(lhs.schema).with_constants(constants);
eq_properties.add_new_orderings(orderings);
Ok(eq_properties)
}
pub fn calculate_union(
eqps: Vec<EquivalenceProperties>,
schema: SchemaRef,
) -> Result<EquivalenceProperties> {
let mut iter = eqps.into_iter();
let Some(mut acc) = iter.next() else {
return internal_err!(
"Cannot calculate EquivalenceProperties for a union with no inputs"
);
};
if !acc.schema.eq(&schema) {
acc = acc.with_new_schema(schema)?;
}
for props in iter {
acc = calculate_union_binary(acc, props)?;
}
Ok(acc)
}
#[derive(Debug)]
enum AddedOrdering {
Yes,
No(LexOrdering),
}
#[derive(Debug)]
struct UnionEquivalentOrderingBuilder {
orderings: Vec<LexOrdering>,
}
impl UnionEquivalentOrderingBuilder {
fn new() -> Self {
Self { orderings: vec![] }
}
fn add_satisfied_orderings(
&mut self,
orderings: impl IntoIterator<Item = LexOrdering>,
constants: &[ConstExpr],
properties: &EquivalenceProperties,
) {
for mut ordering in orderings.into_iter() {
loop {
match self.try_add_ordering(ordering, constants, properties) {
AddedOrdering::Yes => break,
AddedOrdering::No(o) => {
ordering = o;
ordering.pop();
}
}
}
}
}
fn try_add_ordering(
&mut self,
ordering: LexOrdering,
constants: &[ConstExpr],
properties: &EquivalenceProperties,
) -> AddedOrdering {
if ordering.is_empty() {
AddedOrdering::Yes
} else if properties.ordering_satisfy(ordering.as_ref()) {
self.orderings.push(ordering);
AddedOrdering::Yes
} else {
if self.try_find_augmented_ordering(&ordering, constants, properties) {
AddedOrdering::Yes
} else {
AddedOrdering::No(ordering)
}
}
}
fn try_find_augmented_ordering(
&mut self,
ordering: &LexOrdering,
constants: &[ConstExpr],
properties: &EquivalenceProperties,
) -> bool {
if constants.is_empty() {
return false;
}
let start_num_orderings = self.orderings.len();
for existing_ordering in properties.oeq_class.iter() {
if let Some(augmented_ordering) = self.augment_ordering(
ordering,
constants,
existing_ordering,
&properties.constants,
) {
if !augmented_ordering.is_empty() {
assert!(properties.ordering_satisfy(augmented_ordering.as_ref()));
self.orderings.push(augmented_ordering);
}
}
}
self.orderings.len() > start_num_orderings
}
fn augment_ordering(
&mut self,
ordering: &LexOrdering,
constants: &[ConstExpr],
existing_ordering: &LexOrdering,
existing_constants: &[ConstExpr],
) -> Option<LexOrdering> {
let mut augmented_ordering = LexOrdering::default();
let mut sort_expr_iter = ordering.iter().peekable();
let mut existing_sort_expr_iter = existing_ordering.iter().peekable();
while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some()
{
if let Some(expr) =
advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter)
{
augmented_ordering.push(expr);
} else if let Some(expr) =
advance_if_matches_constant(&mut sort_expr_iter, existing_constants)
{
augmented_ordering.push(expr);
} else if let Some(expr) =
advance_if_matches_constant(&mut existing_sort_expr_iter, constants)
{
augmented_ordering.push(expr);
} else {
break;
}
}
Some(augmented_ordering)
}
fn build(self) -> Vec<LexOrdering> {
self.orderings
}
}
fn advance_if_match(
iter1: &mut Peekable<Iter<PhysicalSortExpr>>,
iter2: &mut Peekable<Iter<PhysicalSortExpr>>,
) -> Option<PhysicalSortExpr> {
if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2))
{
iter1.next().unwrap();
iter2.next().cloned()
} else {
None
}
}
fn advance_if_matches_constant(
iter: &mut Peekable<Iter<PhysicalSortExpr>>,
constants: &[ConstExpr],
) -> Option<PhysicalSortExpr> {
let expr = iter.peek()?;
let const_expr = constants.iter().find(|c| c.eq_expr(expr))?;
let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options);
iter.next();
Some(found_expr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::equivalence::class::const_exprs_contains;
use crate::equivalence::tests::{create_test_schema, parse_sort_expr};
use crate::expressions::col;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
use itertools::Itertools;
#[test]
fn test_union_equivalence_properties_multi_children_1() {
let schema = create_test_schema().unwrap();
let schema2 = append_fields(&schema, "1");
let schema3 = append_fields(&schema, "2");
UnionEquivalenceTest::new(&schema)
.with_child_sort(vec![vec!["a", "b", "c"]], &schema)
.with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)
.with_child_sort(vec![vec!["a2", "b2"]], &schema3)
.with_expected_sort(vec![vec!["a", "b"]])
.run()
}
#[test]
fn test_union_equivalence_properties_multi_children_2() {
let schema = create_test_schema().unwrap();
let schema2 = append_fields(&schema, "1");
let schema3 = append_fields(&schema, "2");
UnionEquivalenceTest::new(&schema)
.with_child_sort(vec![vec!["a", "b", "c"]], &schema)
.with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)
.with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)
.with_expected_sort(vec![vec!["a", "b", "c"]])
.run()
}
#[test]
fn test_union_equivalence_properties_multi_children_3() {
let schema = create_test_schema().unwrap();
let schema2 = append_fields(&schema, "1");
let schema3 = append_fields(&schema, "2");
UnionEquivalenceTest::new(&schema)
.with_child_sort(vec![vec!["a", "b"]], &schema)
.with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)
.with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)
.with_expected_sort(vec![vec!["a", "b"]])
.run()
}
#[test]
fn test_union_equivalence_properties_multi_children_4() {
let schema = create_test_schema().unwrap();
let schema2 = append_fields(&schema, "1");
let schema3 = append_fields(&schema, "2");
UnionEquivalenceTest::new(&schema)
.with_child_sort(vec![vec!["a", "b"]], &schema)
.with_child_sort(vec![vec!["a1", "b1"]], &schema2)
.with_child_sort(vec![vec!["b2", "c2"]], &schema3)
.with_expected_sort(vec![])
.run()
}
#[test]
fn test_union_equivalence_properties_multi_children_5() {
let schema = create_test_schema().unwrap();
let schema2 = append_fields(&schema, "1");
UnionEquivalenceTest::new(&schema)
.with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema)
.with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2)
.with_expected_sort(vec![vec!["a", "b"], vec!["c"]])
.run()
}
#[test]
fn test_union_equivalence_properties_constants_common_constants() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a"]],
vec!["b", "c"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["b"]],
vec!["a", "c"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a"], vec!["b"]],
vec!["c"],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_prefix() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a"]],
vec![],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a", "b"]],
vec![],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_asc_desc_mismatch() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a"]],
vec![],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a DESC"]],
vec![],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_different_schemas() {
let schema = create_test_schema().unwrap();
let schema2 = append_fields(&schema, "1");
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a"]],
vec![],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a1", "b1"]],
vec![],
&schema2,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_fill_gaps() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a", "c"]],
vec!["b"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["b", "c"]],
vec!["a"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a", "b", "c"], vec!["b", "a", "c"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_no_fill_gaps() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a", "c"]],
vec!["d"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["b", "c"]],
vec!["a"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_fill_some_gaps() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["c"]],
vec!["a", "b"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a DESC", "b"]],
vec![],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a DESC", "b"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a", "c"]],
vec!["b"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["b DESC", "c"]],
vec!["a"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_gap_fill_symmetric() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a", "b", "d"]],
vec!["c"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a", "c", "d"]],
vec!["b"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]],
vec![],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_gap_fill_and_common() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a DESC", "d"]],
vec!["b", "c"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a DESC", "c", "d"]],
vec!["b"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a DESC", "c", "d"]],
vec!["b"],
)
.run()
}
#[test]
fn test_union_equivalence_properties_constants_middle_desc() {
let schema = create_test_schema().unwrap();
UnionEquivalenceTest::new(&schema)
.with_child_sort_and_const_exprs(
vec![vec!["a", "b DESC", "d"]],
vec!["c"],
&schema,
)
.with_child_sort_and_const_exprs(
vec![vec!["a", "c", "d"]],
vec!["b"],
&schema,
)
.with_expected_sort_and_const_exprs(
vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]],
vec![],
)
.run()
}
#[derive(Debug)]
struct UnionEquivalenceTest {
output_schema: SchemaRef,
child_properties: Vec<EquivalenceProperties>,
expected_properties: Option<EquivalenceProperties>,
}
impl UnionEquivalenceTest {
fn new(output_schema: &SchemaRef) -> Self {
Self {
output_schema: Arc::clone(output_schema),
child_properties: vec![],
expected_properties: None,
}
}
fn with_child_sort(
mut self,
orderings: Vec<Vec<&str>>,
schema: &SchemaRef,
) -> Self {
let properties = self.make_props(orderings, vec![], schema);
self.child_properties.push(properties);
self
}
fn with_child_sort_and_const_exprs(
mut self,
orderings: Vec<Vec<&str>>,
constants: Vec<&str>,
schema: &SchemaRef,
) -> Self {
let properties = self.make_props(orderings, constants, schema);
self.child_properties.push(properties);
self
}
fn with_expected_sort(mut self, orderings: Vec<Vec<&str>>) -> Self {
let properties = self.make_props(orderings, vec![], &self.output_schema);
self.expected_properties = Some(properties);
self
}
fn with_expected_sort_and_const_exprs(
mut self,
orderings: Vec<Vec<&str>>,
constants: Vec<&str>,
) -> Self {
let properties = self.make_props(orderings, constants, &self.output_schema);
self.expected_properties = Some(properties);
self
}
fn run(self) {
let Self {
output_schema,
child_properties,
expected_properties,
} = self;
let expected_properties =
expected_properties.expect("expected_properties not set");
for child_properties in child_properties
.iter()
.cloned()
.permutations(child_properties.len())
{
println!("--- permutation ---");
for c in &child_properties {
println!("{c}");
}
let actual_properties =
calculate_union(child_properties, Arc::clone(&output_schema))
.expect("failed to calculate union equivalence properties");
Self::assert_eq_properties_same(
&actual_properties,
&expected_properties,
format!(
"expected: {expected_properties:?}\nactual: {actual_properties:?}"
),
);
}
}
fn assert_eq_properties_same(
lhs: &EquivalenceProperties,
rhs: &EquivalenceProperties,
err_msg: String,
) {
let lhs_constants = lhs.constants();
let rhs_constants = rhs.constants();
for rhs_constant in rhs_constants {
assert!(
const_exprs_contains(lhs_constants, rhs_constant.expr()),
"{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
);
}
assert_eq!(
lhs_constants.len(),
rhs_constants.len(),
"{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
);
let lhs_orderings = lhs.oeq_class();
let rhs_orderings = rhs.oeq_class();
for rhs_ordering in rhs_orderings.iter() {
assert!(
lhs_orderings.contains(rhs_ordering),
"{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
);
}
assert_eq!(
lhs_orderings.len(),
rhs_orderings.len(),
"{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
);
}
fn make_props(
&self,
orderings: Vec<Vec<&str>>,
constants: Vec<&str>,
schema: &SchemaRef,
) -> EquivalenceProperties {
let orderings = orderings
.iter()
.map(|ordering| {
ordering
.iter()
.map(|name| parse_sort_expr(name, schema))
.collect::<LexOrdering>()
})
.collect::<Vec<_>>();
let constants = constants
.iter()
.map(|col_name| ConstExpr::new(col(col_name, schema).unwrap()))
.collect::<Vec<_>>();
EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings)
.with_constants(constants)
}
}
#[test]
fn test_union_constant_value_preservation() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));
let col_a = col("a", &schema)?;
let literal_10 = ScalarValue::Int32(Some(10));
let const_expr1 = ConstExpr::new(Arc::clone(&col_a))
.with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone())));
let input1 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr1]);
let const_expr2 = ConstExpr::new(Arc::clone(&col_a))
.with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone())));
let input2 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr2]);
let union_props = calculate_union(vec![input1, input2], schema)?;
let const_a = &union_props.constants()[0];
assert!(const_a.expr().eq(&col_a));
assert_eq!(
const_a.across_partitions(),
AcrossPartitions::Uniform(Some(literal_10))
);
Ok(())
}
fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef {
Arc::new(Schema::new(
schema
.fields()
.iter()
.map(|field| {
Field::new(
format!("{}{}", field.name(), text),
field.data_type().clone(),
field.is_nullable(),
)
})
.collect::<Vec<_>>(),
))
}
}