use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
use ahash::RandomState;
use futures::{
ready,
stream::{Stream, StreamExt},
};
use crate::error::Result;
use crate::physical_plan::aggregates::{
evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy,
};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use crate::scalar::ScalarValue;
use arrow::{array::ArrayRef, compute, compute::cast};
use arrow::{
array::{Array, UInt32Builder},
error::{ArrowError, Result as ArrowResult},
};
use arrow::{
datatypes::{Schema, SchemaRef},
record_batch::RecordBatch,
};
use hashbrown::raw::RawTable;
pub(crate) struct GroupedHashAggregateStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
mode: AggregateMode,
accumulators: Accumulators,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
group_by: PhysicalGroupBy,
baseline_metrics: BaselineMetrics,
random_state: RandomState,
finished: bool,
}
impl GroupedHashAggregateStream {
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Result<Self> {
let timer = baseline_metrics.elapsed_compute().timer();
let aggregate_expressions =
aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?;
timer.done();
Ok(Self {
schema,
mode,
input,
aggr_expr,
group_by,
baseline_metrics,
aggregate_expressions,
accumulators: Default::default(),
random_state: Default::default(),
finished: false,
})
}
}
impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if this.finished {
return Poll::Ready(None);
}
let elapsed_compute = this.baseline_metrics.elapsed_compute();
loop {
let result = match ready!(this.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&this.aggr_expr,
batch,
&mut this.accumulators,
&this.aggregate_expressions,
);
timer.done();
match result {
Ok(_) => continue,
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
}
}
Some(Err(e)) => Err(e),
None => {
this.finished = true;
let timer = this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
&this.accumulators,
this.group_by.expr.len(),
&this.schema,
)
.record_output(&this.baseline_metrics);
timer.done();
result
}
};
this.finished = true;
return Poll::Ready(Some(result));
}
}
}
impl RecordBatchStream for GroupedHashAggregateStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
fn group_aggregate_batch(
mode: &AggregateMode,
random_state: &RandomState,
group_by: &PhysicalGroupBy,
aggr_expr: &[Arc<dyn AggregateExpr>],
batch: RecordBatch,
accumulators: &mut Accumulators,
aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<()> {
let group_by_values = evaluate_group_by(group_by, &batch)?;
let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
for grouping_set_values in group_by_values {
let mut groups_with_rows = vec![];
let mut batch_hashes = vec![0; batch.num_rows()];
create_hashes(&grouping_set_values, random_state, &mut batch_hashes)?;
for (row, hash) in batch_hashes.into_iter().enumerate() {
let Accumulators { map, group_states } = accumulators;
let entry = map.get_mut(hash, |(_hash, group_idx)| {
let group_state = &group_states[*group_idx];
grouping_set_values
.iter()
.zip(group_state.group_by_values.iter())
.all(|(array, scalar)| scalar.eq_array(array, row))
});
match entry {
Some((_hash, group_idx)) => {
let group_state = &mut group_states[*group_idx];
if group_state.indices.is_empty() {
groups_with_rows.push(*group_idx);
};
group_state.indices.push(row as u32); }
None => {
let accumulator_set = aggregates::create_accumulators(aggr_expr)?;
let group_by_values = grouping_set_values
.iter()
.map(|col| ScalarValue::try_from_array(col, row))
.collect::<Result<Vec<_>>>()?;
let group_state = GroupState {
group_by_values: group_by_values.into_boxed_slice(),
accumulator_set,
indices: vec![row as u32], };
let group_idx = group_states.len();
group_states.push(group_state);
groups_with_rows.push(group_idx);
map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
}
};
}
let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0);
let mut offsets = vec![0];
let mut offset_so_far = 0;
for group_idx in groups_with_rows.iter() {
let indices = &accumulators.group_states[*group_idx].indices;
batch_indices.append_slice(indices);
offset_so_far += indices.len();
offsets.push(offset_so_far);
}
let batch_indices = batch_indices.finish();
let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
.iter()
.map(|array| {
array
.iter()
.map(|array| {
compute::take(
array.as_ref(),
&batch_indices,
None, )
.unwrap()
})
.collect()
})
.collect();
groups_with_rows
.iter()
.zip(offsets.windows(2))
.try_for_each(|(group_idx, offsets)| {
let group_state = &mut accumulators.group_states[*group_idx];
group_state
.accumulator_set
.iter_mut()
.zip(values.iter())
.map(|(accumulator, aggr_array)| {
(
accumulator,
aggr_array
.iter()
.map(|array| {
array.slice(offsets[0], offsets[1] - offsets[0])
})
.collect::<Vec<ArrayRef>>(),
)
})
.try_for_each(|(accumulator, values)| match mode {
AggregateMode::Partial => accumulator.update_batch(&values),
AggregateMode::FinalPartitioned | AggregateMode::Final => {
accumulator.merge_batch(&values)
}
})
.and({
group_state.indices.clear();
Ok(())
})
})?;
}
Ok(())
}
#[derive(Debug)]
struct GroupState {
group_by_values: Box<[ScalarValue]>,
accumulator_set: Vec<AccumulatorItem>,
indices: Vec<u32>,
}
#[derive(Default)]
struct Accumulators {
map: RawTable<(u64, usize)>,
group_states: Vec<GroupState>,
}
impl std::fmt::Debug for Accumulators {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let map_string = "RawTable";
f.debug_struct("Accumulators")
.field("map", &map_string)
.field("group_states", &self.group_states)
.finish()
}
}
fn create_batch_from_map(
mode: &AggregateMode,
accumulators: &Accumulators,
num_group_expr: usize,
output_schema: &Schema,
) -> ArrowResult<RecordBatch> {
if accumulators.group_states.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
}
let accs = &accumulators.group_states[0].accumulator_set;
let mut acc_data_types: Vec<usize> = vec![];
match mode {
AggregateMode::Partial => {
for acc in accs.iter() {
let state = acc.state()?;
acc_data_types.push(state.len());
}
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
acc_data_types = vec![1; accs.len()];
}
}
let mut columns = (0..num_group_expr)
.map(|i| {
ScalarValue::iter_to_array(
accumulators
.group_states
.iter()
.map(|group_state| group_state.group_by_values[i].clone()),
)
})
.collect::<Result<Vec<_>>>()?;
for (x, &state_len) in acc_data_types.iter().enumerate() {
for y in 0..state_len {
match mode {
AggregateMode::Partial => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
group_state.accumulator_set[x]
.state()
.and_then(|x| x[y].as_scalar().map(|v| v.clone()))
.expect("unexpected accumulator state in hash aggregate")
}),
)?;
columns.push(res);
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
group_state.accumulator_set[x].evaluate().unwrap()
}),
)?;
columns.push(res);
}
}
}
}
let columns = columns
.iter()
.zip(output_schema.fields().iter())
.map(|(col, desired_field)| cast(col, desired_field.data_type()))
.collect::<ArrowResult<Vec<_>>>()?;
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
}