use std::borrow::Borrow;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{any::Any, sync::Arc};
use super::{
ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
SendableRecordBatchStream, Statistics,
metrics::{ExecutionPlanMetricsSet, MetricsSet},
};
use crate::check_if_same_properties;
use crate::execution_plan::{
InvariantLevel, boundedness_from_children, check_default_invariants,
emission_type_from_children,
};
use crate::filter::FilterExec;
use crate::filter_pushdown::{
ChildPushdownResult, FilterDescription, FilterPushdownPhase,
FilterPushdownPropagation, PushedDown,
};
use crate::metrics::BaselineMetrics;
use crate::projection::{ProjectionExec, make_with_child};
use crate::stream::ObservedStream;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::config::ConfigOptions;
use datafusion_common::stats::Precision;
use datafusion_common::{
Result, assert_or_internal_err, exec_err, internal_datafusion_err,
};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{
EquivalenceProperties, PhysicalExpr, calculate_union, conjunction,
};
use futures::Stream;
use itertools::Itertools;
use log::{debug, trace, warn};
use tokio::macros::support::thread_rng_n;
#[derive(Debug, Clone)]
pub struct UnionExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
metrics: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
}
impl UnionExec {
#[deprecated(since = "44.0.0", note = "Use UnionExec::try_new instead")]
pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
let schema =
union_schema(&inputs).expect("UnionExec::new called with empty inputs");
let cache = Self::compute_properties(&inputs, schema).unwrap();
UnionExec {
inputs,
metrics: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
}
}
pub fn try_new(
inputs: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
match inputs.len() {
0 => exec_err!("UnionExec requires at least one input"),
1 => Ok(inputs.into_iter().next().unwrap()),
_ => {
let schema = union_schema(&inputs)?;
let cache = Self::compute_properties(&inputs, schema).unwrap();
Ok(Arc::new(UnionExec {
inputs,
metrics: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
}))
}
}
}
pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
&self.inputs
}
fn compute_properties(
inputs: &[Arc<dyn ExecutionPlan>],
schema: SchemaRef,
) -> Result<PlanProperties> {
let children_eqps = inputs
.iter()
.map(|child| child.equivalence_properties().clone())
.collect::<Vec<_>>();
let eq_properties = calculate_union(children_eqps, schema)?;
let num_partitions = inputs
.iter()
.map(|plan| plan.output_partitioning().partition_count())
.sum();
let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
emission_type_from_children(inputs),
boundedness_from_children(inputs),
))
}
fn with_new_children_and_same_properties(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
inputs: children,
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for UnionExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "UnionExec")
}
DisplayFormatType::TreeRender => Ok(()),
}
}
}
impl ExecutionPlan for UnionExec {
fn name(&self) -> &'static str {
"UnionExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
check_default_invariants(self, check)?;
(self.inputs().len() >= 2).then_some(()).ok_or_else(|| {
internal_datafusion_err!("UnionExec should have at least 2 children")
})
}
fn maintains_input_order(&self) -> Vec<bool> {
if let Some(output_ordering) = self.properties().output_ordering() {
self.inputs()
.iter()
.map(|child| {
if let Some(child_ordering) = child.output_ordering() {
output_ordering.len() == child_ordering.len()
} else {
false
}
})
.collect()
} else {
vec![false; self.inputs().len()]
}
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false; self.children().len()]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
self.inputs.iter().collect()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
UnionExec::try_new(children)
}
fn execute(
&self,
mut partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}",
partition,
context.session_id(),
context.task_id()
);
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
for input in self.inputs.iter() {
if partition < input.output_partitioning().partition_count() {
let stream = input.execute(partition, context)?;
debug!("Found a Union partition to execute");
return Ok(Box::pin(ObservedStream::new(
stream,
baseline_metrics,
None,
)));
} else {
partition -= input.output_partitioning().partition_count();
}
}
warn!("Error in Union: Partition {partition} not found");
exec_err!("Partition {partition} not found in Union")
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
if let Some(partition_idx) = partition {
let mut remaining_idx = partition_idx;
for input in &self.inputs {
let input_partition_count = input.output_partitioning().partition_count();
if remaining_idx < input_partition_count {
return input.partition_statistics(Some(remaining_idx));
}
remaining_idx -= input_partition_count;
}
Ok(Statistics::new_unknown(&self.schema()))
} else {
let stats = self
.inputs
.iter()
.map(|input_exec| input_exec.partition_statistics(None))
.collect::<Result<Vec<_>>>()?;
Ok(stats
.into_iter()
.reduce(stats_union)
.unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
}
}
fn supports_limit_pushdown(&self) -> bool {
true
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if projection.expr().len() >= projection.input().schema().fields().len() {
return Ok(None);
}
let new_children = self
.children()
.into_iter()
.map(|child| make_with_child(projection, child))
.collect::<Result<Vec<_>>>()?;
Ok(Some(UnionExec::try_new(new_children.clone())?))
}
fn gather_filters_for_pushdown(
&self,
_phase: FilterPushdownPhase,
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
_config: &ConfigOptions,
) -> Result<FilterDescription> {
FilterDescription::from_children(parent_filters, &self.children())
}
fn handle_child_pushdown_result(
&self,
phase: FilterPushdownPhase,
child_pushdown_result: ChildPushdownResult,
_config: &ConfigOptions,
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
if phase != FilterPushdownPhase::Pre {
return Ok(FilterPushdownPropagation::if_all(child_pushdown_result));
}
let mut unsupported_filters_per_child = vec![Vec::new(); self.inputs.len()];
for parent_filter_result in child_pushdown_result.parent_filters.iter() {
for (child_idx, &child_result) in
parent_filter_result.child_results.iter().enumerate()
{
if matches!(child_result, PushedDown::No) {
unsupported_filters_per_child[child_idx]
.push(Arc::clone(&parent_filter_result.filter));
}
}
}
let mut new_children = self.inputs.clone();
for (child_idx, unsupported_filters) in
unsupported_filters_per_child.iter().enumerate()
{
if !unsupported_filters.is_empty() {
let combined_filter = conjunction(unsupported_filters.clone());
new_children[child_idx] = Arc::new(FilterExec::try_new(
combined_filter,
Arc::clone(&self.inputs[child_idx]),
)?);
}
}
let children_modified = new_children
.iter()
.zip(self.inputs.iter())
.any(|(new, old)| !Arc::ptr_eq(new, old));
let all_filters_pushed =
vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()];
let propagation = if children_modified {
let updated_node = UnionExec::try_new(new_children)?;
FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed)
.with_updated_node(updated_node)
} else {
FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed)
};
Ok(propagation)
}
}
#[derive(Debug, Clone)]
pub struct InterleaveExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
metrics: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
}
impl InterleaveExec {
pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
assert_or_internal_err!(
can_interleave(inputs.iter()),
"Not all InterleaveExec children have a consistent hash partitioning"
);
let cache = Self::compute_properties(&inputs)?;
Ok(InterleaveExec {
inputs,
metrics: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
})
}
pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
&self.inputs
}
fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<PlanProperties> {
let schema = union_schema(inputs)?;
let eq_properties = EquivalenceProperties::new(schema);
let output_partitioning = inputs[0].output_partitioning().clone();
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
emission_type_from_children(inputs),
boundedness_from_children(inputs),
))
}
fn with_new_children_and_same_properties(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
inputs: children,
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for InterleaveExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "InterleaveExec")
}
DisplayFormatType::TreeRender => Ok(()),
}
}
}
impl ExecutionPlan for InterleaveExec {
fn name(&self) -> &'static str {
"InterleaveExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
self.inputs.iter().collect()
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false; self.inputs().len()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
assert_or_internal_err!(
can_interleave(children.iter()),
"Can not create InterleaveExec: new children can not be interleaved"
);
check_if_same_properties!(self, children);
Ok(Arc::new(InterleaveExec::try_new(children)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}",
partition,
context.session_id(),
context.task_id()
);
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
let mut input_stream_vec = vec![];
for input in self.inputs.iter() {
if partition < input.output_partitioning().partition_count() {
input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
} else {
break;
}
}
if input_stream_vec.len() == self.inputs.len() {
let stream = Box::pin(CombinedRecordBatchStream::new(
self.schema(),
input_stream_vec,
));
return Ok(Box::pin(ObservedStream::new(
stream,
baseline_metrics,
None,
)));
}
warn!("Error in InterleaveExec: Partition {partition} not found");
exec_err!("Partition {partition} not found in InterleaveExec")
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
let stats = self
.inputs
.iter()
.map(|stat| stat.partition_statistics(partition))
.collect::<Result<Vec<_>>>()?;
Ok(stats
.into_iter()
.reduce(stats_union)
.unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false; self.children().len()]
}
}
pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
mut inputs: impl Iterator<Item = T>,
) -> bool {
let Some(first) = inputs.next() else {
return false;
};
let reference = first.borrow().output_partitioning();
matches!(reference, Partitioning::Hash(_, _))
&& inputs
.map(|plan| plan.borrow().output_partitioning().clone())
.all(|partition| partition == *reference)
}
fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> Result<SchemaRef> {
if inputs.is_empty() {
return exec_err!("Cannot create union schema from empty inputs");
}
let first_schema = inputs[0].schema();
let first_field_count = first_schema.fields().len();
for (idx, input) in inputs.iter().enumerate().skip(1) {
let field_count = input.schema().fields().len();
if field_count != first_field_count {
return exec_err!(
"UnionExec/InterleaveExec requires all inputs to have the same number of fields. \
Input 0 has {first_field_count} fields, but input {idx} has {field_count} fields"
);
}
}
let fields = (0..first_field_count)
.map(|i| {
let base_field = first_schema.field(i).clone();
inputs
.iter()
.enumerate()
.map(|(input_idx, input)| {
let field = input.schema().field(i).clone();
let mut metadata = field.metadata().clone();
let other_metadatas = inputs
.iter()
.enumerate()
.filter(|(other_idx, _)| *other_idx != input_idx)
.flat_map(|(_, other_input)| {
other_input.schema().field(i).metadata().clone().into_iter()
});
metadata.extend(other_metadatas);
field.with_metadata(metadata)
})
.find_or_first(Field::is_nullable)
.unwrap()
.with_name(base_field.name())
})
.collect::<Vec<_>>();
let all_metadata_merged = inputs
.iter()
.flat_map(|i| i.schema().metadata().clone().into_iter())
.collect();
Ok(Arc::new(Schema::new_with_metadata(
fields,
all_metadata_merged,
)))
}
struct CombinedRecordBatchStream {
schema: SchemaRef,
entries: Vec<SendableRecordBatchStream>,
}
impl CombinedRecordBatchStream {
pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
Self { schema, entries }
}
}
impl RecordBatchStream for CombinedRecordBatchStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
impl Stream for CombinedRecordBatchStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
use Poll::*;
let start = thread_rng_n(self.entries.len() as u32) as usize;
let mut idx = start;
for _ in 0..self.entries.len() {
let stream = self.entries.get_mut(idx).unwrap();
match Pin::new(stream).poll_next(cx) {
Ready(Some(val)) => return Ready(Some(val)),
Ready(None) => {
self.entries.swap_remove(idx);
if idx == self.entries.len() {
idx = 0;
} else if idx < start && start <= self.entries.len() {
idx = idx.wrapping_add(1) % self.entries.len();
}
}
Pending => {
idx = idx.wrapping_add(1) % self.entries.len();
}
}
}
if self.entries.is_empty() {
Ready(None)
} else {
Pending
}
}
}
fn col_stats_union(
mut left: ColumnStatistics,
right: &ColumnStatistics,
) -> ColumnStatistics {
left.distinct_count = Precision::Absent;
left.min_value = left.min_value.min(&right.min_value);
left.max_value = left.max_value.max(&right.max_value);
left.sum_value = left.sum_value.add(&right.sum_value);
left.null_count = left.null_count.add(&right.null_count);
left
}
fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
let Statistics {
num_rows: right_num_rows,
total_byte_size: right_total_bytes,
column_statistics: right_column_statistics,
..
} = right;
left.num_rows = left.num_rows.add(&right_num_rows);
left.total_byte_size = left.total_byte_size.add(&right_total_bytes);
left.column_statistics = left
.column_statistics
.into_iter()
.zip(right_column_statistics.iter())
.map(|(a, b)| col_stats_union(a, b))
.collect::<Vec<_>>();
left
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collect;
use crate::test::{self, TestMemoryExec};
use arrow::compute::SortOptions;
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::equivalence::convert_to_orderings;
use datafusion_physical_expr::expressions::col;
fn create_test_schema() -> Result<SchemaRef> {
let a = Field::new("a", DataType::Int32, true);
let b = Field::new("b", DataType::Int32, true);
let c = Field::new("c", DataType::Int32, true);
let d = Field::new("d", DataType::Int32, true);
let e = Field::new("e", DataType::Int32, true);
let f = Field::new("f", DataType::Int32, true);
let g = Field::new("g", DataType::Int32, true);
let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
Ok(schema)
}
fn create_test_schema2() -> Result<SchemaRef> {
let a = Field::new("a", DataType::Int32, true);
let b = Field::new("b", DataType::Int32, true);
let c = Field::new("c", DataType::Int32, true);
let d = Field::new("d", DataType::Int32, true);
let e = Field::new("e", DataType::Int32, true);
let f = Field::new("f", DataType::Int32, true);
let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f]));
Ok(schema)
}
#[tokio::test]
async fn test_union_partitions() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let csv = test::scan_partitioned(4);
let csv2 = test::scan_partitioned(5);
let union_exec: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![csv, csv2])?;
assert_eq!(
union_exec
.properties()
.output_partitioning()
.partition_count(),
9
);
let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
assert_eq!(result.len(), 9);
Ok(())
}
#[tokio::test]
async fn test_stats_union() {
let left = Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(23),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
null_count: Precision::Exact(0),
byte_size: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Exact(1),
max_value: Precision::Exact(ScalarValue::from("x")),
min_value: Precision::Exact(ScalarValue::from("a")),
sum_value: Precision::Absent,
null_count: Precision::Exact(3),
byte_size: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
null_count: Precision::Absent,
byte_size: Precision::Absent,
},
],
};
let right = Statistics {
num_rows: Precision::Exact(7),
total_byte_size: Precision::Exact(29),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Exact(3),
max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
null_count: Precision::Exact(1),
byte_size: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::from("c")),
min_value: Precision::Exact(ScalarValue::from("b")),
sum_value: Precision::Absent,
null_count: Precision::Absent,
byte_size: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
null_count: Precision::Absent,
byte_size: Precision::Absent,
},
],
};
let result = stats_union(left, right);
let expected = Statistics {
num_rows: Precision::Exact(12),
total_byte_size: Precision::Exact(52),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
null_count: Precision::Exact(1),
byte_size: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::from("x")),
min_value: Precision::Exact(ScalarValue::from("a")),
sum_value: Precision::Absent,
null_count: Precision::Absent,
byte_size: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
null_count: Precision::Absent,
byte_size: Precision::Absent,
},
],
};
assert_eq!(result, expected);
}
#[tokio::test]
async fn test_union_equivalence_properties() -> Result<()> {
let schema = create_test_schema()?;
let col_a = &col("a", &schema)?;
let col_b = &col("b", &schema)?;
let col_c = &col("c", &schema)?;
let col_d = &col("d", &schema)?;
let col_e = &col("e", &schema)?;
let col_f = &col("f", &schema)?;
let options = SortOptions::default();
let test_cases = [
(
vec![
vec![(col_a, options), (col_b, options), (col_f, options)],
],
vec![
vec![(col_a, options), (col_b, options), (col_c, options)],
vec![(col_a, options), (col_b, options), (col_f, options)],
],
vec![
vec![(col_a, options), (col_b, options), (col_f, options)],
],
),
(
vec![
vec![(col_a, options), (col_b, options), (col_f, options)],
vec![(col_d, options)],
],
vec![
vec![(col_a, options), (col_b, options), (col_c, options)],
vec![(col_e, options)],
],
vec![
vec![(col_a, options), (col_b, options)],
],
),
];
for (
test_idx,
(first_child_orderings, second_child_orderings, union_orderings),
) in test_cases.iter().enumerate()
{
let first_orderings = convert_to_orderings(first_child_orderings);
let second_orderings = convert_to_orderings(second_child_orderings);
let union_expected_orderings = convert_to_orderings(union_orderings);
let child1_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
.try_with_sort_information(first_orderings)?;
let child1 = Arc::new(child1_exec);
let child1 = Arc::new(TestMemoryExec::update_cache(&child1));
let child2_exec = TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
.try_with_sort_information(second_orderings)?;
let child2 = Arc::new(child2_exec);
let child2 = Arc::new(TestMemoryExec::update_cache(&child2));
let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
union_expected_eq.add_orderings(union_expected_orderings);
let union: Arc<dyn ExecutionPlan> = UnionExec::try_new(vec![child1, child2])?;
let union_eq_properties = union.properties().equivalence_properties();
let err_msg = format!(
"Error in test id: {:?}, test case: {:?}",
test_idx, test_cases[test_idx]
);
assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
}
Ok(())
}
fn assert_eq_properties_same(
lhs: &EquivalenceProperties,
rhs: &EquivalenceProperties,
err_msg: String,
) {
let lhs_orderings = lhs.oeq_class();
let rhs_orderings = rhs.oeq_class();
assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}");
for rhs_ordering in rhs_orderings.iter() {
assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
}
}
#[test]
fn test_union_empty_inputs() {
let result = UnionExec::try_new(vec![]);
assert!(
result
.unwrap_err()
.to_string()
.contains("UnionExec requires at least one input")
);
}
#[test]
fn test_union_schema_empty_inputs() {
let result = union_schema(&[]);
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot create union schema from empty inputs")
);
}
#[test]
fn test_union_single_input() -> Result<()> {
let schema = create_test_schema()?;
let memory_exec: Arc<dyn ExecutionPlan> =
Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
let memory_exec_clone = Arc::clone(&memory_exec);
let result = UnionExec::try_new(vec![memory_exec])?;
assert_eq!(result.schema(), schema);
assert!(Arc::ptr_eq(&result, &memory_exec_clone));
Ok(())
}
#[test]
fn test_union_schema_multiple_inputs() -> Result<()> {
let schema = create_test_schema()?;
let memory_exec1 =
Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
let memory_exec2 =
Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
let union_plan = UnionExec::try_new(vec![memory_exec1, memory_exec2])?;
let union = union_plan
.as_any()
.downcast_ref::<UnionExec>()
.expect("Expected UnionExec");
assert_eq!(union.schema(), schema);
assert_eq!(union.inputs().len(), 2);
Ok(())
}
#[test]
fn test_union_schema_mismatch() {
let schema = create_test_schema().unwrap();
let schema2 = create_test_schema2().unwrap();
let memory_exec1 =
Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None).unwrap());
let memory_exec2 =
Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema2), None).unwrap());
let result = UnionExec::try_new(vec![memory_exec1, memory_exec2]);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains(
"UnionExec/InterleaveExec requires all inputs to have the same number of fields"
)
);
}
}