use crate::execution::context::TaskContext;
use crate::physical_plan::aggregates::{
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
AggregateMode,
};
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
use futures::stream::BoxStream;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use futures::stream::{Stream, StreamExt};
pub(crate) struct AggregateStream {
stream: BoxStream<'static, Result<RecordBatch>>,
schema: SchemaRef,
}
struct AggregateStreamInner {
schema: SchemaRef,
mode: AggregateMode,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
accumulators: Vec<AccumulatorItem>,
reservation: MemoryReservation,
finished: bool,
}
impl AggregateStream {
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
context: Arc<TaskContext>,
partition: usize,
) -> Result<Self> {
let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
let accumulators = create_accumulators(&aggr_expr)?;
let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
.register(context.memory_pool());
let inner = AggregateStreamInner {
schema: Arc::clone(&schema),
mode,
input,
baseline_metrics,
aggregate_expressions,
accumulators,
reservation,
finished: false,
};
let stream = futures::stream::unfold(inner, |mut this| async move {
if this.finished {
return None;
}
let elapsed_compute = this.baseline_metrics.elapsed_compute();
loop {
let result = match this.input.next().await {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = aggregate_batch(
&this.mode,
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
);
timer.done();
match result
.and_then(|allocated| this.reservation.try_grow(allocated))
{
Ok(_) => continue,
Err(e) => Err(e),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = finalize_aggregation(&this.accumulators, &this.mode)
.and_then(|columns| {
RecordBatch::try_new(this.schema.clone(), columns)
.map_err(Into::into)
})
.record_output(&this.baseline_metrics);
timer.done();
result
}
};
this.finished = true;
return Some((result, this));
}
});
let stream = stream.fuse();
let stream = Box::pin(stream);
Ok(Self { schema, stream })
}
}
impl Stream for AggregateStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
this.stream.poll_next_unpin(cx)
}
}
impl RecordBatchStream for AggregateStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<usize> {
let mut allocated = 0usize;
accumulators
.iter_mut()
.zip(expressions)
.try_for_each(|(accum, expr)| {
let values = &expr
.iter()
.map(|e| e.evaluate(batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial => accum.update_batch(values),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accum.merge_batch(values)
}
};
let size_post = accum.size();
allocated += size_post.saturating_sub(size_pre);
res
})?;
Ok(allocated)
}