use crate::aggregates::{
AccumulatorItem, AggrDynFilter, AggregateInputMode, AggregateMode,
DynamicFilterAggregateType, aggregate_expressions, create_accumulators,
finalize_aggregation,
};
use crate::metrics::{BaselineMetrics, RecordOutput};
use crate::{RecordBatchStream, SendableRecordBatchStream};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{Result, ScalarValue, internal_datafusion_err, internal_err};
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::expressions::{BinaryExpr, lit};
use futures::stream::BoxStream;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::sync::Arc;
use std::task::{Context, Poll};
use super::AggregateExec;
use crate::filter::batch_filter;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use futures::stream::{Stream, StreamExt};
pub(crate) struct AggregateStream {
stream: BoxStream<'static, Result<RecordBatch>>,
schema: SchemaRef,
}
struct AggregateStreamInner {
schema: SchemaRef,
mode: AggregateMode,
input: SendableRecordBatchStream,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
filter_expressions: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
accumulators: Vec<AccumulatorItem>,
agg_dyn_filter_state: Option<Arc<AggrDynFilter>>,
finished: bool,
baseline_metrics: BaselineMetrics,
reservation: MemoryReservation,
}
impl AggregateStreamInner {
fn build_dynamic_filter_from_accumulator_bounds(
&self,
) -> Result<Arc<dyn PhysicalExpr>> {
let Some(filter_state) = self.agg_dyn_filter_state.as_ref() else {
return internal_err!(
"`build_dynamic_filter_from_accumulator_bounds()` is only called when dynamic filter is enabled"
);
};
let mut predicates: Vec<Arc<dyn PhysicalExpr>> =
Vec::with_capacity(filter_state.supported_accumulators_info.len());
for acc_info in &filter_state.supported_accumulators_info {
let bound = {
let guard = acc_info.shared_bound.lock();
if (*guard).is_null() {
continue;
}
guard.clone()
};
let agg_exprs = self
.aggregate_expressions
.get(acc_info.aggr_index)
.ok_or_else(|| {
internal_datafusion_err!(
"Invalid aggregate expression index {} for dynamic filter",
acc_info.aggr_index
)
})?;
let column_expr = agg_exprs.first().ok_or_else(|| {
internal_datafusion_err!(
"Aggregate expression at index {} expected a single argument",
acc_info.aggr_index
)
})?;
let literal = lit(bound);
let predicate: Arc<dyn PhysicalExpr> = match acc_info.aggr_type {
DynamicFilterAggregateType::Min => Arc::new(BinaryExpr::new(
Arc::clone(column_expr),
Operator::Lt,
literal,
)),
DynamicFilterAggregateType::Max => Arc::new(BinaryExpr::new(
Arc::clone(column_expr),
Operator::Gt,
literal,
)),
};
predicates.push(predicate);
}
let combined = predicates.into_iter().reduce(|acc, pred| {
Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) as Arc<dyn PhysicalExpr>
});
Ok(combined.unwrap_or_else(|| lit(true)))
}
fn maybe_update_dyn_filter(&mut self) -> Result<()> {
let Some(filter_state) = self.agg_dyn_filter_state.as_ref() else {
return Ok(());
};
let mut bounds_changed = false;
for acc_info in &filter_state.supported_accumulators_info {
let acc =
self.accumulators
.get_mut(acc_info.aggr_index)
.ok_or_else(|| {
internal_datafusion_err!(
"Invalid accumulator index {} for dynamic filter",
acc_info.aggr_index
)
})?;
let current_bound = acc.evaluate()?;
{
let mut bound = acc_info.shared_bound.lock();
let new_bound = match acc_info.aggr_type {
DynamicFilterAggregateType::Max => {
scalar_max(&bound, ¤t_bound)?
}
DynamicFilterAggregateType::Min => {
scalar_min(&bound, ¤t_bound)?
}
};
if new_bound != *bound {
*bound = new_bound;
bounds_changed = true;
}
}
}
if bounds_changed {
let predicate = self.build_dynamic_filter_from_accumulator_bounds()?;
filter_state.filter.update(predicate)?;
}
Ok(())
}
}
fn scalar_min(v1: &ScalarValue, v2: &ScalarValue) -> Result<ScalarValue> {
if let Some(result) = scalar_cmp_null_short_circuit(v1, v2) {
return Ok(result);
}
match v1.partial_cmp(v2) {
Some(Ordering::Less | Ordering::Equal) => Ok(v1.clone()),
Some(Ordering::Greater) => Ok(v2.clone()),
None => datafusion_common::internal_err!(
"cannot compare values of different or incompatible types: {v1:?} vs {v2:?}"
),
}
}
fn scalar_max(v1: &ScalarValue, v2: &ScalarValue) -> Result<ScalarValue> {
if let Some(result) = scalar_cmp_null_short_circuit(v1, v2) {
return Ok(result);
}
match v1.partial_cmp(v2) {
Some(Ordering::Greater | Ordering::Equal) => Ok(v1.clone()),
Some(Ordering::Less) => Ok(v2.clone()),
None => datafusion_common::internal_err!(
"cannot compare values of different or incompatible types: {v1:?} vs {v2:?}"
),
}
}
fn scalar_cmp_null_short_circuit(
v1: &ScalarValue,
v2: &ScalarValue,
) -> Option<ScalarValue> {
match (v1, v2) {
(ScalarValue::Null, ScalarValue::Null) => Some(ScalarValue::Null),
(ScalarValue::Null, other) | (other, ScalarValue::Null) => Some(other.clone()),
_ => None,
}
}
fn prepend_grouping_id_column(
mut columns: Vec<Arc<dyn arrow::array::Array>>,
grouping_id: Option<&ScalarValue>,
) -> Result<Vec<Arc<dyn arrow::array::Array>>> {
if let Some(id) = grouping_id {
let num_rows = columns.first().map(|array| array.len()).unwrap_or(1);
let grouping_ids = id.to_array_of_size(num_rows)?;
columns.insert(0, grouping_ids);
}
Ok(columns)
}
impl AggregateStream {
pub fn new(
agg: &AggregateExec,
context: &Arc<TaskContext>,
partition: usize,
) -> Result<Self> {
let agg_schema = Arc::clone(&agg.schema);
let agg_filter_expr = Arc::clone(&agg.filter_expr);
let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
let input = agg.input.execute(partition, Arc::clone(context))?;
let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?;
let filter_expressions = match agg.mode.input_mode() {
AggregateInputMode::Raw => agg_filter_expr,
AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(),
};
let accumulators = create_accumulators(&agg.aggr_expr)?;
let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
.register(context.memory_pool());
let mut maybe_dynamic_filter = match agg.dynamic_filter.as_ref() {
Some(filter) => Some(Arc::clone(filter)),
_ => None,
};
if !context
.session_config()
.options()
.optimizer
.enable_aggregate_dynamic_filter_pushdown
{
maybe_dynamic_filter = None;
}
let inner = AggregateStreamInner {
schema: Arc::clone(&agg.schema),
mode: agg.mode,
input,
baseline_metrics,
aggregate_expressions,
filter_expressions,
accumulators,
reservation,
finished: false,
agg_dyn_filter_state: maybe_dynamic_filter,
};
let stream = futures::stream::unfold(inner, |mut this| async move {
if this.finished {
return None;
}
loop {
let result = match this.input.next().await {
Some(Ok(batch)) => {
let result = {
let elapsed_compute = this.baseline_metrics.elapsed_compute();
let _timer = elapsed_compute.timer(); aggregate_batch(
&this.mode,
&batch,
&mut this.accumulators,
&this.aggregate_expressions,
&this.filter_expressions,
)
};
let result = result.and_then(|allocated| {
this.maybe_update_dyn_filter()?;
Ok(allocated)
});
match result
.and_then(|allocated| this.reservation.try_grow(allocated))
{
Ok(_) => continue,
Err(e) => Err(e),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result =
finalize_aggregation(&mut this.accumulators, &this.mode)
.and_then(|columns| {
prepend_grouping_id_column(columns, None)
})
.and_then(|columns| {
RecordBatch::try_new(
Arc::clone(&this.schema),
columns,
)
.map_err(Into::into)
})
.record_output(&this.baseline_metrics);
timer.done();
result
}
};
this.finished = true;
return Some((result, this));
}
});
let stream = stream.fuse();
let stream = Box::pin(stream);
Ok(Self {
schema: agg_schema,
stream,
})
}
}
impl Stream for AggregateStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
this.stream.poll_next_unpin(cx)
}
}
impl RecordBatchStream for AggregateStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
filters: &[Option<Arc<dyn PhysicalExpr>>],
) -> Result<usize> {
let mut allocated = 0usize;
accumulators
.iter_mut()
.zip(expressions)
.zip(filters)
.try_for_each(|((accum, expr), filter)| {
let batch = match filter {
Some(filter) => Cow::Owned(batch_filter(batch, filter)?),
None => Cow::Borrowed(batch),
};
let values = evaluate_expressions_to_arrays(expr, batch.as_ref())?;
let size_pre = accum.size();
let res = match mode.input_mode() {
AggregateInputMode::Raw => accum.update_batch(&values),
AggregateInputMode::Partial => accum.merge_batch(&values),
};
let size_post = accum.size();
allocated += size_post.saturating_sub(size_pre);
res
})?;
Ok(allocated)
}