use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
use super::AggregateExec;
use super::order::GroupOrdering;
use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values};
use crate::aggregates::order::GroupOrderingFull;
use crate::aggregates::{
AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy,
create_schema, evaluate_group_by, evaluate_many, evaluate_optional,
};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::spill::spill_manager::{GetSlicedSize, SpillManager};
use crate::{PhysicalExpr, aggregates, metrics};
use crate::{RecordBatchStream, SendableRecordBatchStream};
use arrow::array::*;
use arrow::datatypes::SchemaRef;
use datafusion_common::{
DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err,
internal_err, resources_datafusion_err,
};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_expr::{EmitTo, GroupsAccumulator};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use crate::sorts::IncrementalSortIterator;
use datafusion_common::instant::Instant;
use datafusion_common::utils::memory::get_record_batch_memory_size;
use futures::ready;
use futures::stream::{Stream, StreamExt};
use log::debug;
#[derive(Debug, Clone)]
pub(crate) enum ExecutionState {
ReadingInput,
ProducingOutput(RecordBatch),
SkippingAggregation,
Done,
}
struct SpillState {
spill_expr: LexOrdering,
spill_schema: SchemaRef,
merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
merging_group_by: PhysicalGroupBy,
spill_manager: SpillManager,
spills: Vec<SortedSpillFile>,
is_stream_merging: bool,
peak_mem_used: metrics::Gauge,
}
struct SkipAggregationProbe {
probe_rows_threshold: usize,
probe_ratio_threshold: f64,
input_rows: usize,
num_groups: usize,
should_skip: bool,
is_locked: bool,
skipped_aggregation_rows: metrics::Count,
}
impl SkipAggregationProbe {
fn new(
probe_rows_threshold: usize,
probe_ratio_threshold: f64,
skipped_aggregation_rows: metrics::Count,
) -> Self {
Self {
input_rows: 0,
num_groups: 0,
probe_rows_threshold,
probe_ratio_threshold,
should_skip: false,
is_locked: false,
skipped_aggregation_rows,
}
}
fn update_state(&mut self, input_rows: usize, num_groups: usize) {
if self.is_locked {
return;
}
self.input_rows += input_rows;
self.num_groups = num_groups;
if self.input_rows >= self.probe_rows_threshold {
self.should_skip = self.num_groups as f64 / self.input_rows as f64
>= self.probe_ratio_threshold;
self.is_locked = self.should_skip;
}
}
fn should_skip(&self) -> bool {
self.should_skip
}
fn record_skipped(&mut self, batch: &RecordBatch) {
self.skipped_aggregation_rows.add(batch.num_rows());
}
}
#[derive(PartialEq, Debug)]
enum OutOfMemoryMode {
Spill,
EmitEarly,
ReportError,
}
pub(crate) struct GroupedHashAggregateStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
mode: AggregateMode,
aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
filter_expressions: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
group_by: Arc<PhysicalGroupBy>,
batch_size: usize,
group_values_soft_limit: Option<usize>,
exec_state: ExecutionState,
input_done: bool,
group_values: Box<dyn GroupValues>,
current_group_indices: Vec<usize>,
accumulators: Vec<Box<dyn GroupsAccumulator>>,
group_ordering: GroupOrdering,
spill_state: SpillState,
skip_aggregation_probe: Option<SkipAggregationProbe>,
reservation: MemoryReservation,
oom_mode: OutOfMemoryMode,
baseline_metrics: BaselineMetrics,
group_by_metrics: GroupByMetrics,
reduction_factor: Option<metrics::RatioMetrics>,
}
impl GroupedHashAggregateStream {
pub fn new(
agg: &AggregateExec,
context: &Arc<TaskContext>,
partition: usize,
) -> Result<Self> {
debug!("Creating GroupedHashAggregateStream");
let agg_schema = Arc::clone(&agg.schema);
let agg_group_by = Arc::clone(&agg.group_by);
let agg_filter_expr = Arc::clone(&agg.filter_expr);
let batch_size = context.session_config().batch_size();
let input = agg.input.execute(partition, Arc::clone(context))?;
let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
let group_by_metrics = GroupByMetrics::new(&agg.metrics, partition);
let timer = baseline_metrics.elapsed_compute().timer();
let aggregate_exprs = Arc::clone(&agg.aggr_expr);
let aggregate_arguments = aggregates::aggregate_expressions(
&agg.aggr_expr,
&agg.mode,
agg_group_by.num_group_exprs(),
)?;
let merging_aggregate_arguments = aggregates::aggregate_expressions(
&agg.aggr_expr,
&AggregateMode::Final,
agg_group_by.num_group_exprs(),
)?;
let filter_expressions = match agg.mode.input_mode() {
AggregateInputMode::Raw => agg_filter_expr,
AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(),
};
let accumulators: Vec<_> = aggregate_exprs
.iter()
.map(create_group_accumulator)
.collect::<Result<_>>()?;
let group_schema = agg_group_by.group_schema(&agg.input().schema())?;
let spill_schema = Arc::new(create_schema(
&agg.input().schema(),
&agg_group_by,
&aggregate_exprs,
AggregateMode::Partial,
)?);
let merging_group_by_expr = agg_group_by
.expr
.iter()
.enumerate()
.map(|(idx, (_, name))| {
(Arc::new(Column::new(name.as_str(), idx)) as _, name.clone())
})
.collect();
let output_ordering = agg.cache.output_ordering();
let spill_sort_exprs =
group_schema
.fields
.into_iter()
.enumerate()
.map(|(idx, field)| {
let output_expr = Column::new(field.name().as_str(), idx);
let sort_options = output_ordering
.and_then(|o| o.get_sort_options(&output_expr))
.unwrap_or_default();
PhysicalSortExpr::new(Arc::new(output_expr), sort_options)
});
let Some(spill_ordering) = LexOrdering::new(spill_sort_exprs) else {
return internal_err!("Spill expression is empty");
};
let agg_fn_names = aggregate_exprs
.iter()
.map(|expr| expr.human_display())
.collect::<Vec<_>>()
.join(", ");
let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})");
let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?;
let oom_mode = match (agg.mode, &group_ordering) {
(AggregateMode::Partial, _) => OutOfMemoryMode::EmitEarly,
(_, GroupOrdering::None | GroupOrdering::Partial(_))
if context.runtime_env().disk_manager.tmp_files_enabled() =>
{
OutOfMemoryMode::Spill
}
_ => OutOfMemoryMode::ReportError,
};
let group_values = new_group_values(group_schema, &group_ordering)?;
let reservation = MemoryConsumer::new(name)
.with_can_spill(oom_mode != OutOfMemoryMode::ReportError)
.register(context.memory_pool());
timer.done();
let exec_state = ExecutionState::ReadingInput;
let spill_manager = SpillManager::new(
context.runtime_env(),
metrics::SpillMetrics::new(&agg.metrics, partition),
Arc::clone(&spill_schema),
)
.with_compression_type(context.session_config().spill_compression());
let spill_state = SpillState {
spills: vec![],
spill_expr: spill_ordering,
spill_schema,
is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr),
peak_mem_used: MetricBuilder::new(&agg.metrics)
.gauge("peak_mem_used", partition),
spill_manager,
};
let skip_aggregation_probe = if agg.mode == AggregateMode::Partial
&& matches!(group_ordering, GroupOrdering::None)
&& accumulators
.iter()
.all(|acc| acc.supports_convert_to_state())
&& agg_group_by.is_single()
{
let options = &context.session_config().options().execution;
let probe_rows_threshold =
options.skip_partial_aggregation_probe_rows_threshold;
let probe_ratio_threshold =
options.skip_partial_aggregation_probe_ratio_threshold;
let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
.counter("skipped_aggregation_rows", partition);
Some(SkipAggregationProbe::new(
probe_rows_threshold,
probe_ratio_threshold,
skipped_aggregation_rows,
))
} else {
None
};
let reduction_factor = if agg.mode == AggregateMode::Partial {
Some(
MetricBuilder::new(&agg.metrics)
.with_type(metrics::MetricType::SUMMARY)
.ratio_metrics("reduction_factor", partition),
)
} else {
None
};
Ok(GroupedHashAggregateStream {
schema: agg_schema,
input,
mode: agg.mode,
accumulators,
aggregate_arguments,
filter_expressions,
group_by: agg_group_by,
reservation,
oom_mode,
group_values,
current_group_indices: Default::default(),
exec_state,
baseline_metrics,
group_by_metrics,
batch_size,
group_ordering,
input_done: false,
spill_state,
group_values_soft_limit: agg.limit_options().map(|config| config.limit()),
skip_aggregation_probe,
reduction_factor,
})
}
}
pub(crate) fn create_group_accumulator(
agg_expr: &Arc<AggregateFunctionExpr>,
) -> Result<Box<dyn GroupsAccumulator>> {
if agg_expr.groups_accumulator_supported() {
agg_expr.create_groups_accumulator()
} else {
debug!(
"Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}",
agg_expr.name()
);
let agg_expr_captured = Arc::clone(agg_expr);
let factory = move || agg_expr_captured.create_accumulator();
Ok(Box::new(GroupsAccumulatorAdapter::new(factory)))
}
}
impl Stream for GroupedHashAggregateStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
loop {
match &self.exec_state {
ExecutionState::ReadingInput => 'reading_input: {
match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let input_rows = batch.num_rows();
if self.mode == AggregateMode::Partial
&& let Some(reduction_factor) =
self.reduction_factor.as_ref()
{
reduction_factor.add_total(input_rows);
}
self.group_aggregate_batch(&batch)?;
assert!(!self.input_done);
if self.hit_soft_group_limit() {
timer.done();
self.set_input_done_and_produce_output()?;
break 'reading_input;
}
if (self.spill_state.spills.is_empty()
|| self.spill_state.is_stream_merging)
&& let Some(to_emit) = self.group_ordering.emit_to()
{
timer.done();
if let Some(batch) = self.emit(to_emit, false)? {
self.exec_state =
ExecutionState::ProducingOutput(batch);
};
break 'reading_input;
}
if self.mode == AggregateMode::Partial {
assert!(!self.spill_state.is_stream_merging);
self.update_skip_aggregation_probe(input_rows);
if let Some(new_state) =
self.switch_to_skip_aggregation()?
{
timer.done();
self.exec_state = new_state;
break 'reading_input;
}
}
if let Some(new_state) =
self.try_update_memory_reservation()?
{
timer.done();
self.exec_state = new_state;
break 'reading_input;
}
timer.done();
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
None => {
self.set_input_done_and_produce_output()?;
}
}
}
ExecutionState::SkippingAggregation => {
match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let _timer = elapsed_compute.timer();
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
probe.record_skipped(&batch);
}
let states = self.transform_to_states(&batch)?;
return Poll::Ready(Some(Ok(
states.record_output(&self.baseline_metrics)
)));
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
None => {
if !self.group_values.is_empty() {
return Poll::Ready(Some(internal_err!(
"Switching from SkippingAggregation to Done with {} groups still in hash table. \
This is a bug - all groups should have been emitted before skip aggregation started.",
self.group_values.len()
)));
}
self.exec_state = ExecutionState::Done;
}
}
}
ExecutionState::ProducingOutput(batch) => {
let output_batch;
let size = self.batch_size;
(self.exec_state, output_batch) = if batch.num_rows() <= size {
(
if self.input_done {
ExecutionState::Done
}
else if self.mode == AggregateMode::Partial
&& self.should_skip_aggregation()
{
ExecutionState::SkippingAggregation
} else {
ExecutionState::ReadingInput
},
batch.clone(),
)
} else {
let size = self.batch_size;
let num_remaining = batch.num_rows() - size;
let remaining = batch.slice(size, num_remaining);
let output = batch.slice(0, size);
(ExecutionState::ProducingOutput(remaining), output)
};
if let Some(reduction_factor) = self.reduction_factor.as_ref() {
reduction_factor.add_part(output_batch.num_rows());
}
debug_assert!(output_batch.num_rows() > 0);
return Poll::Ready(Some(Ok(
output_batch.record_output(&self.baseline_metrics)
)));
}
ExecutionState::Done => {
if !self.group_values.is_empty() {
return Poll::Ready(Some(internal_err!(
"AggregateStream was in Done state with {} groups left in hash table. \
This is a bug - all groups should have been emitted before entering Done state.",
self.group_values.len()
)));
}
self.clear_all();
let _ = self.update_memory_reservation();
return Poll::Ready(None);
}
}
}
}
}
impl RecordBatchStream for GroupedHashAggregateStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
impl GroupedHashAggregateStream {
fn group_aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> {
let group_by_values = if self.spill_state.is_stream_merging {
evaluate_group_by(&self.spill_state.merging_group_by, batch)?
} else {
evaluate_group_by(&self.group_by, batch)?
};
let timer = match (
self.spill_state.is_stream_merging,
self.spill_state.merging_aggregate_arguments.is_empty(),
self.aggregate_arguments.is_empty(),
) {
(true, false, _) | (false, _, false) => {
Some(self.group_by_metrics.aggregate_arguments_time.timer())
}
_ => None,
};
let input_values = if self.spill_state.is_stream_merging {
evaluate_many(&self.spill_state.merging_aggregate_arguments, batch)?
} else {
evaluate_many(&self.aggregate_arguments, batch)?
};
drop(timer);
let filter_values = if self.spill_state.is_stream_merging {
let filter_expressions = vec![None; self.accumulators.len()];
evaluate_optional(&filter_expressions, batch)?
} else {
evaluate_optional(&self.filter_expressions, batch)?
};
for group_values in &group_by_values {
let groups_start_time = Instant::now();
let starting_num_groups = self.group_values.len();
self.group_values
.intern(group_values, &mut self.current_group_indices)?;
let group_indices = &self.current_group_indices;
let total_num_groups = self.group_values.len();
if total_num_groups > starting_num_groups {
self.group_ordering.new_groups(
group_values,
group_indices,
total_num_groups,
)?;
}
let agg_start_time = Instant::now();
self.group_by_metrics
.time_calculating_group_ids
.add_duration(agg_start_time - groups_start_time);
let t = self
.accumulators
.iter_mut()
.zip(input_values.iter())
.zip(filter_values.iter());
for ((acc, values), opt_filter) in t {
let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
if self.mode.input_mode() == AggregateInputMode::Raw
&& !self.spill_state.is_stream_merging
{
acc.update_batch(
values,
group_indices,
opt_filter,
total_num_groups,
)?;
} else {
assert_or_internal_err!(
opt_filter.is_none(),
"aggregate filter should be applied in partial stage, there should be no filter in final stage"
);
acc.merge_batch(values, group_indices, None, total_num_groups)?;
}
self.group_by_metrics
.aggregation_time
.add_elapsed(agg_start_time);
}
}
Ok(())
}
fn try_update_memory_reservation(&mut self) -> Result<Option<ExecutionState>> {
let oom = match self.update_memory_reservation() {
Err(e @ DataFusionError::ResourcesExhausted(_)) => e,
Err(e) => return Err(e),
Ok(_) => return Ok(None),
};
match self.oom_mode {
OutOfMemoryMode::Spill if !self.group_values.is_empty() => {
self.spill()?;
self.clear_shrink(self.batch_size);
self.update_memory_reservation()?;
Ok(None)
}
OutOfMemoryMode::EmitEarly if self.group_values.len() > 1 => {
let n = if self.group_values.len() >= self.batch_size {
self.group_values.len() / self.batch_size * self.batch_size
} else {
self.group_values.len()
};
let n = match &self.group_ordering {
GroupOrdering::None => n,
_ => match self.group_ordering.emit_to() {
Some(EmitTo::First(max)) => n.min(max),
_ => 0,
},
};
if n > 0
&& let Some(batch) = self.emit(EmitTo::First(n), false)?
{
Ok(Some(ExecutionState::ProducingOutput(batch)))
} else {
Err(oom)
}
}
_ => Err(oom),
}
}
fn update_memory_reservation(&mut self) -> Result<()> {
let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
let groups_and_acc_size = acc
+ self.group_values.size()
+ self.group_ordering.size()
+ self.current_group_indices.allocated_size();
let sort_headroom =
if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() {
acc + self.group_values.size()
} else {
0
};
let new_size = groups_and_acc_size + sort_headroom;
let reservation_result = self.reservation.try_resize(new_size);
if reservation_result.is_ok() {
self.spill_state
.peak_mem_used
.set_max(self.reservation.size());
}
reservation_result
}
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Option<RecordBatch>> {
let schema = if spilling {
Arc::clone(&self.spill_state.spill_schema)
} else {
self.schema()
};
if self.group_values.is_empty() {
return Ok(None);
}
let timer = self.group_by_metrics.emitting_time.timer();
let mut output = self.group_values.emit(emit_to)?;
if let EmitTo::First(n) = emit_to {
self.group_ordering.remove_groups(n);
}
for acc in self.accumulators.iter_mut() {
if self.mode.output_mode() == AggregateOutputMode::Final && !spilling {
output.push(acc.evaluate(emit_to)?)
} else {
output.extend(acc.state(emit_to)?)
}
}
drop(timer);
let _ = self.update_memory_reservation();
let batch = RecordBatch::try_new(schema, output)?;
debug_assert!(batch.num_rows() > 0);
Ok(Some(batch))
}
fn spill(&mut self) -> Result<()> {
let Some(emit) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
self.clear_shrink(0);
self.update_memory_reservation()?;
let batch_size_ratio = self.batch_size as f32 / emit.num_rows() as f32;
let batch_memory = get_record_batch_memory_size(&emit);
let sort_memory = (batch_memory
+ (emit.get_sliced_size()? as f32 * batch_size_ratio) as usize)
.min(batch_memory * 2);
self.reservation.try_grow(sort_memory).map_err(|err| {
resources_datafusion_err!(
"Failed to reserve memory for sort during spill: {err}"
)
})?;
let sorted_iter = IncrementalSortIterator::new(
emit,
self.spill_state.spill_expr.clone(),
self.batch_size,
);
let spillfile = self
.spill_state
.spill_manager
.spill_record_batch_iter_and_return_max_batch_memory(
sorted_iter,
"HashAggSpill",
)?;
self.reservation.shrink(sort_memory);
match spillfile {
Some((spillfile, max_record_batch_memory)) => {
self.spill_state.spills.push(SortedSpillFile {
file: spillfile,
max_record_batch_memory,
})
}
None => {
return internal_err!(
"Calling spill with no intermediate batch to spill"
);
}
}
Ok(())
}
fn clear_shrink(&mut self, num_rows: usize) {
self.group_values.clear_shrink(num_rows);
self.current_group_indices.clear();
self.current_group_indices.shrink_to(num_rows);
}
fn clear_all(&mut self) {
self.clear_shrink(0);
}
fn hit_soft_group_limit(&self) -> bool {
let Some(group_values_soft_limit) = self.group_values_soft_limit else {
return false;
};
group_values_soft_limit <= self.group_values.len()
}
fn set_input_done_and_produce_output(&mut self) -> Result<()> {
self.input_done = true;
self.group_ordering.input_done();
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
let batch = self.emit(EmitTo::All, false)?;
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
} else {
self.spill()?;
self.spill_state.is_stream_merging = true;
self.input = StreamingMergeBuilder::new()
.with_schema(Arc::clone(&self.spill_state.spill_schema))
.with_spill_manager(self.spill_state.spill_manager.clone())
.with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills))
.with_expressions(&self.spill_state.spill_expr)
.with_metrics(self.baseline_metrics.clone())
.with_batch_size(self.batch_size)
.with_reservation(self.reservation.new_empty())
.build()?;
self.input_done = false;
self.clear_all();
self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
let group_schema = self
.spill_state
.merging_group_by
.group_schema(&self.spill_state.spill_schema)?;
if group_schema.fields().len() > 1 {
self.group_values = new_group_values(group_schema, &self.group_ordering)?;
}
self.oom_mode = OutOfMemoryMode::ReportError;
self.update_memory_reservation()?;
ExecutionState::ReadingInput
};
timer.done();
Ok(())
}
fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
assert!(self.spill_state.spills.is_empty());
probe.update_state(input_rows, self.group_values.len());
};
}
fn switch_to_skip_aggregation(&mut self) -> Result<Option<ExecutionState>> {
if let Some(probe) = self.skip_aggregation_probe.as_mut()
&& probe.should_skip()
&& let Some(batch) = self.emit(EmitTo::All, false)?
{
return Ok(Some(ExecutionState::ProducingOutput(batch)));
};
Ok(None)
}
fn should_skip_aggregation(&self) -> bool {
self.skip_aggregation_probe
.as_ref()
.is_some_and(|probe| probe.should_skip())
}
fn transform_to_states(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let mut group_values = evaluate_group_by(&self.group_by, batch)?;
let input_values = evaluate_many(&self.aggregate_arguments, batch)?;
let filter_values = evaluate_optional(&self.filter_expressions, batch)?;
assert_eq_or_internal_err!(
group_values.len(),
1,
"group_values expected to have single element"
);
let mut output = group_values.swap_remove(0);
let iter = self
.accumulators
.iter()
.zip(input_values.iter())
.zip(filter_values.iter());
for ((acc, values), opt_filter) in iter {
let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
output.extend(acc.convert_to_state(values, opt_filter)?);
}
let states_batch = RecordBatch::try_new(self.schema(), output)?;
Ok(states_batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::InputOrderMode;
use crate::execution_plan::ExecutionPlan;
use crate::test::TestMemoryExec;
use arrow::array::{Int32Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_execution::TaskContext;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::col;
use std::sync::Arc;
#[tokio::test]
async fn test_double_emission_race_condition_bug() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("group_col", DataType::Int32, false),
Field::new("value_col", DataType::Int64, false),
]));
let batch_size = 1024; let num_groups = batch_size + 100;
let group_ids: Vec<i32> = (0..num_groups as i32).collect();
let values: Vec<i64> = vec![1; num_groups];
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(group_ids)),
Arc::new(Int64Array::from(values)),
],
)?;
let input_partitions = vec![vec![batch]];
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(1024, 1.0) .build_arc()?;
let mut task_ctx = TaskContext::default().with_runtime(runtime);
let mut session_config = task_ctx.session_config().clone();
session_config = session_config.set(
"datafusion.execution.batch_size",
&datafusion_common::ScalarValue::UInt64(Some(1024)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&datafusion_common::ScalarValue::UInt64(Some(50)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
&datafusion_common::ScalarValue::Float64(Some(0.8)),
);
task_ctx = task_ctx.with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
let aggr_expr = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
.schema(Arc::clone(&schema))
.alias("count_value")
.build()?,
)];
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec)));
let aggregate_exec = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(group_expr),
aggr_expr,
vec![None],
exec,
Arc::clone(&schema),
)?;
let mut stream =
GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?;
let mut results = Vec::new();
while let Some(result) = stream.next().await {
let batch = result?;
results.push(batch);
}
let mut total_output_groups = 0;
for batch in &results {
total_output_groups += batch.num_rows();
}
assert_eq!(
total_output_groups, num_groups,
"Unexpected number of groups",
);
Ok(())
}
#[tokio::test]
async fn test_skip_aggregation_probe_not_locked_until_skip() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("group_col", DataType::Int32, false),
Field::new("value_col", DataType::Int32, false),
]));
let probe_rows_threshold = 100;
let probe_ratio_threshold = 0.8;
let batch1_rows = 100;
let batch1_groups = 10;
let mut group_ids_batch1 = Vec::new();
for i in 0..batch1_rows {
group_ids_batch1.push((i % batch1_groups) as i32);
}
let values_batch1: Vec<i32> = vec![1; batch1_rows];
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(group_ids_batch1)),
Arc::new(Int32Array::from(values_batch1)),
],
)?;
let batch2_rows = 350;
let batch2_groups = 350;
let group_ids_batch2: Vec<i32> = (batch1_groups..(batch1_groups + batch2_groups))
.map(|x| x as i32)
.collect();
let values_batch2: Vec<i32> = vec![1; batch2_rows];
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(group_ids_batch2)),
Arc::new(Int32Array::from(values_batch2)),
],
)?;
let batch3_rows = 100;
let batch3_groups = 100;
let batch3_start_group = batch1_groups + batch2_groups;
let group_ids_batch3: Vec<i32> = (batch3_start_group
..(batch3_start_group + batch3_groups))
.map(|x| x as i32)
.collect();
let values_batch3: Vec<i32> = vec![1; batch3_rows];
let batch3 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(group_ids_batch3)),
Arc::new(Int32Array::from(values_batch3)),
],
)?;
let input_partitions = vec![vec![batch1, batch2, batch3]];
let runtime = RuntimeEnvBuilder::default().build_arc()?;
let mut task_ctx = TaskContext::default().with_runtime(runtime);
let mut session_config = task_ctx.session_config().clone();
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&datafusion_common::ScalarValue::UInt64(Some(probe_rows_threshold)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
&datafusion_common::ScalarValue::Float64(Some(probe_ratio_threshold)),
);
task_ctx = task_ctx.with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
let aggr_expr = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
.schema(Arc::clone(&schema))
.alias("count_value")
.build()?,
)];
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec)));
let aggregate_exec = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(group_expr),
aggr_expr,
vec![None],
exec,
Arc::clone(&schema),
)?;
let mut stream =
GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?;
let mut results = Vec::new();
while let Some(result) = stream.next().await {
let batch = result?;
results.push(batch);
}
let metrics = aggregate_exec.metrics().unwrap();
let skipped_rows = metrics
.sum_by_name("skipped_aggregation_rows")
.map(|m| m.as_usize())
.unwrap_or(0);
assert_eq!(
skipped_rows, batch3_rows,
"Expected batch 3's rows ({batch3_rows}) to be skipped",
);
Ok(())
}
#[tokio::test]
async fn test_emit_early_with_partially_sorted() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("sort_col", DataType::Int32, false),
Field::new("group_col", DataType::Int32, false),
Field::new("value_col", DataType::Int64, false),
]));
let n = 256;
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1; n])),
Arc::new(Int32Array::from((0..n as i32).collect::<Vec<_>>())),
Arc::new(Int64Array::from(vec![1; n])),
],
)?;
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(4096, 1.0)
.build_arc()?;
let mut task_ctx = TaskContext::default().with_runtime(runtime);
let mut cfg = task_ctx.session_config().clone();
cfg = cfg.set(
"datafusion.execution.batch_size",
&datafusion_common::ScalarValue::UInt64(Some(128)),
);
cfg = cfg.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&datafusion_common::ScalarValue::UInt64(Some(u64::MAX)),
);
task_ctx = task_ctx.with_session_config(cfg);
let task_ctx = Arc::new(task_ctx);
let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(
Column::new("sort_col", 0),
)
as _)])
.unwrap();
let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?
.try_with_sort_information(vec![ordering])?;
let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec)));
let aggregate_exec = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![
(col("sort_col", &schema)?, "sort_col".to_string()),
(col("group_col", &schema)?, "group_col".to_string()),
]),
vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
.schema(Arc::clone(&schema))
.alias("count_value")
.build()?,
)],
vec![None],
exec,
Arc::clone(&schema),
)?;
assert!(matches!(
aggregate_exec.input_order_mode(),
InputOrderMode::PartiallySorted(_)
));
let mut stream = GroupedHashAggregateStream::new(&aggregate_exec, &task_ctx, 0)?;
while let Some(result) = stream.next().await {
if let Err(e) = result {
if e.to_string().contains("Resources exhausted") {
break;
}
return Err(e);
}
}
Ok(())
}
}