use std::{any::Any, sync::Arc};
use arrow::datatypes::SchemaRef;
use super::{ExecutionPlan, Partitioning, SendableRecordBatchStream};
use crate::error::Result;
use async_trait::async_trait;
#[derive(Debug)]
pub struct UnionExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
}
impl UnionExec {
pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
UnionExec { inputs }
}
}
#[async_trait]
impl ExecutionPlan for UnionExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.inputs[0].schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.inputs.clone()
}
fn output_partitioning(&self) -> Partitioning {
let num_partitions = self
.inputs
.iter()
.map(|plan| plan.output_partitioning().partition_count())
.sum();
Partitioning::UnknownPartitioning(num_partitions)
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(UnionExec::new(children)))
}
async fn execute(&self, mut partition: usize) -> Result<SendableRecordBatchStream> {
for input in self.inputs.iter() {
if partition < input.output_partitioning().partition_count() {
return input.execute(partition).await;
} else {
partition -= input.output_partitioning().partition_count();
}
}
Err(crate::error::DataFusionError::Execution(format!(
"Partition {} not found in Union",
partition
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::{
collect,
csv::{CsvExec, CsvReadOptions},
};
use crate::test;
use arrow::record_batch::RecordBatch;
#[tokio::test]
async fn test_union_partitions() -> Result<()> {
let schema = test::aggr_test_schema();
let path = test::create_partitioned_csv("aggregate_test_100.csv", 4)?;
let path2 = test::create_partitioned_csv("aggregate_test_100.csv", 5)?;
let csv = CsvExec::try_new(
&path,
CsvReadOptions::new().schema(&schema),
None,
1024,
None,
)?;
let csv2 = CsvExec::try_new(
&path2,
CsvReadOptions::new().schema(&schema),
None,
1024,
None,
)?;
let union_exec = Arc::new(UnionExec::new(vec![Arc::new(csv), Arc::new(csv2)]));
assert_eq!(union_exec.output_partitioning().partition_count(), 9);
let result: Vec<RecordBatch> = collect(union_exec).await?;
assert_eq!(result.len(), 9);
Ok(())
}
}