use std::any::Any;
use std::sync::Arc;
use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
use crate::aggregates::{
no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
topk_stream::GroupedTopKAggregateStream,
};
use crate::execution_plan::{CardinalityEffect, EmissionType};
use crate::filter_pushdown::{
ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
FilterPushdownPropagation, PushedDownPredicate,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::{
DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
SendableRecordBatchStream, Statistics, check_if_same_properties,
};
use datafusion_common::config::ConfigOptions;
use datafusion_physical_expr::utils::collect_columns;
use parking_lot::Mutex;
use std::collections::HashSet;
use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::FieldRef;
use datafusion_common::stats::Precision;
use datafusion_common::{
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
};
use datafusion_execution::TaskContext;
use datafusion_expr::{Accumulator, Aggregate};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
use datafusion_physical_expr::{
ConstExpr, EquivalenceProperties, physical_exprs_contains,
};
use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
use datafusion_physical_expr_common::sort_expr::{
LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
};
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use itertools::Itertools;
use topk::hash_table::is_supported_hash_key_type;
use topk::heap::is_supported_heap_type;
pub mod group_values;
mod no_grouping;
pub mod order;
mod row_hash;
mod topk;
mod topk_stream;
pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool {
is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type)
}
const AGGREGATION_HASH_SEED: ahash::RandomState =
ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AggregateInputMode {
Raw,
Partial,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AggregateOutputMode {
Partial,
Final,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AggregateMode {
Partial,
Final,
FinalPartitioned,
Single,
SinglePartitioned,
PartialReduce,
}
impl AggregateMode {
pub fn input_mode(&self) -> AggregateInputMode {
match self {
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::PartialReduce => AggregateInputMode::Partial,
}
}
pub fn output_mode(&self) -> AggregateOutputMode {
match self {
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single
| AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
AggregateMode::Partial | AggregateMode::PartialReduce => {
AggregateOutputMode::Partial
}
}
}
}
#[derive(Clone, Debug, Default)]
pub struct PhysicalGroupBy {
expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
groups: Vec<Vec<bool>>,
has_grouping_set: bool,
}
impl PhysicalGroupBy {
pub fn new(
expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
groups: Vec<Vec<bool>>,
has_grouping_set: bool,
) -> Self {
Self {
expr,
null_expr,
groups,
has_grouping_set,
}
}
pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
let num_exprs = expr.len();
Self {
expr,
null_expr: vec![],
groups: vec![vec![false; num_exprs]],
has_grouping_set: false,
}
}
pub fn exprs_nullable(&self) -> Vec<bool> {
let mut exprs_nullable = vec![false; self.expr.len()];
for group in self.groups.iter() {
group.iter().enumerate().for_each(|(index, is_null)| {
if *is_null {
exprs_nullable[index] = true;
}
})
}
exprs_nullable
}
pub fn is_true_no_grouping(&self) -> bool {
self.is_empty() && !self.has_grouping_set
}
pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
&self.expr
}
pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
&self.null_expr
}
pub fn groups(&self) -> &[Vec<bool>] {
&self.groups
}
pub fn has_grouping_set(&self) -> bool {
self.has_grouping_set
}
pub fn is_empty(&self) -> bool {
self.expr.is_empty()
}
pub fn is_single(&self) -> bool {
!self.has_grouping_set
}
pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.expr
.iter()
.map(|(expr, _alias)| Arc::clone(expr))
.collect()
}
fn num_output_exprs(&self) -> usize {
let mut num_exprs = self.expr.len();
if self.has_grouping_set {
num_exprs += 1
}
num_exprs
}
pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
let num_output_exprs = self.num_output_exprs();
let mut output_exprs = Vec::with_capacity(num_output_exprs);
output_exprs.extend(
self.expr
.iter()
.enumerate()
.take(num_output_exprs)
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
);
if self.has_grouping_set {
output_exprs.push(Arc::new(Column::new(
Aggregate::INTERNAL_GROUPING_ID,
self.expr.len(),
)) as _);
}
output_exprs
}
pub fn num_group_exprs(&self) -> usize {
self.expr.len() + usize::from(self.has_grouping_set)
}
pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
}
fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
let mut fields = Vec::with_capacity(self.num_group_exprs());
for ((expr, name), group_expr_nullable) in
self.expr.iter().zip(self.exprs_nullable().into_iter())
{
fields.push(
Field::new(
name,
expr.data_type(input_schema)?,
group_expr_nullable || expr.nullable(input_schema)?,
)
.with_metadata(expr.return_field(input_schema)?.metadata().clone())
.into(),
);
}
if self.has_grouping_set {
fields.push(
Field::new(
Aggregate::INTERNAL_GROUPING_ID,
Aggregate::grouping_id_type(self.expr.len()),
false,
)
.into(),
);
}
Ok(fields)
}
fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
let mut fields = self.group_fields(input_schema)?;
fields.truncate(self.num_output_exprs());
Ok(fields)
}
pub fn as_final(&self) -> PhysicalGroupBy {
let expr: Vec<_> =
self.output_exprs()
.into_iter()
.zip(
self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
Aggregate::INTERNAL_GROUPING_ID.to_owned(),
)),
)
.collect();
let num_exprs = expr.len();
let groups = if self.expr.is_empty() && !self.has_grouping_set {
vec![]
} else {
vec![vec![false; num_exprs]]
};
Self {
expr,
null_expr: vec![],
groups,
has_grouping_set: false,
}
}
}
impl PartialEq for PhysicalGroupBy {
fn eq(&self, other: &PhysicalGroupBy) -> bool {
self.expr.len() == other.expr.len()
&& self
.expr
.iter()
.zip(other.expr.iter())
.all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
&& self.null_expr.len() == other.null_expr.len()
&& self
.null_expr
.iter()
.zip(other.null_expr.iter())
.all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
&& self.groups == other.groups
&& self.has_grouping_set == other.has_grouping_set
}
}
#[expect(clippy::large_enum_variant)]
enum StreamType {
AggregateStream(AggregateStream),
GroupedHash(GroupedHashAggregateStream),
GroupedPriorityQueue(GroupedTopKAggregateStream),
}
impl From<StreamType> for SendableRecordBatchStream {
fn from(stream: StreamType) -> Self {
match stream {
StreamType::AggregateStream(stream) => Box::pin(stream),
StreamType::GroupedHash(stream) => Box::pin(stream),
StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
}
}
}
#[derive(Debug, Clone)]
struct AggrDynFilter {
filter: Arc<DynamicFilterPhysicalExpr>,
supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
}
#[derive(Debug, Clone)]
struct PerAccumulatorDynFilter {
aggr_type: DynamicFilterAggregateType,
aggr_index: usize,
shared_bound: Arc<Mutex<ScalarValue>>,
}
#[derive(Debug, Clone)]
enum DynamicFilterAggregateType {
Min,
Max,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LimitOptions {
pub limit: usize,
pub descending: Option<bool>,
}
impl LimitOptions {
pub fn new(limit: usize) -> Self {
Self {
limit,
descending: None,
}
}
pub fn new_with_order(limit: usize, descending: bool) -> Self {
Self {
limit,
descending: Some(descending),
}
}
pub fn limit(&self) -> usize {
self.limit
}
pub fn descending(&self) -> Option<bool> {
self.descending
}
}
#[derive(Debug, Clone)]
pub struct AggregateExec {
mode: AggregateMode,
group_by: Arc<PhysicalGroupBy>,
aggr_expr: Arc<[Arc<AggregateFunctionExpr>]>,
filter_expr: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
limit_options: Option<LimitOptions>,
pub input: Arc<dyn ExecutionPlan>,
schema: SchemaRef,
pub input_schema: SchemaRef,
metrics: ExecutionPlanMetricsSet,
required_input_ordering: Option<OrderingRequirements>,
input_order_mode: InputOrderMode,
cache: Arc<PlanProperties>,
dynamic_filter: Option<Arc<AggrDynFilter>>,
}
impl AggregateExec {
pub fn with_new_aggr_exprs(
&self,
aggr_expr: impl Into<Arc<[Arc<AggregateFunctionExpr>]>>,
) -> Self {
Self {
aggr_expr: aggr_expr.into(),
required_input_ordering: self.required_input_ordering.clone(),
metrics: ExecutionPlanMetricsSet::new(),
input_order_mode: self.input_order_mode.clone(),
cache: Arc::clone(&self.cache),
mode: self.mode,
group_by: Arc::clone(&self.group_by),
filter_expr: Arc::clone(&self.filter_expr),
limit_options: self.limit_options,
input: Arc::clone(&self.input),
schema: Arc::clone(&self.schema),
input_schema: Arc::clone(&self.input_schema),
dynamic_filter: self.dynamic_filter.clone(),
}
}
pub fn with_new_limit_options(&self, limit_options: Option<LimitOptions>) -> Self {
Self {
limit_options,
required_input_ordering: self.required_input_ordering.clone(),
metrics: ExecutionPlanMetricsSet::new(),
input_order_mode: self.input_order_mode.clone(),
cache: Arc::clone(&self.cache),
mode: self.mode,
group_by: Arc::clone(&self.group_by),
aggr_expr: Arc::clone(&self.aggr_expr),
filter_expr: Arc::clone(&self.filter_expr),
input: Arc::clone(&self.input),
schema: Arc::clone(&self.schema),
input_schema: Arc::clone(&self.input_schema),
dynamic_filter: self.dynamic_filter.clone(),
}
}
pub fn cache(&self) -> &PlanProperties {
&self.cache
}
pub fn try_new(
mode: AggregateMode,
group_by: impl Into<Arc<PhysicalGroupBy>>,
aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let group_by = group_by.into();
let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
let schema = Arc::new(schema);
AggregateExec::try_new_with_schema(
mode,
group_by,
aggr_expr,
filter_expr,
input,
input_schema,
schema,
)
}
fn try_new_with_schema(
mode: AggregateMode,
group_by: impl Into<Arc<PhysicalGroupBy>>,
mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
filter_expr: impl Into<Arc<[Option<Arc<dyn PhysicalExpr>>]>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
schema: SchemaRef,
) -> Result<Self> {
let group_by = group_by.into();
let filter_expr = filter_expr.into();
assert_eq_or_internal_err!(
aggr_expr.len(),
filter_expr.len(),
"Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
aggr_expr,
filter_expr
);
let input_eq_properties = input.equivalence_properties();
let groupby_exprs = group_by.input_exprs();
let (new_sort_exprs, indices) =
input_eq_properties.find_longest_permutation(&groupby_exprs)?;
let mut new_requirements = new_sort_exprs
.into_iter()
.map(PhysicalSortRequirement::from)
.collect::<Vec<_>>();
let req = get_finer_aggregate_exprs_requirement(
&mut aggr_expr,
&group_by,
input_eq_properties,
&mode,
)?;
new_requirements.extend(req);
let required_input_ordering =
LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
let indices: Vec<usize> = indices
.into_iter()
.filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
.collect();
let input_order_mode = if indices.len() == groupby_exprs.len()
&& !indices.is_empty()
&& group_by.groups.len() == 1
{
InputOrderMode::Sorted
} else if !indices.is_empty() {
InputOrderMode::PartiallySorted(indices)
} else {
InputOrderMode::Linear
};
let group_expr_mapping =
ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
let cache = Self::compute_properties(
&input,
Arc::clone(&schema),
&group_expr_mapping,
&mode,
&input_order_mode,
aggr_expr.as_ref(),
)?;
let mut exec = AggregateExec {
mode,
group_by,
aggr_expr: aggr_expr.into(),
filter_expr,
input,
schema,
input_schema,
metrics: ExecutionPlanMetricsSet::new(),
required_input_ordering,
limit_options: None,
input_order_mode,
cache: Arc::new(cache),
dynamic_filter: None,
};
exec.init_dynamic_filter();
Ok(exec)
}
pub fn mode(&self) -> &AggregateMode {
&self.mode
}
pub fn with_limit_options(mut self, limit_options: Option<LimitOptions>) -> Self {
self.limit_options = limit_options;
self
}
pub fn limit_options(&self) -> Option<LimitOptions> {
self.limit_options
}
pub fn group_expr(&self) -> &PhysicalGroupBy {
&self.group_by
}
pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.group_by.output_exprs()
}
pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
&self.aggr_expr
}
pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
&self.filter_expr
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn input_schema(&self) -> SchemaRef {
Arc::clone(&self.input_schema)
}
fn execute_typed(
&self,
partition: usize,
context: &Arc<TaskContext>,
) -> Result<StreamType> {
if self.group_by.is_true_no_grouping() {
return Ok(StreamType::AggregateStream(AggregateStream::new(
self, context, partition,
)?));
}
if let Some(config) = self.limit_options
&& !self.is_unordered_unfiltered_group_by_distinct()
{
return Ok(StreamType::GroupedPriorityQueue(
GroupedTopKAggregateStream::new(self, context, partition, config.limit)?,
));
}
Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
self, context, partition,
)?))
}
pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
agg_expr.get_minmax_desc()
}
pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
if self
.limit_options()
.and_then(|config| config.descending)
.is_some()
{
return false;
}
if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
return false;
}
if !self.aggr_expr().is_empty() {
return false;
}
if self.filter_expr().iter().any(|e| e.is_some()) {
return false;
}
if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
return false;
}
if self.properties().output_ordering().is_some() {
return false;
}
if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
return matches!(requirement, OrderingRequirements::Hard(_));
}
true
}
pub fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
group_expr_mapping: &ProjectionMapping,
mode: &AggregateMode,
input_order_mode: &InputOrderMode,
aggr_exprs: &[Arc<AggregateFunctionExpr>],
) -> Result<PlanProperties> {
let mut eq_properties = input
.equivalence_properties()
.project(group_expr_mapping, schema);
if group_expr_mapping.is_empty() {
let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
let column = Arc::new(Column::new(func.name(), idx));
ConstExpr::from(column as Arc<dyn PhysicalExpr>)
});
eq_properties.add_constants(new_constants)?;
}
let mut constraints = eq_properties.constraints().to_vec();
let new_constraint = Constraint::Unique(
group_expr_mapping
.iter()
.flat_map(|(_, target_cols)| {
target_cols.iter().flat_map(|(expr, _)| {
expr.as_any().downcast_ref::<Column>().map(|c| c.index())
})
})
.collect(),
);
constraints.push(new_constraint);
eq_properties =
eq_properties.with_constraints(Constraints::new_unverified(constraints));
let input_partitioning = input.output_partitioning().clone();
let output_partitioning = match mode.input_mode() {
AggregateInputMode::Raw => {
let input_eq_properties = input.equivalence_properties();
input_partitioning.project(group_expr_mapping, input_eq_properties)
}
AggregateInputMode::Partial => input_partitioning.clone(),
};
let emission_type = if *input_order_mode == InputOrderMode::Linear {
EmissionType::Final
} else {
input.pipeline_behavior()
};
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
emission_type,
input.boundedness(),
))
}
pub fn input_order_mode(&self) -> &InputOrderMode {
&self.input_order_mode
}
fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
let column_statistics = {
let mut column_statistics = Statistics::unknown_column(&self.schema());
for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
column_statistics[idx].max_value = child_statistics.column_statistics
[col.index()]
.max_value
.clone();
column_statistics[idx].min_value = child_statistics.column_statistics
[col.index()]
.min_value
.clone();
}
}
column_statistics
};
match self.mode {
AggregateMode::Final | AggregateMode::FinalPartitioned
if self.group_by.expr.is_empty() =>
{
let total_byte_size =
Self::calculate_scaled_byte_size(child_statistics, 1);
Ok(Statistics {
num_rows: Precision::Exact(1),
column_statistics,
total_byte_size,
})
}
_ => {
let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
{
if *value > 1 {
child_statistics.num_rows.to_inexact()
} else if *value == 0 {
child_statistics.num_rows
} else {
let grouping_set_num = self.group_by.groups.len();
child_statistics.num_rows.map(|x| x * grouping_set_num)
}
} else {
Precision::Absent
};
let total_byte_size = num_rows
.get_value()
.and_then(|&output_rows| {
Self::calculate_scaled_byte_size(child_statistics, output_rows)
.get_value()
.map(|&bytes| Precision::Inexact(bytes))
})
.unwrap_or(Precision::Absent);
Ok(Statistics {
num_rows,
column_statistics,
total_byte_size,
})
}
}
}
fn init_dynamic_filter(&mut self) {
if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
debug_assert!(
self.dynamic_filter.is_none(),
"The current operator node does not support dynamic filter"
);
return;
}
if self.dynamic_filter.is_some() {
return;
}
let mut aggr_dyn_filters = Vec::new();
let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
let fun_name = aggr_expr.fun().name();
let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
DynamicFilterAggregateType::Min
} else if fun_name.eq_ignore_ascii_case("max") {
DynamicFilterAggregateType::Max
} else {
return;
};
if let [arg] = aggr_expr.expressions().as_slice()
&& arg.as_any().is::<Column>()
{
all_cols.push(Arc::clone(arg));
aggr_dyn_filters.push(PerAccumulatorDynFilter {
aggr_type,
aggr_index: i,
shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
});
}
}
if !aggr_dyn_filters.is_empty() {
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
supported_accumulators_info: aggr_dyn_filters,
}))
}
}
#[inline]
fn calculate_scaled_byte_size(
input_stats: &Statistics,
target_row_count: usize,
) -> Precision<usize> {
match (
input_stats.num_rows.get_value(),
input_stats.total_byte_size.get_value(),
) {
(Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
let bytes_per_row = input_bytes as f64 / input_rows as f64;
let scaled_bytes =
(bytes_per_row * target_row_count as f64).ceil() as usize;
Precision::Inexact(scaled_bytes)
}
_ => Precision::Absent,
}
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
input: children.swap_remove(0),
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for AggregateExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let format_expr_with_alias =
|(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
let e = e.to_string();
if &e != alias {
format!("{e} as {alias}")
} else {
e
}
};
write!(f, "AggregateExec: mode={:?}", self.mode)?;
let g: Vec<String> = if self.group_by.is_single() {
self.group_by
.expr
.iter()
.map(format_expr_with_alias)
.collect()
} else {
self.group_by
.groups
.iter()
.map(|group| {
let terms = group
.iter()
.enumerate()
.map(|(idx, is_null)| {
if *is_null {
format_expr_with_alias(
&self.group_by.null_expr[idx],
)
} else {
format_expr_with_alias(&self.group_by.expr[idx])
}
})
.collect::<Vec<String>>()
.join(", ");
format!("({terms})")
})
.collect()
};
write!(f, ", gby=[{}]", g.join(", "))?;
let a: Vec<String> = self
.aggr_expr
.iter()
.map(|agg| agg.name().to_string())
.collect();
write!(f, ", aggr=[{}]", a.join(", "))?;
if let Some(config) = self.limit_options {
write!(f, ", lim=[{}]", config.limit)?;
}
if self.input_order_mode != InputOrderMode::Linear {
write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
}
}
DisplayFormatType::TreeRender => {
let format_expr_with_alias =
|(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
let expr_sql = fmt_sql(e.as_ref()).to_string();
if &expr_sql != alias {
format!("{expr_sql} as {alias}")
} else {
expr_sql
}
};
let g: Vec<String> = if self.group_by.is_single() {
self.group_by
.expr
.iter()
.map(format_expr_with_alias)
.collect()
} else {
self.group_by
.groups
.iter()
.map(|group| {
let terms = group
.iter()
.enumerate()
.map(|(idx, is_null)| {
if *is_null {
format_expr_with_alias(
&self.group_by.null_expr[idx],
)
} else {
format_expr_with_alias(&self.group_by.expr[idx])
}
})
.collect::<Vec<String>>()
.join(", ");
format!("({terms})")
})
.collect()
};
let a: Vec<String> = self
.aggr_expr
.iter()
.map(|agg| agg.human_display().to_string())
.collect();
writeln!(f, "mode={:?}", self.mode)?;
if !g.is_empty() {
writeln!(f, "group_by={}", g.join(", "))?;
}
if !a.is_empty() {
writeln!(f, "aggr={}", a.join(", "))?;
}
if let Some(config) = self.limit_options {
writeln!(f, "limit={}", config.limit)?;
}
}
}
Ok(())
}
}
impl ExecutionPlan for AggregateExec {
fn name(&self) -> &'static str {
"AggregateExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
match &self.mode {
AggregateMode::Partial | AggregateMode::PartialReduce => {
vec![Distribution::UnspecifiedDistribution]
}
AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
}
AggregateMode::Final | AggregateMode::Single => {
vec![Distribution::SinglePartition]
}
}
}
fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
vec![self.required_input_ordering.clone()]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![self.input_order_mode != InputOrderMode::Linear]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
let mut me = AggregateExec::try_new_with_schema(
self.mode,
Arc::clone(&self.group_by),
self.aggr_expr.to_vec(),
Arc::clone(&self.filter_expr),
Arc::clone(&children[0]),
Arc::clone(&self.input_schema),
Arc::clone(&self.schema),
)?;
me.limit_options = self.limit_options;
me.dynamic_filter = self.dynamic_filter.clone();
Ok(Arc::new(me))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
self.execute_typed(partition, &context)
.map(|stream| stream.into())
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
let child_statistics = self.input().partition_statistics(partition)?;
self.statistics_inner(&child_statistics)
}
fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::LowerEqual
}
fn gather_filters_for_pushdown(
&self,
phase: FilterPushdownPhase,
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
config: &ConfigOptions,
) -> Result<FilterDescription> {
let grouping_columns: HashSet<_> = self
.group_by
.expr()
.iter()
.flat_map(|(expr, _)| collect_columns(expr))
.collect();
let mut safe_filters = Vec::new();
let mut unsafe_filters = Vec::new();
for filter in parent_filters {
let filter_columns: HashSet<_> =
collect_columns(&filter).into_iter().collect();
let references_non_grouping = !grouping_columns.is_empty()
&& !filter_columns.is_subset(&grouping_columns);
if references_non_grouping {
unsafe_filters.push(filter);
continue;
}
if self.group_by.groups().len() > 1 {
let filter_column_indices: Vec<usize> = filter_columns
.iter()
.filter_map(|filter_col| {
self.group_by.expr().iter().position(|(expr, _)| {
collect_columns(expr).contains(filter_col)
})
})
.collect();
let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
filter_column_indices
.iter()
.any(|&idx| null_mask.get(idx) == Some(&true))
});
if has_missing_column {
unsafe_filters.push(filter);
continue;
}
}
safe_filters.push(filter);
}
let child = self.children()[0];
let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
child_desc.parent_filters.extend(
unsafe_filters
.into_iter()
.map(PushedDownPredicate::unsupported),
);
if phase == FilterPushdownPhase::Post
&& config.optimizer.enable_aggregate_dynamic_filter_pushdown
&& let Some(self_dyn_filter) = &self.dynamic_filter
{
let dyn_filter = Arc::clone(&self_dyn_filter.filter);
child_desc = child_desc.with_self_filter(dyn_filter);
}
Ok(FilterDescription::new().with_child(child_desc))
}
fn handle_child_pushdown_result(
&self,
phase: FilterPushdownPhase,
child_pushdown_result: ChildPushdownResult,
_config: &ConfigOptions,
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
if phase == FilterPushdownPhase::Post
&& let Some(dyn_filter) = &self.dynamic_filter
{
let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
if !child_accepts_dyn_filter {
let mut new_node = self.clone();
new_node.dynamic_filter = None;
result = result
.with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
}
}
Ok(result)
}
}
fn create_schema(
input_schema: &Schema,
group_by: &PhysicalGroupBy,
aggr_expr: &[Arc<AggregateFunctionExpr>],
mode: AggregateMode,
) -> Result<Schema> {
let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
fields.extend(group_by.output_fields(input_schema)?);
match mode.output_mode() {
AggregateOutputMode::Final => {
for expr in aggr_expr {
fields.push(expr.field())
}
}
AggregateOutputMode::Partial => {
for expr in aggr_expr {
fields.extend(expr.state_fields()?.iter().cloned());
}
}
}
Ok(Schema::new_with_metadata(
fields,
input_schema.metadata().clone(),
))
}
fn get_aggregate_expr_req(
aggr_expr: &AggregateFunctionExpr,
group_by: &PhysicalGroupBy,
agg_mode: &AggregateMode,
include_soft_requirement: bool,
) -> Option<LexOrdering> {
if agg_mode.input_mode() == AggregateInputMode::Partial {
return None;
}
match aggr_expr.order_sensitivity() {
AggregateOrderSensitivity::Insensitive => return None,
AggregateOrderSensitivity::HardRequirement => {}
AggregateOrderSensitivity::SoftRequirement => {
if !include_soft_requirement {
return None;
}
}
AggregateOrderSensitivity::Beneficial => return None,
}
let mut sort_exprs = aggr_expr.order_bys().to_vec();
if group_by.is_single() {
let physical_exprs = group_by.input_exprs();
sort_exprs.retain(|sort_expr| {
!physical_exprs_contains(&physical_exprs, &sort_expr.expr)
});
}
LexOrdering::new(sort_exprs)
}
pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
[lhs, rhs].concat()
}
fn determine_finer(
current: &Option<LexOrdering>,
candidate: &LexOrdering,
) -> Option<bool> {
if let Some(ordering) = current {
candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
} else {
Some(true)
}
}
pub fn get_finer_aggregate_exprs_requirement(
aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
group_by: &PhysicalGroupBy,
eq_properties: &EquivalenceProperties,
agg_mode: &AggregateMode,
) -> Result<Vec<PhysicalSortRequirement>> {
let mut requirement = None;
for include_soft_requirement in [false, true] {
for aggr_expr in aggr_exprs.iter_mut() {
let Some(aggr_req) = get_aggregate_expr_req(
aggr_expr,
group_by,
agg_mode,
include_soft_requirement,
)
.and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
continue;
};
let forward_finer = determine_finer(&requirement, &aggr_req);
if let Some(finer) = forward_finer {
if !finer {
continue;
} else if eq_properties.ordering_satisfy(aggr_req.clone())? {
requirement = Some(aggr_req);
continue;
}
}
if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
let Some(rev_aggr_req) = get_aggregate_expr_req(
&reverse_aggr_expr,
group_by,
agg_mode,
include_soft_requirement,
)
.and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
*aggr_expr = Arc::new(reverse_aggr_expr);
continue;
};
if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
if !finer {
*aggr_expr = Arc::new(reverse_aggr_expr);
} else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
*aggr_expr = Arc::new(reverse_aggr_expr);
requirement = Some(rev_aggr_req);
} else {
requirement = Some(aggr_req);
}
} else if forward_finer.is_some() {
requirement = Some(aggr_req);
} else {
if !include_soft_requirement {
return not_impl_err!(
"Conflicting ordering requirements in aggregate functions is not supported"
);
}
}
}
}
}
Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
}
pub fn aggregate_expressions(
aggr_expr: &[Arc<AggregateFunctionExpr>],
mode: &AggregateMode,
col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode.input_mode() {
AggregateInputMode::Raw => Ok(aggr_expr
.iter()
.map(|agg| {
let mut result = agg.expressions();
result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
result
})
.collect()),
AggregateInputMode::Partial => {
let mut col_idx_base = col_idx_base;
aggr_expr
.iter()
.map(|agg| {
let exprs = merge_expressions(col_idx_base, agg)?;
col_idx_base += exprs.len();
Ok(exprs)
})
.collect()
}
}
}
fn merge_expressions(
index_base: usize,
expr: &AggregateFunctionExpr,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
expr.state_fields().map(|fields| {
fields
.iter()
.enumerate()
.map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
.collect()
})
}
pub type AccumulatorItem = Box<dyn Accumulator>;
pub fn create_accumulators(
aggr_expr: &[Arc<AggregateFunctionExpr>],
) -> Result<Vec<AccumulatorItem>> {
aggr_expr
.iter()
.map(|expr| expr.create_accumulator())
.collect()
}
pub fn finalize_aggregation(
accumulators: &mut [AccumulatorItem],
mode: &AggregateMode,
) -> Result<Vec<ArrayRef>> {
match mode.output_mode() {
AggregateOutputMode::Final => {
accumulators
.iter_mut()
.map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
.collect()
}
AggregateOutputMode::Partial => {
accumulators
.iter_mut()
.map(|accumulator| {
accumulator.state().and_then(|e| {
e.iter()
.map(|v| v.to_array())
.collect::<Result<Vec<ArrayRef>>>()
})
})
.flatten_ok()
.collect()
}
}
}
pub fn evaluate_many(
expr: &[Vec<Arc<dyn PhysicalExpr>>],
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
expr.iter()
.map(|expr| evaluate_expressions_to_arrays(expr, batch))
.collect()
}
fn evaluate_optional(
expr: &[Option<Arc<dyn PhysicalExpr>>],
batch: &RecordBatch,
) -> Result<Vec<Option<ArrayRef>>> {
expr.iter()
.map(|expr| {
expr.as_ref()
.map(|expr| {
expr.evaluate(batch)
.and_then(|v| v.into_array(batch.num_rows()))
})
.transpose()
})
.collect()
}
fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
if group.len() > 64 {
return not_impl_err!(
"Grouping sets with more than 64 columns are not supported"
);
}
let group_id = group.iter().fold(0u64, |acc, &is_null| {
(acc << 1) | if is_null { 1 } else { 0 }
});
let num_rows = batch.num_rows();
if group.len() <= 8 {
Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
} else if group.len() <= 16 {
Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
} else if group.len() <= 32 {
Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
} else {
Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
}
}
pub fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
let exprs = evaluate_expressions_to_arrays(
group_by.expr.iter().map(|(expr, _)| expr),
batch,
)?;
let null_exprs = evaluate_expressions_to_arrays(
group_by.null_expr.iter().map(|(expr, _)| expr),
batch,
)?;
group_by
.groups
.iter()
.map(|group| {
let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
if *is_null {
Arc::clone(&null_exprs[idx])
} else {
Arc::clone(&exprs[idx])
}
}));
if !group_by.is_single() {
group_values.push(group_id_array(group, batch)?);
}
Ok(group_values)
})
.collect()
}
#[cfg(test)]
mod tests {
use std::task::{Context, Poll};
use super::*;
use crate::RecordBatchStream;
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::common;
use crate::common::collect;
use crate::execution_plan::Boundedness;
use crate::expressions::col;
use crate::metrics::MetricValue;
use crate::test::TestMemoryExec;
use crate::test::assert_is_pending;
use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
use arrow::array::{
DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
UInt32Array, UInt64Array,
};
use arrow::compute::{SortOptions, concat_batches};
use arrow::datatypes::{DataType, Int32Type};
use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
use datafusion_common::{DataFusionError, ScalarValue, internal_err};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::memory_pool::FairSpillPool;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::Partitioning;
use datafusion_physical_expr::PhysicalSortExpr;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::expressions::lit;
use crate::projection::ProjectionExec;
use datafusion_physical_expr::projection::ProjectionExpr;
use futures::{FutureExt, Stream};
use insta::{allow_duplicates, assert_snapshot};
fn create_test_schema() -> Result<SchemaRef> {
let a = Field::new("a", DataType::Int32, true);
let b = Field::new("b", DataType::Int32, true);
let c = Field::new("c", DataType::Int32, true);
let d = Field::new("d", DataType::Int32, true);
let e = Field::new("e", DataType::Int32, true);
let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
Ok(schema)
}
fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
(
Arc::clone(&schema),
vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
],
)
}
fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
(
Arc::clone(&schema),
vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
],
)
.unwrap(),
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
],
)
.unwrap(),
],
)
}
fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
let session_config = SessionConfig::new().with_batch_size(batch_size);
let runtime = RuntimeEnvBuilder::new()
.with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
.build_arc()
.unwrap();
let task_ctx = TaskContext::default()
.with_session_config(session_config)
.with_runtime(runtime);
Arc::new(task_ctx)
}
async fn check_grouping_sets(
input: Arc<dyn ExecutionPlan>,
spill: bool,
) -> Result<()> {
let input_schema = input.schema();
let grouping_set = PhysicalGroupBy::new(
vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
vec![
(lit(ScalarValue::UInt32(None)), "a".to_string()),
(lit(ScalarValue::Float64(None)), "b".to_string()),
],
vec![
vec![false, true], vec![true, false], vec![false, false], ],
true,
);
let aggregates = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
.schema(Arc::clone(&input_schema))
.alias("COUNT(1)")
.build()?,
)];
let task_ctx = if spill {
new_spill_ctx(4, 500)
} else {
Arc::new(TaskContext::default())
};
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
grouping_set.clone(),
aggregates.clone(),
vec![None],
input,
Arc::clone(&input_schema),
)?);
let result =
collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
if spill {
allow_duplicates! {
assert_snapshot!(batches_to_sort_string(&result),
@r"
+---+-----+---------------+-----------------+
| a | b | __grouping_id | COUNT(1)[count] |
+---+-----+---------------+-----------------+
| | 1.0 | 2 | 1 |
| | 1.0 | 2 | 1 |
| | 2.0 | 2 | 1 |
| | 2.0 | 2 | 1 |
| | 3.0 | 2 | 1 |
| | 3.0 | 2 | 1 |
| | 4.0 | 2 | 1 |
| | 4.0 | 2 | 1 |
| 2 | | 1 | 1 |
| 2 | | 1 | 1 |
| 2 | 1.0 | 0 | 1 |
| 2 | 1.0 | 0 | 1 |
| 3 | | 1 | 1 |
| 3 | | 1 | 2 |
| 3 | 2.0 | 0 | 2 |
| 3 | 3.0 | 0 | 1 |
| 4 | | 1 | 1 |
| 4 | | 1 | 2 |
| 4 | 3.0 | 0 | 1 |
| 4 | 4.0 | 0 | 2 |
+---+-----+---------------+-----------------+
"
);
}
} else {
allow_duplicates! {
assert_snapshot!(batches_to_sort_string(&result),
@r"
+---+-----+---------------+-----------------+
| a | b | __grouping_id | COUNT(1)[count] |
+---+-----+---------------+-----------------+
| | 1.0 | 2 | 2 |
| | 2.0 | 2 | 2 |
| | 3.0 | 2 | 2 |
| | 4.0 | 2 | 2 |
| 2 | | 1 | 2 |
| 2 | 1.0 | 0 | 2 |
| 3 | | 1 | 3 |
| 3 | 2.0 | 0 | 2 |
| 3 | 3.0 | 0 | 1 |
| 4 | | 1 | 3 |
| 4 | 3.0 | 0 | 1 |
| 4 | 4.0 | 0 | 2 |
+---+-----+---------------+-----------------+
"
);
}
};
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
let final_grouping_set = grouping_set.as_final();
let task_ctx = if spill {
new_spill_ctx(4, 3160)
} else {
task_ctx
};
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
aggregates,
vec![None],
merge,
input_schema,
)?);
let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
let batch = concat_batches(&result[0].schema(), &result)?;
assert_eq!(batch.num_columns(), 4);
assert_eq!(batch.num_rows(), 12);
allow_duplicates! {
assert_snapshot!(
batches_to_sort_string(&result),
@r"
+---+-----+---------------+----------+
| a | b | __grouping_id | COUNT(1) |
+---+-----+---------------+----------+
| | 1.0 | 2 | 2 |
| | 2.0 | 2 | 2 |
| | 3.0 | 2 | 2 |
| | 4.0 | 2 | 2 |
| 2 | | 1 | 2 |
| 2 | 1.0 | 0 | 2 |
| 3 | | 1 | 3 |
| 3 | 2.0 | 0 | 2 |
| 3 | 3.0 | 0 | 1 |
| 4 | | 1 | 3 |
| 4 | 3.0 | 0 | 1 |
| 4 | 4.0 | 0 | 2 |
+---+-----+---------------+----------+
"
);
}
let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
assert_eq!(12, output_rows);
Ok(())
}
async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
let input_schema = input.schema();
let grouping_set = PhysicalGroupBy::new(
vec![(col("a", &input_schema)?, "a".to_string())],
vec![],
vec![vec![false]],
false,
);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
.schema(Arc::clone(&input_schema))
.alias("AVG(b)")
.build()?,
)];
let task_ctx = if spill {
new_spill_ctx(2, 1600)
} else {
Arc::new(TaskContext::default())
};
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
grouping_set.clone(),
aggregates.clone(),
vec![None],
input,
Arc::clone(&input_schema),
)?);
let result =
collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
if spill {
allow_duplicates! {
assert_snapshot!(batches_to_sort_string(&result), @r"
+---+---------------+-------------+
| a | AVG(b)[count] | AVG(b)[sum] |
+---+---------------+-------------+
| 2 | 1 | 1.0 |
| 2 | 1 | 1.0 |
| 3 | 1 | 2.0 |
| 3 | 2 | 5.0 |
| 4 | 3 | 11.0 |
+---+---------------+-------------+
");
}
} else {
allow_duplicates! {
assert_snapshot!(batches_to_sort_string(&result), @r"
+---+---------------+-------------+
| a | AVG(b)[count] | AVG(b)[sum] |
+---+---------------+-------------+
| 2 | 2 | 2.0 |
| 3 | 3 | 7.0 |
| 4 | 3 | 11.0 |
+---+---------------+-------------+
");
}
};
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
let final_grouping_set = grouping_set.as_final();
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
aggregates,
vec![None],
merge,
input_schema,
)?);
let final_stats = merged_aggregate.partition_statistics(None)?;
assert!(final_stats.total_byte_size.get_value().is_some());
let task_ctx = if spill {
new_spill_ctx(2, 2600)
} else {
Arc::clone(&task_ctx)
};
let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
let batch = concat_batches(&result[0].schema(), &result)?;
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
allow_duplicates! {
assert_snapshot!(batches_to_sort_string(&result), @r"
+---+--------------------+
| a | AVG(b) |
+---+--------------------+
| 2 | 1.0 |
| 3 | 2.3333333333333335 |
| 4 | 3.6666666666666665 |
+---+--------------------+
");
}
let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
let spill_count = metrics.spill_count().unwrap();
let spilled_bytes = metrics.spilled_bytes().unwrap();
let spilled_rows = metrics.spilled_rows().unwrap();
if spill {
assert_eq!(8, output_rows);
assert!(spill_count > 0);
assert!(spilled_bytes > 0);
assert!(spilled_rows > 0);
} else {
assert_eq!(3, output_rows);
assert_eq!(0, spill_count);
assert_eq!(0, spilled_bytes);
assert_eq!(0, spilled_rows);
}
Ok(())
}
#[derive(Debug)]
struct TestYieldingExec {
pub yield_first: bool,
cache: Arc<PlanProperties>,
}
impl TestYieldingExec {
fn new(yield_first: bool) -> Self {
let schema = some_data().0;
let cache = Self::compute_properties(schema);
Self {
yield_first,
cache: Arc::new(cache),
}
}
fn compute_properties(schema: SchemaRef) -> PlanProperties {
PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
)
}
}
impl DisplayAs for TestYieldingExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "TestYieldingExec")
}
DisplayFormatType::TreeRender => {
write!(f, "")
}
}
}
}
impl ExecutionPlan for TestYieldingExec {
fn name(&self) -> &'static str {
"TestYieldingExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
internal_err!("Children cannot be replaced in {self:?}")
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let stream = if self.yield_first {
TestYieldingStream::New
} else {
TestYieldingStream::Yielded
};
Ok(Box::pin(stream))
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
if partition.is_some() {
return Ok(Statistics::new_unknown(self.schema().as_ref()));
}
let (_, batches) = some_data();
Ok(common::compute_record_batch_statistics(
&[batches],
&self.schema(),
None,
))
}
}
enum TestYieldingStream {
New,
Yielded,
ReturnedBatch1,
ReturnedBatch2,
}
impl Stream for TestYieldingStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match &*self {
TestYieldingStream::New => {
*(self.as_mut()) = TestYieldingStream::Yielded;
cx.waker().wake_by_ref();
Poll::Pending
}
TestYieldingStream::Yielded => {
*(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
Poll::Ready(Some(Ok(some_data().1[0].clone())))
}
TestYieldingStream::ReturnedBatch1 => {
*(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
Poll::Ready(Some(Ok(some_data().1[1].clone())))
}
TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
}
}
}
impl RecordBatchStream for TestYieldingStream {
fn schema(&self) -> SchemaRef {
some_data().0
}
}
#[tokio::test]
async fn aggregate_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
check_aggregates(input, false).await
}
#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
check_grouping_sets(input, false).await
}
#[tokio::test]
async fn aggregate_source_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
check_aggregates(input, false).await
}
#[tokio::test]
async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
check_grouping_sets(input, false).await
}
#[tokio::test]
async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
check_aggregates(input, true).await
}
#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
check_grouping_sets(input, true).await
}
#[tokio::test]
async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
check_aggregates(input, true).await
}
#[tokio::test]
async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
check_grouping_sets(input, true).await
}
fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
.schema(schema)
.alias("MEDIAN(a)")
.build()
}
#[tokio::test]
async fn test_oom() -> Result<()> {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
let input_schema = input.schema();
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(1, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let groups_none = PhysicalGroupBy::default();
let groups_some = PhysicalGroupBy::new(
vec![(col("a", &input_schema)?, "a".to_string())],
vec![],
vec![vec![false]],
false,
);
let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
.schema(Arc::clone(&input_schema))
.alias("AVG(b)")
.build()?,
)];
for (version, groups, aggregates) in [
(0, groups_none, aggregates_v0),
(2, groups_some, aggregates_v2),
] {
let n_aggr = aggregates.len();
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
groups,
aggregates,
vec![None; n_aggr],
Arc::clone(&input),
Arc::clone(&input_schema),
)?);
let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
match version {
0 => {
assert!(matches!(stream, StreamType::AggregateStream(_)));
}
1 => {
assert!(matches!(stream, StreamType::GroupedHash(_)));
}
2 => {
assert!(matches!(stream, StreamType::GroupedHash(_)));
}
_ => panic!("Unknown version: {version}"),
}
let stream: SendableRecordBatchStream = stream.into();
let err = collect(stream).await.unwrap_err();
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}
Ok(())
}
#[tokio::test]
async fn test_drop_cancel_without_groups() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
let groups = PhysicalGroupBy::default();
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
.schema(Arc::clone(&schema))
.alias("AVG(a)")
.build()?,
)];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
vec![None],
blocking_exec,
schema,
)?);
let fut = crate::collect(aggregate_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn test_drop_cancel_with_groups() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Float64, true),
]));
let groups =
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
)];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates.clone(),
vec![None],
blocking_exec,
schema,
)?);
let fut = crate::collect(aggregate_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn run_first_last_multi_partitions() -> Result<()> {
for is_first_acc in [false, true] {
for spill in [false, true] {
first_last_multi_partitions(is_first_acc, spill, 4200).await?
}
}
Ok(())
}
fn test_first_value_agg_expr(
schema: &Schema,
sort_options: SortOptions,
) -> Result<Arc<AggregateFunctionExpr>> {
let order_bys = vec![PhysicalSortExpr {
expr: col("b", schema)?,
options: sort_options,
}];
let args = [col("b", schema)?];
AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
.order_by(order_bys)
.schema(Arc::new(schema.clone()))
.alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
.build()
.map(Arc::new)
}
fn test_last_value_agg_expr(
schema: &Schema,
sort_options: SortOptions,
) -> Result<Arc<AggregateFunctionExpr>> {
let order_bys = vec![PhysicalSortExpr {
expr: col("b", schema)?,
options: sort_options,
}];
let args = [col("b", schema)?];
AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
.order_by(order_bys)
.schema(Arc::new(schema.clone()))
.alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
.build()
.map(Arc::new)
}
async fn first_last_multi_partitions(
is_first_acc: bool,
spill: bool,
max_memory: usize,
) -> Result<()> {
let task_ctx = if spill {
new_spill_ctx(2, max_memory)
} else {
Arc::new(TaskContext::default())
};
let (schema, data) = some_data_v2();
let partition1 = data[0].clone();
let partition2 = data[1].clone();
let partition3 = data[2].clone();
let partition4 = data[3].clone();
let groups =
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
let sort_options = SortOptions {
descending: false,
nulls_first: false,
};
let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
vec![test_first_value_agg_expr(&schema, sort_options)?]
} else {
vec![test_last_value_agg_expr(&schema, sort_options)?]
};
let memory_exec = TestMemoryExec::try_new_exec(
&[
vec![partition1],
vec![partition2],
vec![partition3],
vec![partition4],
],
Arc::clone(&schema),
None,
)?;
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
vec![None],
memory_exec,
Arc::clone(&schema),
)?);
let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec))
as Arc<dyn ExecutionPlan>;
let aggregate_final = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
groups,
aggregates.clone(),
vec![None],
coalesce,
schema,
)?) as Arc<dyn ExecutionPlan>;
let result = crate::collect(aggregate_final, task_ctx).await?;
if is_first_acc {
allow_duplicates! {
assert_snapshot!(batches_to_string(&result), @r"
+---+--------------------------------------------+
| a | first_value(b) ORDER BY [b ASC NULLS LAST] |
+---+--------------------------------------------+
| 2 | 0.0 |
| 3 | 1.0 |
| 4 | 3.0 |
+---+--------------------------------------------+
");
}
} else {
allow_duplicates! {
assert_snapshot!(batches_to_string(&result), @r"
+---+-------------------------------------------+
| a | last_value(b) ORDER BY [b ASC NULLS LAST] |
+---+-------------------------------------------+
| 2 | 3.0 |
| 3 | 5.0 |
| 4 | 6.0 |
+---+-------------------------------------------+
");
}
};
Ok(())
}
#[tokio::test]
async fn test_get_finest_requirements() -> Result<()> {
let test_schema = create_test_schema()?;
let options = SortOptions {
descending: false,
nulls_first: false,
};
let col_a = &col("a", &test_schema)?;
let col_b = &col("b", &test_schema)?;
let col_c = &col("c", &test_schema)?;
let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
let order_by_exprs = vec![
vec![],
vec![PhysicalSortExpr {
expr: Arc::clone(col_a),
options,
}],
vec![
PhysicalSortExpr {
expr: Arc::clone(col_a),
options,
},
PhysicalSortExpr {
expr: Arc::clone(col_b),
options,
},
PhysicalSortExpr {
expr: Arc::clone(col_c),
options,
},
],
vec![
PhysicalSortExpr {
expr: Arc::clone(col_a),
options,
},
PhysicalSortExpr {
expr: Arc::clone(col_b),
options,
},
],
];
let common_requirement = vec![
PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
];
let mut aggr_exprs = order_by_exprs
.into_iter()
.map(|order_by_expr| {
AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
.alias("a")
.order_by(order_by_expr)
.schema(Arc::clone(&test_schema))
.build()
.map(Arc::new)
.unwrap()
})
.collect::<Vec<_>>();
let group_by = PhysicalGroupBy::new_single(vec![]);
let result = get_finer_aggregate_exprs_requirement(
&mut aggr_exprs,
&group_by,
&eq_properties,
&AggregateMode::Partial,
)?;
assert_eq!(result, common_requirement);
Ok(())
}
#[test]
fn test_agg_exec_same_schema() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
]));
let col_a = col("a", &schema)?;
let option_desc = SortOptions {
descending: true,
nulls_first: true,
};
let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
test_first_value_agg_expr(&schema, option_desc)?,
test_last_value_agg_expr(&schema, option_desc)?,
];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates,
vec![None, None],
Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
schema,
)?);
let new_agg =
Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
assert_eq!(new_agg.schema(), aggregate_exec.schema());
Ok(())
}
#[tokio::test]
async fn test_agg_exec_group_by_const() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
Field::new("const", DataType::Int32, false),
]));
let col_a = col("a", &schema)?;
let col_b = col("b", &schema)?;
let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
let groups = PhysicalGroupBy::new(
vec![
(col_a, "a".to_string()),
(col_b, "b".to_string()),
(const_expr, "const".to_string()),
],
vec![
(
Arc::new(Literal::new(ScalarValue::Float32(None))),
"a".to_string(),
),
(
Arc::new(Literal::new(ScalarValue::Float32(None))),
"b".to_string(),
),
(
Arc::new(Literal::new(ScalarValue::Int32(None))),
"const".to_string(),
),
],
vec![
vec![false, true, true],
vec![true, false, true],
vec![true, true, false],
],
true,
);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
.schema(Arc::clone(&schema))
.alias("1")
.build()
.map(Arc::new)?,
];
let input_batches = (0..4)
.map(|_| {
let a = Arc::new(Float32Array::from(vec![0.; 8192]));
let b = Arc::new(Float32Array::from(vec![0.; 8192]));
let c = Arc::new(Int32Array::from(vec![1; 8192]));
RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
})
.collect();
let input =
TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
groups,
aggregates.clone(),
vec![None],
input,
schema,
)?);
let output =
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
allow_duplicates! {
assert_snapshot!(batches_to_sort_string(&output), @r"
+-----+-----+-------+---------------+-------+
| a | b | const | __grouping_id | 1 |
+-----+-----+-------+---------------+-------+
| | | 1 | 6 | 32768 |
| | 0.0 | | 5 | 32768 |
| 0.0 | | | 3 | 32768 |
+-----+-----+-------+---------------+-------+
");
}
Ok(())
}
#[tokio::test]
async fn test_agg_exec_struct_of_dicts() -> Result<()> {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new(
"labels".to_string(),
DataType::Struct(
vec![
Field::new(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
),
Field::new(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
),
]
.into(),
),
false,
),
Field::new("value", DataType::UInt64, false),
])),
vec![
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
)),
Arc::new(
vec![Some("a"), None, Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
(
Arc::new(Field::new(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
)),
Arc::new(
vec![Some("b"), Some("c"), Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
])),
Arc::new(UInt64Array::from(vec![1, 1, 1])),
],
)
.expect("Failed to create RecordBatch");
let group_by = PhysicalGroupBy::new_single(vec![(
col("labels", &batch.schema())?,
"labels".to_string(),
)]);
let aggr_expr = vec![
AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
.schema(Arc::clone(&batch.schema()))
.alias(String::from("SUM(value)"))
.build()
.map(Arc::new)?,
];
let input = TestMemoryExec::try_new_exec(
&[vec![batch.clone()]],
Arc::<Schema>::clone(&batch.schema()),
None,
)?;
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::FinalPartitioned,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
batch.schema(),
)?);
let session_config = SessionConfig::default();
let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
allow_duplicates! {
assert_snapshot!(batches_to_string(&output), @r"
+--------------+------------+
| labels | SUM(value) |
+--------------+------------+
| {a: a, b: b} | 2 |
| {a: , b: c} | 1 |
+--------------+------------+
");
}
Ok(())
}
#[tokio::test]
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int32, true),
Field::new("val", DataType::Int32, true),
]));
let group_by =
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
let aggr_expr = vec![
AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
.schema(Arc::clone(&schema))
.alias(String::from("COUNT(val)"))
.build()
.map(Arc::new)?,
];
let input_data = vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
];
let input =
TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
schema,
)?);
let mut session_config = SessionConfig::default();
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&ScalarValue::Int64(Some(2)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
&ScalarValue::Float64(Some(0.1)),
);
let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
allow_duplicates! {
assert_snapshot!(batches_to_string(&output), @r"
+-----+-------------------+
| key | COUNT(val)[count] |
+-----+-------------------+
| 1 | 1 |
| 2 | 1 |
| 3 | 1 |
| 2 | 1 |
| 3 | 1 |
| 4 | 1 |
+-----+-------------------+
");
}
Ok(())
}
#[tokio::test]
async fn test_skip_aggregation_after_threshold() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int32, true),
Field::new("val", DataType::Int32, true),
]));
let group_by =
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
let aggr_expr = vec![
AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
.schema(Arc::clone(&schema))
.alias(String::from("COUNT(val)"))
.build()
.map(Arc::new)?,
];
let input_data = vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
];
let input =
TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
schema,
)?);
let mut session_config = SessionConfig::default();
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&ScalarValue::Int64(Some(5)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
&ScalarValue::Float64(Some(0.1)),
);
let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
allow_duplicates! {
assert_snapshot!(batches_to_string(&output), @r"
+-----+-------------------+
| key | COUNT(val)[count] |
+-----+-------------------+
| 1 | 1 |
| 2 | 2 |
| 3 | 2 |
| 4 | 1 |
| 2 | 1 |
| 3 | 1 |
| 4 | 1 |
+-----+-------------------+
");
}
Ok(())
}
#[test]
fn group_exprs_nullable() -> Result<()> {
let input_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));
let aggr_expr = vec![
AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
.schema(Arc::clone(&input_schema))
.alias("COUNT(a)")
.build()
.map(Arc::new)?,
];
let grouping_set = PhysicalGroupBy::new(
vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
vec![
(lit(ScalarValue::Float32(None)), "a".to_string()),
(lit(ScalarValue::Float32(None)), "b".to_string()),
],
vec![
vec![false, true], vec![false, false], ],
true,
);
let aggr_schema = create_schema(
&input_schema,
&grouping_set,
&aggr_expr,
AggregateMode::Final,
)?;
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, true),
Field::new("__grouping_id", DataType::UInt8, false),
Field::new("COUNT(a)", DataType::Int64, false),
]);
assert_eq!(aggr_schema, expected_schema);
Ok(())
}
async fn run_test_with_spill_pool_if_necessary(
pool_size: usize,
expect_spill: bool,
) -> Result<()> {
fn create_record_batch(
schema: &Arc<Schema>,
data: (Vec<u32>, Vec<f64>),
) -> Result<RecordBatch> {
Ok(RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(UInt32Array::from(data.0)),
Arc::new(Float64Array::from(data.1)),
],
)?)
}
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
let batches = vec![
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
];
let plan: Arc<dyn ExecutionPlan> =
TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
let grouping_set = PhysicalGroupBy::new(
vec![(col("a", &schema)?, "a".to_string())],
vec![],
vec![vec![false]],
false,
);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![col("b", &schema)?],
)
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
),
];
let single_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
grouping_set,
aggregates,
vec![None, None],
plan,
Arc::clone(&schema),
)?);
let batch_size = 2;
let memory_pool = Arc::new(FairSpillPool::new(pool_size));
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
.with_runtime(Arc::new(
RuntimeEnvBuilder::new()
.with_memory_pool(memory_pool)
.build()?,
)),
);
let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
assert_spill_count_metric(expect_spill, single_aggregate);
allow_duplicates! {
assert_snapshot!(batches_to_string(&result), @r"
+---+--------+--------+
| a | MIN(b) | AVG(b) |
+---+--------+--------+
| 2 | 1.0 | 1.0 |
| 3 | 2.0 | 2.0 |
| 4 | 3.0 | 3.5 |
+---+--------+--------+
");
}
Ok(())
}
fn assert_spill_count_metric(
expect_spill: bool,
single_aggregate: Arc<AggregateExec>,
) {
if let Some(metrics_set) = single_aggregate.metrics() {
let mut spill_count = 0;
for metric in metrics_set.iter() {
if let MetricValue::SpillCount(count) = metric.value() {
spill_count = count.value();
break;
}
}
if expect_spill && spill_count == 0 {
panic!(
"Expected spill but SpillCount metric not found or SpillCount was 0."
);
} else if !expect_spill && spill_count > 0 {
panic!(
"Expected no spill but found SpillCount metric with value greater than 0."
);
}
} else {
panic!("No metrics returned from the operator; cannot verify spilling.");
}
}
#[tokio::test]
async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
run_test_with_spill_pool_if_necessary(2_000, true).await?;
run_test_with_spill_pool_if_necessary(20_000, false).await?;
Ok(())
}
#[tokio::test]
async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
fn create_record_batch(
schema: &Arc<Schema>,
data: (Vec<u32>, Vec<f64>),
) -> Result<RecordBatch> {
Ok(RecordBatch::try_new(
Arc::clone(schema),
vec![
Arc::new(UInt32Array::from(data.0)),
Arc::new(Float64Array::from(data.1)),
],
)?)
}
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
let batches = vec![
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
];
let plan: Arc<dyn ExecutionPlan> =
TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
let proj = ProjectionExec::try_new(
vec![
ProjectionExpr::new(lit("0"), "l".to_string()),
ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
],
plan,
)?;
let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
let schema = plan.schema();
let grouping_set = PhysicalGroupBy::new(
vec![
(col("l", &schema)?, "l".to_string()),
(col("a", &schema)?, "a".to_string()),
],
vec![],
vec![vec![false, false]],
false,
);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![col("b", &schema)?],
)
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
),
];
let single_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
grouping_set,
aggregates,
vec![None, None],
plan,
Arc::clone(&schema),
)?);
let batch_size = 2;
let memory_pool = Arc::new(FairSpillPool::new(2000));
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
.with_runtime(Arc::new(
RuntimeEnvBuilder::new()
.with_memory_pool(memory_pool)
.build()?,
)),
);
let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
match result {
Ok(result) => {
assert_spill_count_metric(true, single_aggregate);
allow_duplicates! {
assert_snapshot!(batches_to_string(&result), @r"
+---+---+--------+--------+
| l | a | MIN(b) | AVG(b) |
+---+---+--------+--------+
| 0 | 2 | 1.0 | 1.0 |
| 0 | 3 | 2.0 | 2.0 |
| 0 | 4 | 3.0 | 3.5 |
+---+---+--------+--------+
");
}
}
Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
}
Ok(())
}
#[tokio::test]
async fn test_aggregate_statistics_edge_cases() -> Result<()> {
use crate::test::exec::StatisticsExec;
use datafusion_common::ColumnStatistics;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
]));
let input = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Exact(100),
total_byte_size: Precision::Absent,
column_statistics: vec![
ColumnStatistics::new_unknown(),
ColumnStatistics::new_unknown(),
],
},
(*schema).clone(),
)) as Arc<dyn ExecutionPlan>;
let agg = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
.schema(Arc::clone(&schema))
.alias("COUNT(a)")
.build()?,
)],
vec![None],
input,
Arc::clone(&schema),
)?);
let stats = agg.partition_statistics(None)?;
assert_eq!(stats.total_byte_size, Precision::Absent);
let input_zero = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Exact(0),
total_byte_size: Precision::Exact(0),
column_statistics: vec![
ColumnStatistics::new_unknown(),
ColumnStatistics::new_unknown(),
],
},
(*schema).clone(),
)) as Arc<dyn ExecutionPlan>;
let agg_zero = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
.schema(Arc::clone(&schema))
.alias("COUNT(a)")
.build()?,
)],
vec![None],
input_zero,
Arc::clone(&schema),
)?);
let stats_zero = agg_zero.partition_statistics(None)?;
assert_eq!(stats_zero.total_byte_size, Precision::Absent);
Ok(())
}
#[tokio::test]
async fn test_order_is_retained_when_spilling() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
Field::new("c", DataType::Int64, false),
]));
let batches = vec![vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![2])),
Arc::new(Int64Array::from(vec![2])),
Arc::new(Int64Array::from(vec![1])),
],
)?,
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1])),
Arc::new(Int64Array::from(vec![1])),
Arc::new(Int64Array::from(vec![1])),
],
)?,
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![0])),
Arc::new(Int64Array::from(vec![0])),
Arc::new(Int64Array::from(vec![1])),
],
)?,
]];
let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
let scan = scan.try_with_sort_information(vec![
LexOrdering::new([PhysicalSortExpr::new(
col("b", schema.as_ref())?,
SortOptions::default().desc(),
)])
.unwrap(),
])?;
let aggr = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new(
vec![
(col("b", schema.as_ref())?, "b".to_string()),
(col("c", schema.as_ref())?, "c".to_string()),
],
vec![],
vec![vec![false, false]],
false,
),
vec![Arc::new(
AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
.schema(Arc::clone(&schema))
.alias("SUM(c)")
.build()?,
)],
vec![None],
Arc::new(scan) as Arc<dyn ExecutionPlan>,
Arc::clone(&schema),
)?);
let task_ctx = new_spill_ctx(1, 600);
let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
assert_spill_count_metric(true, aggr);
allow_duplicates! {
assert_snapshot!(batches_to_string(&result), @r"
+---+---+--------+
| b | c | SUM(c) |
+---+---+--------+
| 2 | 1 | 1 |
| 1 | 1 | 1 |
| 0 | 1 | 1 |
+---+---+--------+
");
}
Ok(())
}
#[tokio::test]
async fn test_sort_reservation_fails_during_spill() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("g", DataType::Int64, false),
Field::new("a", DataType::Float64, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Float64, false),
Field::new("d", DataType::Float64, false),
Field::new("e", DataType::Float64, false),
]));
let batches = vec![vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1])),
Arc::new(Float64Array::from(vec![10.0])),
Arc::new(Float64Array::from(vec![20.0])),
Arc::new(Float64Array::from(vec![30.0])),
Arc::new(Float64Array::from(vec![40.0])),
Arc::new(Float64Array::from(vec![50.0])),
],
)?,
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![2])),
Arc::new(Float64Array::from(vec![11.0])),
Arc::new(Float64Array::from(vec![21.0])),
Arc::new(Float64Array::from(vec![31.0])),
Arc::new(Float64Array::from(vec![41.0])),
Arc::new(Float64Array::from(vec![51.0])),
],
)?,
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![3])),
Arc::new(Float64Array::from(vec![12.0])),
Arc::new(Float64Array::from(vec![22.0])),
Arc::new(Float64Array::from(vec![32.0])),
Arc::new(Float64Array::from(vec![42.0])),
Arc::new(Float64Array::from(vec![52.0])),
],
)?,
]];
let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
let aggr = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new(
vec![(col("g", schema.as_ref())?, "g".to_string())],
vec![],
vec![vec![false]],
false,
),
vec![
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("a", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(a)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("b", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("c", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(c)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("d", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(d)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("e", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(e)")
.build()?,
),
],
vec![None, None, None, None, None],
Arc::new(scan) as Arc<dyn ExecutionPlan>,
Arc::clone(&schema),
)?);
let task_ctx = new_spill_ctx(1, 500);
let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;
match &result {
Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
Err(e) => {
let root = e.find_root();
assert!(
matches!(root, DataFusionError::ResourcesExhausted(_)),
"Expected ResourcesExhausted, got: {root}",
);
let msg = root.to_string();
assert!(
msg.contains("Failed to reserve memory for sort during spill"),
"Expected sort reservation error, got: {msg}",
);
}
}
Ok(())
}
#[tokio::test]
async fn test_partial_reduce_mode() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(UInt32Array::from(vec![1, 2, 3])),
Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
],
)?;
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(UInt32Array::from(vec![1, 2, 3])),
Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
],
)?;
let groups =
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("SUM(b)")
.build()?,
)];
let input1 =
TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?;
let partial1 = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
vec![None],
input1,
Arc::clone(&schema),
)?);
let input2 =
TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?;
let partial2 = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
vec![None],
input2,
Arc::clone(&schema),
)?);
let task_ctx = Arc::new(TaskContext::default());
let partial_result1 =
crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?;
let partial_result2 =
crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?;
let partial_schema = partial1.schema();
let combined_input = TestMemoryExec::try_new_exec(
&[partial_result1, partial_result2],
Arc::clone(&partial_schema),
None,
)?;
let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
let partial_reduce = Arc::new(AggregateExec::try_new(
AggregateMode::PartialReduce,
groups.clone(),
aggregates.clone(),
vec![None],
coalesced,
Arc::clone(&partial_schema),
)?);
assert_eq!(partial_reduce.schema(), partial_schema);
let reduce_result =
crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx))
.await?;
let final_input = TestMemoryExec::try_new_exec(
&[reduce_result],
Arc::clone(&partial_schema),
None,
)?;
let final_agg = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
groups.clone(),
aggregates.clone(),
vec![None],
final_input,
Arc::clone(&partial_schema),
)?);
let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
assert_snapshot!(batches_to_sort_string(&result), @r"
+---+--------+
| a | SUM(b) |
+---+--------+
| 1 | 50.0 |
| 2 | 70.0 |
| 3 | 90.0 |
+---+--------+
");
Ok(())
}
}