use std::any::Any;
use std::sync::Arc;
use futures::channel::mpsc;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use futures::Stream;
use async_trait::async_trait;
use arrow::record_batch::RecordBatch;
use arrow::{
datatypes::SchemaRef,
error::{ArrowError, Result as ArrowResult},
};
use super::RecordBatchStream;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::Partitioning;
use super::SendableRecordBatchStream;
use pin_project_lite::pin_project;
#[derive(Debug)]
pub struct MergeExec {
input: Arc<dyn ExecutionPlan>,
}
impl MergeExec {
pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
MergeExec { input }
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
}
#[async_trait]
impl ExecutionPlan for MergeExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
match children.len() {
1 => Ok(Arc::new(MergeExec::new(children[0].clone()))),
_ => Err(DataFusionError::Internal(
"MergeExec wrong number of children".to_string(),
)),
}
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
if 0 != partition {
return Err(DataFusionError::Internal(format!(
"MergeExec invalid partition {}",
partition
)));
}
let input_partitions = self.input.output_partitioning().partition_count();
match input_partitions {
0 => Err(DataFusionError::Internal(
"MergeExec requires at least one input partition".to_owned(),
)),
1 => {
self.input.execute(0).await
}
_ => {
let (sender, receiver) =
mpsc::channel::<ArrowResult<RecordBatch>>(input_partitions);
for part_i in 0..input_partitions {
let input = self.input.clone();
let mut sender = sender.clone();
tokio::spawn(async move {
let mut stream = match input.execute(part_i).await {
Err(e) => {
let arrow_error = ArrowError::ExternalError(Box::new(e));
sender.send(Err(arrow_error)).await.ok();
return;
}
Ok(stream) => stream,
};
while let Some(item) = stream.next().await {
sender.send(item).await.ok();
}
});
}
Ok(Box::pin(MergeStream {
input: receiver,
schema: self.schema(),
}))
}
}
}
}
pin_project! {
struct MergeStream {
schema: SchemaRef,
#[pin]
input: mpsc::Receiver<ArrowResult<RecordBatch>>,
}
}
impl Stream for MergeStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
this.input.poll_next(cx)
}
}
impl RecordBatchStream for MergeStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::common;
use crate::physical_plan::csv::{CsvExec, CsvReadOptions};
use crate::test;
#[tokio::test]
async fn merge() -> Result<()> {
let schema = test::aggr_test_schema();
let num_partitions = 4;
let path =
test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?;
let csv = CsvExec::try_new(
&path,
CsvReadOptions::new().schema(&schema),
None,
1024,
None,
)?;
assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
let merge = MergeExec::new(Arc::new(csv));
assert_eq!(merge.output_partitioning().partition_count(), 1);
let iter = merge.execute(0).await?;
let batches = common::collect(iter).await?;
assert_eq!(batches.len(), num_partitions);
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(row_count, 100);
Ok(())
}
}