kapot_executor/
collect.rsuse std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, pin::Pin};
use datafusion::arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use datafusion::error::DataFusionError;
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
SendableRecordBatchStream, Statistics,
};
use datafusion::{error::Result, physical_plan::RecordBatchStream};
use futures::stream::SelectAll;
use futures::Stream;
#[derive(Debug, Clone)]
pub struct CollectExec {
plan: Arc<dyn ExecutionPlan>,
properties: PlanProperties,
}
impl CollectExec {
pub fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
let properties = PlanProperties::new(
datafusion::physical_expr::EquivalenceProperties::new(plan.schema()),
Partitioning::UnknownPartitioning(1),
EmissionType::Both,
Boundedness::Bounded,
);
Self { plan, properties }
}
}
impl DisplayAs for CollectExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "CollectExec")
}
}
}
}
impl ExecutionPlan for CollectExec {
fn name(&self) -> &str {
"CollectExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.plan.schema()
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.plan]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq!(0, partition);
let num_partitions = self
.plan
.properties()
.output_partitioning()
.partition_count();
let streams = (0..num_partitions)
.map(|i| self.plan.execute(i, context.clone()))
.collect::<Result<Vec<_>>>()
.map_err(|e| DataFusionError::Execution(format!("kapotError: {e:?}")))?;
Ok(Box::pin(MergedRecordBatchStream {
schema: self.schema(),
select_all: Box::pin(futures::stream::select_all(streams)),
}))
}
fn statistics(&self) -> Result<Statistics> {
self.plan.statistics()
}
}
struct MergedRecordBatchStream {
schema: SchemaRef,
select_all: Pin<Box<SelectAll<SendableRecordBatchStream>>>,
}
impl Stream for MergedRecordBatchStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.select_all.as_mut().poll_next(cx)
}
}
impl RecordBatchStream for MergedRecordBatchStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}