use crate::execution::context::TaskContext;
use crate::physical_plan::aggregates::no_grouping::AggregateStream;
use crate::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
};
use crate::physical_plan::{
DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
SendableRecordBatchStream, Statistics,
};
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{
expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr,
};
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
mod no_grouping;
mod row_hash;
use crate::physical_plan::aggregates::row_hash::GroupedHashAggregateStream;
use crate::physical_plan::EquivalenceProperties;
pub use datafusion_expr::AggregateFunction;
use datafusion_physical_expr::aggregate::row_accumulator::RowAccumulator;
use datafusion_physical_expr::equivalence::project_equivalence_properties;
pub use datafusion_physical_expr::expressions::create_aggregate_expr;
use datafusion_physical_expr::normalize_out_expr_with_alias_schema;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AggregateMode {
Partial,
Final,
FinalPartitioned,
}
#[derive(Clone, Debug, Default)]
pub struct PhysicalGroupBy {
expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
groups: Vec<Vec<bool>>,
}
impl PhysicalGroupBy {
pub fn new(
expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
groups: Vec<Vec<bool>>,
) -> Self {
Self {
expr,
null_expr,
groups,
}
}
pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
let num_exprs = expr.len();
Self {
expr,
null_expr: vec![],
groups: vec![vec![false; num_exprs]],
}
}
pub fn contains_null(&self) -> bool {
self.groups.iter().flatten().any(|is_null| *is_null)
}
pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
&self.expr
}
pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
&self.null_expr
}
pub fn groups(&self) -> &[Vec<bool>] {
&self.groups
}
pub fn is_empty(&self) -> bool {
self.expr.is_empty()
}
}
enum StreamType {
AggregateStream(AggregateStream),
GroupedHashAggregateStream(GroupedHashAggregateStream),
}
impl From<StreamType> for SendableRecordBatchStream {
fn from(stream: StreamType) -> Self {
match stream {
StreamType::AggregateStream(stream) => Box::pin(stream),
StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream),
}
}
}
#[derive(Debug)]
pub struct AggregateExec {
pub(crate) mode: AggregateMode,
pub(crate) group_by: PhysicalGroupBy,
pub(crate) aggr_expr: Vec<Arc<dyn AggregateExpr>>,
pub(crate) input: Arc<dyn ExecutionPlan>,
schema: SchemaRef,
pub(crate) input_schema: SchemaRef,
alias_map: HashMap<Column, Vec<Column>>,
metrics: ExecutionPlanMetricsSet,
}
impl AggregateExec {
pub fn try_new(
mode: AggregateMode,
group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let schema = create_schema(
&input.schema(),
&group_by.expr,
&aggr_expr,
group_by.contains_null(),
mode,
)?;
let schema = Arc::new(schema);
let mut alias_map: HashMap<Column, Vec<Column>> = HashMap::new();
for (expression, name) in group_by.expr.iter() {
if let Some(column) = expression.as_any().downcast_ref::<Column>() {
let new_col_idx = schema.index_of(name)?;
if (column.name() != name) || (column.index() != new_col_idx) {
let entry = alias_map.entry(column.clone()).or_insert_with(Vec::new);
entry.push(Column::new(name, new_col_idx));
}
};
}
Ok(AggregateExec {
mode,
group_by,
aggr_expr,
input,
schema,
input_schema,
alias_map,
metrics: ExecutionPlanMetricsSet::new(),
})
}
pub fn mode(&self) -> &AggregateMode {
&self.mode
}
pub fn group_expr(&self) -> &PhysicalGroupBy {
&self.group_by
}
pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.group_by
.expr()
.iter()
.enumerate()
.map(|(index, (_col, name))| {
Arc::new(expressions::Column::new(name, index)) as Arc<dyn PhysicalExpr>
})
.collect()
}
pub fn aggr_expr(&self) -> &[Arc<dyn AggregateExpr>] {
&self.aggr_expr
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
fn execute_typed(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<StreamType> {
let batch_size = context.session_config().batch_size();
let input = self.input.execute(partition, Arc::clone(&context))?;
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
if self.group_by.expr.is_empty() {
Ok(StreamType::AggregateStream(AggregateStream::new(
self.mode,
self.schema.clone(),
self.aggr_expr.clone(),
input,
baseline_metrics,
context,
partition,
)?))
} else {
Ok(StreamType::GroupedHashAggregateStream(
GroupedHashAggregateStream::new(
self.mode,
self.schema.clone(),
self.group_by.clone(),
self.aggr_expr.clone(),
input,
baseline_metrics,
batch_size,
context,
partition,
)?,
))
}
}
}
impl ExecutionPlan for AggregateExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
match &self.mode {
AggregateMode::Partial => {
let input_partition = self.input.output_partitioning();
match input_partition {
Partitioning::Hash(exprs, part) => {
let normalized_exprs = exprs
.into_iter()
.map(|expr| {
normalize_out_expr_with_alias_schema(
expr,
&self.alias_map,
&self.schema,
)
})
.collect::<Vec<_>>();
Partitioning::Hash(normalized_exprs, part)
}
_ => input_partition,
}
}
_ => self.input.output_partitioning(),
}
}
fn unbounded_output(&self, children: &[bool]) -> Result<bool> {
if children[0] {
Err(DataFusionError::Plan(
"Aggregate Error: `GROUP BY` clause (including the more general GROUPING SET) is not supported for unbounded inputs.".to_string(),
))
} else {
Ok(false)
}
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn required_input_distribution(&self) -> Vec<Distribution> {
match &self.mode {
AggregateMode::Partial => vec![Distribution::UnspecifiedDistribution],
AggregateMode::FinalPartitioned => {
vec![Distribution::HashPartitioned(self.output_group_expr())]
}
AggregateMode::Final => vec![Distribution::SinglePartition],
}
}
fn equivalence_properties(&self) -> EquivalenceProperties {
let mut new_properties = EquivalenceProperties::new(self.schema());
project_equivalence_properties(
self.input.equivalence_properties(),
&self.alias_map,
&mut new_properties,
);
new_properties
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(AggregateExec::try_new(
self.mode,
self.group_by.clone(),
self.aggr_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
self.execute_typed(partition, context)
.map(|stream| stream.into())
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "AggregateExec: mode={:?}", self.mode)?;
let g: Vec<String> = if self.group_by.groups.len() == 1 {
self.group_by
.expr
.iter()
.map(|(e, alias)| {
let e = e.to_string();
if &e != alias {
format!("{e} as {alias}")
} else {
e
}
})
.collect()
} else {
self.group_by
.groups
.iter()
.map(|group| {
let terms = group
.iter()
.enumerate()
.map(|(idx, is_null)| {
if *is_null {
let (e, alias) = &self.group_by.null_expr[idx];
let e = e.to_string();
if &e != alias {
format!("{e} as {alias}")
} else {
e
}
} else {
let (e, alias) = &self.group_by.expr[idx];
let e = e.to_string();
if &e != alias {
format!("{e} as {alias}")
} else {
e
}
}
})
.collect::<Vec<String>>()
.join(", ");
format!("({terms})")
})
.collect()
};
write!(f, ", gby=[{}]", g.join(", "))?;
let a: Vec<String> = self
.aggr_expr
.iter()
.map(|agg| agg.name().to_string())
.collect();
write!(f, ", aggr=[{}]", a.join(", "))?;
}
}
Ok(())
}
fn statistics(&self) -> Statistics {
match self.mode {
AggregateMode::Final | AggregateMode::FinalPartitioned
if self.group_by.expr.is_empty() =>
{
Statistics {
num_rows: Some(1),
is_exact: true,
..Default::default()
}
}
_ => Statistics {
num_rows: self.input.statistics().num_rows,
is_exact: false,
..Default::default()
},
}
}
}
fn create_schema(
input_schema: &Schema,
group_expr: &[(Arc<dyn PhysicalExpr>, String)],
aggr_expr: &[Arc<dyn AggregateExpr>],
contains_null_expr: bool,
mode: AggregateMode,
) -> datafusion_common::Result<Schema> {
let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
for (expr, name) in group_expr {
fields.push(Field::new(
name,
expr.data_type(input_schema)?,
contains_null_expr || expr.nullable(input_schema)?,
))
}
match mode {
AggregateMode::Partial => {
for expr in aggr_expr {
fields.extend(expr.state_fields()?.iter().cloned())
}
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
for expr in aggr_expr {
fields.push(expr.field()?)
}
}
}
Ok(Schema::new(fields))
}
fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
let group_fields = schema.fields()[0..group_count].to_vec();
Arc::new(Schema::new(group_fields))
}
fn aggregate_expressions(
aggr_expr: &[Arc<dyn AggregateExpr>],
mode: &AggregateMode,
col_idx_base: usize,
) -> datafusion_common::Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let mut col_idx_base = col_idx_base;
Ok(aggr_expr
.iter()
.map(|agg| {
let exprs = merge_expressions(col_idx_base, agg)?;
col_idx_base += exprs.len();
Ok(exprs)
})
.collect::<datafusion_common::Result<Vec<_>>>()?)
}
}
}
fn merge_expressions(
index_base: usize,
expr: &Arc<dyn AggregateExpr>,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
Ok(expr
.state_fields()?
.iter()
.enumerate()
.map(|(idx, f)| {
Arc::new(Column::new(f.name(), index_base + idx)) as Arc<dyn PhysicalExpr>
})
.collect::<Vec<_>>())
}
pub(crate) type AccumulatorItem = Box<dyn Accumulator>;
pub(crate) type RowAccumulatorItem = Box<dyn RowAccumulator>;
fn create_accumulators(
aggr_expr: &[Arc<dyn AggregateExpr>],
) -> datafusion_common::Result<Vec<AccumulatorItem>> {
aggr_expr
.iter()
.map(|expr| expr.create_accumulator())
.collect::<datafusion_common::Result<Vec<_>>>()
}
fn create_row_accumulators(
aggr_expr: &[Arc<dyn AggregateExpr>],
) -> datafusion_common::Result<Vec<RowAccumulatorItem>> {
let mut state_index = 0;
aggr_expr
.iter()
.map(|expr| {
let result = expr.create_row_accumulator(state_index);
state_index += expr.state_fields().unwrap().len();
result
})
.collect::<datafusion_common::Result<Vec<_>>>()
}
fn finalize_aggregation(
accumulators: &[AccumulatorItem],
mode: &AggregateMode,
) -> datafusion_common::Result<Vec<ArrayRef>> {
match mode {
AggregateMode::Partial => {
let a = accumulators
.iter()
.map(|accumulator| accumulator.state())
.map(|value| {
value.map(|e| {
e.iter().map(|v| v.to_array()).collect::<Vec<ArrayRef>>()
})
})
.collect::<datafusion_common::Result<Vec<_>>>()?;
Ok(a.iter().flatten().cloned().collect::<Vec<_>>())
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accumulators
.iter()
.map(|accumulator| accumulator.evaluate().map(|v| v.to_array()))
.collect::<datafusion_common::Result<Vec<ArrayRef>>>()
}
}
}
fn evaluate(
expr: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
) -> Result<Vec<ArrayRef>> {
expr.iter()
.map(|expr| expr.evaluate(batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()
}
fn evaluate_many(
expr: &[Vec<Arc<dyn PhysicalExpr>>],
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
expr.iter()
.map(|expr| evaluate(expr, batch))
.collect::<Result<Vec<_>>>()
}
fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
let exprs: Vec<ArrayRef> = group_by
.expr
.iter()
.map(|(expr, _)| {
let value = expr.evaluate(batch)?;
Ok(value.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
let null_exprs: Vec<ArrayRef> = group_by
.null_expr
.iter()
.map(|(expr, _)| {
let value = expr.evaluate(batch)?;
Ok(value.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
Ok(group_by
.groups
.iter()
.map(|group| {
group
.iter()
.enumerate()
.map(|(idx, is_null)| {
if *is_null {
null_exprs[idx].clone()
} else {
exprs[idx].clone()
}
})
.collect()
})
.collect())
}
#[cfg(test)]
mod tests {
use crate::execution::context::{SessionConfig, TaskContext};
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::from_slice::FromSlice;
use crate::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use crate::physical_plan::expressions::{col, Avg};
use crate::test::assert_is_pending;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::{assert_batches_sorted_eq, physical_plan::common};
use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use futures::{FutureExt, Stream};
use std::any::Any;
use std::sync::Arc;
use std::task::{Context, Poll};
use super::StreamType;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::{
ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};
use crate::prelude::SessionContext;
fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
(
schema.clone(),
vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from_slice([2, 3, 4, 4])),
Arc::new(Float64Array::from_slice([1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from_slice([2, 3, 3, 4])),
Arc::new(Float64Array::from_slice([1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
],
)
}
async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
let input_schema = input.schema();
let grouping_set = PhysicalGroupBy {
expr: vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
null_expr: vec![
(lit(ScalarValue::UInt32(None)), "a".to_string()),
(lit(ScalarValue::Float64(None)), "b".to_string()),
],
groups: vec![
vec![false, true], vec![true, false], vec![false, false], ],
};
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Count::new(
lit(1i8),
"COUNT(1)".to_string(),
DataType::Int64,
))];
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
grouping_set.clone(),
aggregates.clone(),
input,
input_schema.clone(),
)?);
let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;
let expected = vec![
"+---+-----+-----------------+",
"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 2 |",
"| | 2.0 | 2 |",
"| | 3.0 | 2 |",
"| | 4.0 | 2 |",
"| 2 | | 2 |",
"| 2 | 1.0 | 2 |",
"| 3 | | 3 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 3 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
];
assert_batches_sorted_eq!(expected, &result);
let groups = partial_aggregate.group_expr().expr().to_vec();
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups
.iter()
.map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
.collect::<Result<_>>()?;
let final_grouping_set = PhysicalGroupBy::new_single(final_group);
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
aggregates,
merge,
input_schema,
)?);
let result =
common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?;
assert_eq!(result.len(), 1);
let batch = &result[0];
assert_eq!(batch.num_columns(), 3);
assert_eq!(batch.num_rows(), 12);
let expected = vec![
"+---+-----+----------+",
"| a | b | COUNT(1) |",
"+---+-----+----------+",
"| | 1.0 | 2 |",
"| | 2.0 | 2 |",
"| | 3.0 | 2 |",
"| | 4.0 | 2 |",
"| 2 | | 2 |",
"| 2 | 1.0 | 2 |",
"| 3 | | 3 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 3 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+----------+",
];
assert_batches_sorted_eq!(&expected, &result);
let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
assert_eq!(12, output_rows);
Ok(())
}
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
let input_schema = input.schema();
let grouping_set = PhysicalGroupBy {
expr: vec![(col("a", &input_schema)?, "a".to_string())],
null_expr: vec![],
groups: vec![vec![false]],
};
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &input_schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
grouping_set.clone(),
aggregates.clone(),
input,
input_schema.clone(),
)?);
let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;
let expected = vec![
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 2 | 2.0 |",
"| 3 | 3 | 7.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
];
assert_batches_sorted_eq!(expected, &result);
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set
.expr
.iter()
.map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
.collect::<Result<_>>()?;
let final_grouping_set = PhysicalGroupBy::new_single(final_group);
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
aggregates,
merge,
input_schema,
)?);
let result =
common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?;
assert_eq!(result.len(), 1);
let batch = &result[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
let expected = vec![
"+---+--------------------+",
"| a | AVG(b) |",
"+---+--------------------+",
"| 2 | 1.0 |",
"| 3 | 2.3333333333333335 |", "| 4 | 3.6666666666666665 |", "+---+--------------------+",
];
assert_batches_sorted_eq!(&expected, &result);
let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
assert_eq!(3, output_rows);
Ok(())
}
#[derive(Debug)]
struct TestYieldingExec {
pub yield_first: bool,
}
impl ExecutionPlan for TestYieldingExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
some_data().0
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Internal(format!(
"Children cannot be replaced in {self:?}"
)))
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let stream = if self.yield_first {
TestYieldingStream::New
} else {
TestYieldingStream::Yielded
};
Ok(Box::pin(stream))
}
fn statistics(&self) -> Statistics {
let (_, batches) = some_data();
common::compute_record_batch_statistics(&[batches], &self.schema(), None)
}
}
enum TestYieldingStream {
New,
Yielded,
ReturnedBatch1,
ReturnedBatch2,
}
impl Stream for TestYieldingStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match &*self {
TestYieldingStream::New => {
*(self.as_mut()) = TestYieldingStream::Yielded;
cx.waker().wake_by_ref();
Poll::Pending
}
TestYieldingStream::Yielded => {
*(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
Poll::Ready(Some(Ok(some_data().1[0].clone())))
}
TestYieldingStream::ReturnedBatch1 => {
*(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
Poll::Ready(Some(Ok(some_data().1[1].clone())))
}
TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
}
}
}
impl RecordBatchStream for TestYieldingStream {
fn schema(&self) -> SchemaRef {
some_data().0
}
}
#[tokio::test]
async fn aggregate_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });
check_aggregates(input).await
}
#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });
check_grouping_sets(input).await
}
#[tokio::test]
async fn aggregate_source_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });
check_aggregates(input).await
}
#[tokio::test]
async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });
check_grouping_sets(input).await
}
#[tokio::test]
async fn test_oom() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });
let input_schema = input.schema();
let session_ctx = SessionContext::with_config_rt(
SessionConfig::default(),
Arc::new(
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
.unwrap(),
),
);
let task_ctx = session_ctx.task_ctx();
let groups_none = PhysicalGroupBy::default();
let groups_some = PhysicalGroupBy {
expr: vec![(col("a", &input_schema)?, "a".to_string())],
null_expr: vec![],
groups: vec![vec![false]],
};
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Median::new(
col("a", &input_schema)?,
"MEDIAN(a)".to_string(),
DataType::UInt32,
))];
let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(ApproxDistinct::new(
col("a", &input_schema)?,
"APPROX_DISTINCT(a)".to_string(),
DataType::UInt32,
))];
let aggregates_v2: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &input_schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
for (version, groups, aggregates) in [
(0, groups_none, aggregates_v0),
(1, groups_some.clone(), aggregates_v1),
(2, groups_some, aggregates_v2),
] {
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates,
input.clone(),
input_schema.clone(),
)?);
let stream = partial_aggregate.execute_typed(0, task_ctx.clone())?;
match version {
0 => {
assert!(matches!(stream, StreamType::AggregateStream(_)));
}
1 => {
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
}
2 => {
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
}
_ => panic!("Unknown version: {version}"),
}
let stream: SendableRecordBatchStream = stream.into();
let err = common::collect(stream).await.unwrap_err();
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}
Ok(())
}
#[tokio::test]
async fn test_drop_cancel_without_groups() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let groups = PhysicalGroupBy::default();
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("a", &schema)?,
"AVG(a)".to_string(),
DataType::Float64,
))];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
blocking_exec,
schema,
)?);
let fut = crate::physical_plan::collect(aggregate_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn test_drop_cancel_with_groups() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
]));
let groups =
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates.clone(),
blocking_exec,
schema,
)?);
let fut = crate::physical_plan::collect(aggregate_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
}