Skip to main content

datafusion_physical_plan/aggregates/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregates functionalities
19
20use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25    no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26    topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30    ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
31    FilterPushdownPropagation, PushedDownPredicate,
32};
33use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
34use crate::{
35    DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36    SendableRecordBatchStream, Statistics, check_if_same_properties,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use parking_lot::Mutex;
41use std::collections::HashSet;
42
43use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use arrow_schema::FieldRef;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49    Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_expr::{Accumulator, Aggregate};
53use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
54use datafusion_physical_expr::equivalence::ProjectionMapping;
55use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
56use datafusion_physical_expr::{
57    ConstExpr, EquivalenceProperties, physical_exprs_contains,
58};
59use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
60use datafusion_physical_expr_common::sort_expr::{
61    LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
62};
63
64use datafusion_expr::utils::AggregateOrderSensitivity;
65use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
66use itertools::Itertools;
67use topk::hash_table::is_supported_hash_key_type;
68use topk::heap::is_supported_heap_type;
69
70pub mod group_values;
71mod no_grouping;
72pub mod order;
73mod row_hash;
74mod topk;
75mod topk_stream;
76
77/// Returns true if TopK aggregation data structures support the provided key and value types.
78///
79/// This function checks whether both the key type (used for grouping) and value type
80/// (used in min/max aggregation) can be handled by the TopK aggregation heap and hash table.
81/// Supported types include Arrow primitives (integers, floats, decimals, intervals) and
82/// UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`).
83/// ```text
84pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool {
85    is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type)
86}
87
88/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions.
89const AGGREGATION_HASH_SEED: ahash::RandomState =
90    ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
91
92/// Whether an aggregate stage consumes raw input data or intermediate
93/// accumulator state from a previous aggregation stage.
94///
95/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
96/// for how this relates to aggregate modes.
97#[derive(Debug, Copy, Clone, PartialEq, Eq)]
98pub enum AggregateInputMode {
99    /// The stage consumes raw, unaggregated input data and calls
100    /// [`Accumulator::update_batch`].
101    Raw,
102    /// The stage consumes intermediate accumulator state from a previous
103    /// aggregation stage and calls [`Accumulator::merge_batch`].
104    Partial,
105}
106
107/// Whether an aggregate stage produces intermediate accumulator state
108/// or final output values.
109///
110/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
111/// for how this relates to aggregate modes.
112#[derive(Debug, Copy, Clone, PartialEq, Eq)]
113pub enum AggregateOutputMode {
114    /// The stage produces intermediate accumulator state, serialized via
115    /// [`Accumulator::state`].
116    Partial,
117    /// The stage produces final output values via
118    /// [`Accumulator::evaluate`].
119    Final,
120}
121
122/// Aggregation modes
123///
124/// See [`Accumulator::state`] for background information on multi-phase
125/// aggregation and how these modes are used.
126///
127/// # Variants and their input/output modes
128///
129/// Each variant can be characterized by its [`AggregateInputMode`] and
130/// [`AggregateOutputMode`]:
131///
132/// ```text
133///                       | Input: Raw data           | Input: Partial state
134/// Output: Final values  | Single, SinglePartitioned | Final, FinalPartitioned
135/// Output: Partial state | Partial                   | PartialReduce
136/// ```
137///
138/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`]
139/// to query these properties.
140#[derive(Debug, Copy, Clone, PartialEq, Eq)]
141pub enum AggregateMode {
142    /// One of multiple layers of aggregation, any input partitioning
143    ///
144    /// Partial aggregate that can be applied in parallel across input
145    /// partitions.
146    ///
147    /// This is the first phase of a multi-phase aggregation.
148    Partial,
149    /// *Final* of multiple layers of aggregation, in exactly one partition
150    ///
151    /// Final aggregate that produces a single partition of output by combining
152    /// the output of multiple partial aggregates.
153    ///
154    /// This is the second phase of a multi-phase aggregation.
155    ///
156    /// This mode requires that the input is a single partition
157    ///
158    /// Note: Adjacent `Partial` and `Final` mode aggregation is equivalent to a `Single`
159    /// mode aggregation node. The `Final` mode is required since this is used in an
160    /// intermediate step. The [`CombinePartialFinalAggregate`] physical optimizer rule
161    /// will replace this combination with `Single` mode for more efficient execution.
162    ///
163    /// [`CombinePartialFinalAggregate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/combine_partial_final_agg/struct.CombinePartialFinalAggregate.html
164    Final,
165    /// *Final* of multiple layers of aggregation, input is *Partitioned*
166    ///
167    /// Final aggregate that works on pre-partitioned data.
168    ///
169    /// This mode requires that all rows with a particular grouping key are in
170    /// the same partitions, such as is the case with Hash repartitioning on the
171    /// group keys. If a group key is duplicated, duplicate groups would be
172    /// produced
173    FinalPartitioned,
174    /// *Single* layer of Aggregation, input is exactly one partition
175    ///
176    /// Applies the entire logical aggregation operation in a single operator,
177    /// as opposed to Partial / Final modes which apply the logical aggregation using
178    /// two operators.
179    ///
180    /// This mode requires that the input is a single partition (like Final)
181    Single,
182    /// *Single* layer of Aggregation, input is *Partitioned*
183    ///
184    /// Applies the entire logical aggregation operation in a single operator,
185    /// as opposed to Partial / Final modes which apply the logical aggregation
186    /// using two operators.
187    ///
188    /// This mode requires that the input has more than one partition, and is
189    /// partitioned by group key (like FinalPartitioned).
190    SinglePartitioned,
191    /// Combine multiple partial aggregations to produce a new partial
192    /// aggregation.
193    ///
194    /// Input is intermediate accumulator state (like Final), but output is
195    /// also intermediate accumulator state (like Partial). This enables
196    /// tree-reduce aggregation strategies where partial results from
197    /// multiple workers are combined in multiple stages before a final
198    /// evaluation.
199    ///
200    /// ```text
201    ///               Final
202    ///            /        \
203    ///     PartialReduce   PartialReduce
204    ///     /         \      /         \
205    ///  Partial   Partial  Partial   Partial
206    /// ```
207    PartialReduce,
208}
209
210impl AggregateMode {
211    /// Returns the [`AggregateInputMode`] for this mode: whether this
212    /// stage consumes raw input data or intermediate accumulator state.
213    ///
214    /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes)
215    /// for details.
216    pub fn input_mode(&self) -> AggregateInputMode {
217        match self {
218            AggregateMode::Partial
219            | AggregateMode::Single
220            | AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
221            AggregateMode::Final
222            | AggregateMode::FinalPartitioned
223            | AggregateMode::PartialReduce => AggregateInputMode::Partial,
224        }
225    }
226
227    /// Returns the [`AggregateOutputMode`] for this mode: whether this
228    /// stage produces intermediate accumulator state or final output values.
229    ///
230    /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes)
231    /// for details.
232    pub fn output_mode(&self) -> AggregateOutputMode {
233        match self {
234            AggregateMode::Final
235            | AggregateMode::FinalPartitioned
236            | AggregateMode::Single
237            | AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
238            AggregateMode::Partial | AggregateMode::PartialReduce => {
239                AggregateOutputMode::Partial
240            }
241        }
242    }
243}
244
245/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
246/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
247/// and a single group [false, false].
248/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression
249/// into multiple groups, using null expressions to align each group.
250/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should
251/// create a `PhysicalGroupBy` like
252/// ```text
253/// PhysicalGroupBy {
254///     expr: [(col(a), a), (col(b), b)],
255///     null_expr: [(NULL, a), (NULL, b)],
256///     groups: [
257///         [false, false], // (a,b)
258///         [false, true],  // (a) <=> (a, NULL)
259///         [true, false]   // (b) <=> (NULL, b)
260///     ]
261/// }
262/// ```
263#[derive(Clone, Debug, Default)]
264pub struct PhysicalGroupBy {
265    /// Distinct (Physical Expr, Alias) in the grouping set
266    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
267    /// Corresponding NULL expressions for expr
268    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
269    /// Null mask for each group in this grouping set. Each group is
270    /// composed of either one of the group expressions in expr or a null
271    /// expression in null_expr. If `groups[i][j]` is true, then the
272    /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
273    groups: Vec<Vec<bool>>,
274    /// True when GROUPING SETS/CUBE/ROLLUP are used so `__grouping_id` should
275    /// be included in the output schema.
276    has_grouping_set: bool,
277}
278
279impl PhysicalGroupBy {
280    /// Create a new `PhysicalGroupBy`
281    pub fn new(
282        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
283        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
284        groups: Vec<Vec<bool>>,
285        has_grouping_set: bool,
286    ) -> Self {
287        Self {
288            expr,
289            null_expr,
290            groups,
291            has_grouping_set,
292        }
293    }
294
295    /// Create a GROUPING SET with only a single group. This is the "standard"
296    /// case when building a plan from an expression such as `GROUP BY a,b,c`
297    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
298        let num_exprs = expr.len();
299        Self {
300            expr,
301            null_expr: vec![],
302            groups: vec![vec![false; num_exprs]],
303            has_grouping_set: false,
304        }
305    }
306
307    /// Calculate GROUP BY expressions nullable
308    pub fn exprs_nullable(&self) -> Vec<bool> {
309        let mut exprs_nullable = vec![false; self.expr.len()];
310        for group in self.groups.iter() {
311            group.iter().enumerate().for_each(|(index, is_null)| {
312                if *is_null {
313                    exprs_nullable[index] = true;
314                }
315            })
316        }
317        exprs_nullable
318    }
319
320    /// Returns true if this has no grouping at all (including no GROUPING SETS)
321    pub fn is_true_no_grouping(&self) -> bool {
322        self.is_empty() && !self.has_grouping_set
323    }
324
325    /// Returns the group expressions
326    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
327        &self.expr
328    }
329
330    /// Returns the null expressions
331    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
332        &self.null_expr
333    }
334
335    /// Returns the group null masks
336    pub fn groups(&self) -> &[Vec<bool>] {
337        &self.groups
338    }
339
340    /// Returns true if this grouping uses GROUPING SETS, CUBE or ROLLUP.
341    pub fn has_grouping_set(&self) -> bool {
342        self.has_grouping_set
343    }
344
345    /// Returns true if this `PhysicalGroupBy` has no group expressions
346    pub fn is_empty(&self) -> bool {
347        self.expr.is_empty()
348    }
349
350    /// Returns true if this is a "simple" GROUP BY (not using GROUPING SETS/CUBE/ROLLUP).
351    /// This determines whether the `__grouping_id` column is included in the output schema.
352    pub fn is_single(&self) -> bool {
353        !self.has_grouping_set
354    }
355
356    /// Calculate GROUP BY expressions according to input schema.
357    pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
358        self.expr
359            .iter()
360            .map(|(expr, _alias)| Arc::clone(expr))
361            .collect()
362    }
363
364    /// The number of expressions in the output schema.
365    fn num_output_exprs(&self) -> usize {
366        let mut num_exprs = self.expr.len();
367        if self.has_grouping_set {
368            num_exprs += 1
369        }
370        num_exprs
371    }
372
373    /// Return grouping expressions as they occur in the output schema.
374    pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
375        let num_output_exprs = self.num_output_exprs();
376        let mut output_exprs = Vec::with_capacity(num_output_exprs);
377        output_exprs.extend(
378            self.expr
379                .iter()
380                .enumerate()
381                .take(num_output_exprs)
382                .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
383        );
384        if self.has_grouping_set {
385            output_exprs.push(Arc::new(Column::new(
386                Aggregate::INTERNAL_GROUPING_ID,
387                self.expr.len(),
388            )) as _);
389        }
390        output_exprs
391    }
392
393    /// Returns the number expression as grouping keys.
394    pub fn num_group_exprs(&self) -> usize {
395        self.expr.len() + usize::from(self.has_grouping_set)
396    }
397
398    pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
399        Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
400    }
401
402    /// Returns the fields that are used as the grouping keys.
403    fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
404        let mut fields = Vec::with_capacity(self.num_group_exprs());
405        for ((expr, name), group_expr_nullable) in
406            self.expr.iter().zip(self.exprs_nullable().into_iter())
407        {
408            fields.push(
409                Field::new(
410                    name,
411                    expr.data_type(input_schema)?,
412                    group_expr_nullable || expr.nullable(input_schema)?,
413                )
414                .with_metadata(expr.return_field(input_schema)?.metadata().clone())
415                .into(),
416            );
417        }
418        if self.has_grouping_set {
419            fields.push(
420                Field::new(
421                    Aggregate::INTERNAL_GROUPING_ID,
422                    Aggregate::grouping_id_type(self.expr.len()),
423                    false,
424                )
425                .into(),
426            );
427        }
428        Ok(fields)
429    }
430
431    /// Returns the output fields of the group by.
432    ///
433    /// This might be different from the `group_fields` that might contain internal expressions that
434    /// should not be part of the output schema.
435    fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
436        let mut fields = self.group_fields(input_schema)?;
437        fields.truncate(self.num_output_exprs());
438        Ok(fields)
439    }
440
441    /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
442    /// aggregation.
443    pub fn as_final(&self) -> PhysicalGroupBy {
444        let expr: Vec<_> =
445            self.output_exprs()
446                .into_iter()
447                .zip(
448                    self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
449                        Aggregate::INTERNAL_GROUPING_ID.to_owned(),
450                    )),
451                )
452                .collect();
453        let num_exprs = expr.len();
454        let groups = if self.expr.is_empty() && !self.has_grouping_set {
455            // No GROUP BY expressions - should have no groups
456            vec![]
457        } else {
458            vec![vec![false; num_exprs]]
459        };
460        Self {
461            expr,
462            null_expr: vec![],
463            groups,
464            has_grouping_set: false,
465        }
466    }
467}
468
469impl PartialEq for PhysicalGroupBy {
470    fn eq(&self, other: &PhysicalGroupBy) -> bool {
471        self.expr.len() == other.expr.len()
472            && self
473                .expr
474                .iter()
475                .zip(other.expr.iter())
476                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
477            && self.null_expr.len() == other.null_expr.len()
478            && self
479                .null_expr
480                .iter()
481                .zip(other.null_expr.iter())
482                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
483            && self.groups == other.groups
484            && self.has_grouping_set == other.has_grouping_set
485    }
486}
487
488#[expect(clippy::large_enum_variant)]
489enum StreamType {
490    AggregateStream(AggregateStream),
491    GroupedHash(GroupedHashAggregateStream),
492    GroupedPriorityQueue(GroupedTopKAggregateStream),
493}
494
495impl From<StreamType> for SendableRecordBatchStream {
496    fn from(stream: StreamType) -> Self {
497        match stream {
498            StreamType::AggregateStream(stream) => Box::pin(stream),
499            StreamType::GroupedHash(stream) => Box::pin(stream),
500            StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
501        }
502    }
503}
504
505/// # Aggregate Dynamic Filter Pushdown Overview
506///
507/// For queries like
508///   -- `example_table(type TEXT, val INT)`
509///   SELECT min(val)
510///   FROM example_table
511///   WHERE type='A';
512///
513/// And `example_table`'s physical representation is a partitioned parquet file with
514/// column statistics
515/// - part-0.parquet: val {min=0, max=100}
516/// - part-1.parquet: val {min=100, max=200}
517/// - ...
518/// - part-100.parquet: val {min=10000, max=10100}
519///
520/// After scanning the 1st file, we know we only have to read files if their minimal
521/// value on `val` column is less than 0, the minimal `val` value in the 1st file.
522///
523/// We can skip scanning the remaining file by implementing dynamic filter, the
524/// intuition is we keep a shared data structure for current min in both `AggregateExec`
525/// and `DataSourceExec`, and let it update during execution, so the scanner can
526/// know during execution if it's possible to skip scanning certain files. See
527/// physical optimizer rule `FilterPushdown` for details.
528///
529/// # Implementation
530///
531/// ## Enable Condition
532/// - No grouping (no `GROUP BY` clause in the sql, only a single global group to aggregate)
533/// - The aggregate expression must be `min`/`max`, and evaluate directly on columns.
534///   Note multiple aggregate expressions that satisfy this requirement are allowed,
535///   and a dynamic filter will be constructed combining all applicable expr's
536///   states. See more in the following example with dynamic filter on multiple columns.
537///
538/// ## Filter Construction
539/// The filter is kept in the `DataSourceExec`, and it will gets update during execution,
540/// the reader will interpret it as "the upstream only needs rows that such filter
541/// predicate is evaluated to true", and certain scanner implementation like `parquet`
542/// can evalaute column statistics on those dynamic filters, to decide if they can
543/// prune a whole range.
544///
545/// ### Examples
546/// - Expr: `min(a)`, Dynamic Filter: `a < a_cur_min`
547/// - Expr: `min(a), max(a), min(b)`, Dynamic Filter: `(a < a_cur_min) OR (a > a_cur_max) OR (b < b_cur_min)`
548#[derive(Debug, Clone)]
549struct AggrDynFilter {
550    /// The physical expr for the dynamic filter shared between the `AggregateExec`
551    /// and the parquet scanner.
552    filter: Arc<DynamicFilterPhysicalExpr>,
553    /// The current bounds for the dynamic filter, updates during the execution to
554    /// tighten the bound for more effective pruning.
555    ///
556    /// Each vector element is for the accumulators that support dynamic filter.
557    /// e.g. This `AggregateExec` has accumulator:
558    /// min(a), avg(a), max(b)
559    /// And this field stores [PerAccumulatorDynFilter(min(a)), PerAccumulatorDynFilter(min(b))]
560    supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
561}
562
563// ---- Aggregate Dynamic Filter Utility Structs ----
564
565/// Aggregate expressions that support the dynamic filter pushdown in aggregation.
566/// See comments in [`AggrDynFilter`] for conditions.
567#[derive(Debug, Clone)]
568struct PerAccumulatorDynFilter {
569    aggr_type: DynamicFilterAggregateType,
570    /// During planning and optimization, the parent structure is kept in `AggregateExec`,
571    /// this index is into `aggr_expr` vec inside `AggregateExec`.
572    /// During execution, the parent struct is moved into `AggregateStream` (stream
573    /// for no grouping aggregate execution), and this index is into    `aggregate_expressions`
574    /// vec inside `AggregateStreamInner`
575    aggr_index: usize,
576    // The current bound. Shared among all streams.
577    shared_bound: Arc<Mutex<ScalarValue>>,
578}
579
580/// Aggregate types that are supported for dynamic filter in `AggregateExec`
581#[derive(Debug, Clone)]
582enum DynamicFilterAggregateType {
583    Min,
584    Max,
585}
586
587/// Configuration for limit-based optimizations in aggregation
588#[derive(Debug, Clone, Copy, PartialEq, Eq)]
589pub struct LimitOptions {
590    /// The maximum number of rows to return
591    pub limit: usize,
592    /// Optional ordering direction (true = descending, false = ascending)
593    /// This is used for TopK aggregation to maintain a priority queue with the correct ordering
594    pub descending: Option<bool>,
595}
596
597impl LimitOptions {
598    /// Create a new LimitOptions with a limit and no specific ordering
599    pub fn new(limit: usize) -> Self {
600        Self {
601            limit,
602            descending: None,
603        }
604    }
605
606    /// Create a new LimitOptions with a limit and ordering direction
607    pub fn new_with_order(limit: usize, descending: bool) -> Self {
608        Self {
609            limit,
610            descending: Some(descending),
611        }
612    }
613
614    pub fn limit(&self) -> usize {
615        self.limit
616    }
617
618    pub fn descending(&self) -> Option<bool> {
619        self.descending
620    }
621}
622
623/// Hash aggregate execution plan
624#[derive(Debug, Clone)]
625pub struct AggregateExec {
626    /// Aggregation mode (full, partial)
627    mode: AggregateMode,
628    /// Group by expressions
629    /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance.
630    group_by: Arc<PhysicalGroupBy>,
631    /// Aggregate expressions
632    /// The same reason to [`Arc`] it as for [`Self::group_by`].
633    aggr_expr: Arc<[Arc<AggregateFunctionExpr>]>,
634    /// FILTER (WHERE clause) expression for each aggregate expression
635    /// The same reason to [`Arc`] it as for [`Self::group_by`].
636    filter_expr: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
637    /// Configuration for limit-based optimizations
638    limit_options: Option<LimitOptions>,
639    /// Input plan, could be a partial aggregate or the input to the aggregate
640    pub input: Arc<dyn ExecutionPlan>,
641    /// Schema after the aggregate is applied
642    schema: SchemaRef,
643    /// Input schema before any aggregation is applied. For partial aggregate this will be the
644    /// same as input.schema() but for the final aggregate it will be the same as the input
645    /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`.
646    /// We need the input schema of partial aggregate to be able to deserialize aggregate
647    /// expressions from protobuf for final aggregate.
648    pub input_schema: SchemaRef,
649    /// Execution metrics
650    metrics: ExecutionPlanMetricsSet,
651    required_input_ordering: Option<OrderingRequirements>,
652    /// Describes how the input is ordered relative to the group by columns
653    input_order_mode: InputOrderMode,
654    cache: Arc<PlanProperties>,
655    /// During initialization, if the plan supports dynamic filtering (see [`AggrDynFilter`]),
656    /// it is set to `Some(..)` regardless of whether it can be pushed down to a child node.
657    ///
658    /// During filter pushdown optimization, if a child node can accept this filter,
659    /// it remains `Some(..)` to enable dynamic filtering during aggregate execution;
660    /// otherwise, it is cleared to `None`.
661    dynamic_filter: Option<Arc<AggrDynFilter>>,
662}
663
664impl AggregateExec {
665    /// Function used in `OptimizeAggregateOrder` optimizer rule,
666    /// where we need parts of the new value, others cloned from the old one
667    /// Rewrites aggregate exec with new aggregate expressions.
668    pub fn with_new_aggr_exprs(
669        &self,
670        aggr_expr: impl Into<Arc<[Arc<AggregateFunctionExpr>]>>,
671    ) -> Self {
672        Self {
673            aggr_expr: aggr_expr.into(),
674            // clone the rest of the fields
675            required_input_ordering: self.required_input_ordering.clone(),
676            metrics: ExecutionPlanMetricsSet::new(),
677            input_order_mode: self.input_order_mode.clone(),
678            cache: Arc::clone(&self.cache),
679            mode: self.mode,
680            group_by: Arc::clone(&self.group_by),
681            filter_expr: Arc::clone(&self.filter_expr),
682            limit_options: self.limit_options,
683            input: Arc::clone(&self.input),
684            schema: Arc::clone(&self.schema),
685            input_schema: Arc::clone(&self.input_schema),
686            dynamic_filter: self.dynamic_filter.clone(),
687        }
688    }
689
690    /// Clone this exec, overriding only the limit hint.
691    pub fn with_new_limit_options(&self, limit_options: Option<LimitOptions>) -> Self {
692        Self {
693            limit_options,
694            // clone the rest of the fields
695            required_input_ordering: self.required_input_ordering.clone(),
696            metrics: ExecutionPlanMetricsSet::new(),
697            input_order_mode: self.input_order_mode.clone(),
698            cache: Arc::clone(&self.cache),
699            mode: self.mode,
700            group_by: Arc::clone(&self.group_by),
701            aggr_expr: Arc::clone(&self.aggr_expr),
702            filter_expr: Arc::clone(&self.filter_expr),
703            input: Arc::clone(&self.input),
704            schema: Arc::clone(&self.schema),
705            input_schema: Arc::clone(&self.input_schema),
706            dynamic_filter: self.dynamic_filter.clone(),
707        }
708    }
709
710    pub fn cache(&self) -> &PlanProperties {
711        &self.cache
712    }
713
714    /// Create a new hash aggregate execution plan
715    pub fn try_new(
716        mode: AggregateMode,
717        group_by: impl Into<Arc<PhysicalGroupBy>>,
718        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
719        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
720        input: Arc<dyn ExecutionPlan>,
721        input_schema: SchemaRef,
722    ) -> Result<Self> {
723        let group_by = group_by.into();
724        let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
725
726        let schema = Arc::new(schema);
727        AggregateExec::try_new_with_schema(
728            mode,
729            group_by,
730            aggr_expr,
731            filter_expr,
732            input,
733            input_schema,
734            schema,
735        )
736    }
737
738    /// Create a new hash aggregate execution plan with the given schema.
739    /// This constructor isn't part of the public API, it is used internally
740    /// by DataFusion to enforce schema consistency during when re-creating
741    /// `AggregateExec`s inside optimization rules. Schema field names of an
742    /// `AggregateExec` depends on the names of aggregate expressions. Since
743    /// a rule may re-write aggregate expressions (e.g. reverse them) during
744    /// initialization, field names may change inadvertently if one re-creates
745    /// the schema in such cases.
746    fn try_new_with_schema(
747        mode: AggregateMode,
748        group_by: impl Into<Arc<PhysicalGroupBy>>,
749        mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
750        filter_expr: impl Into<Arc<[Option<Arc<dyn PhysicalExpr>>]>>,
751        input: Arc<dyn ExecutionPlan>,
752        input_schema: SchemaRef,
753        schema: SchemaRef,
754    ) -> Result<Self> {
755        let group_by = group_by.into();
756        let filter_expr = filter_expr.into();
757
758        // Make sure arguments are consistent in size
759        assert_eq_or_internal_err!(
760            aggr_expr.len(),
761            filter_expr.len(),
762            "Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
763            aggr_expr,
764            filter_expr
765        );
766
767        let input_eq_properties = input.equivalence_properties();
768        // Get GROUP BY expressions:
769        let groupby_exprs = group_by.input_exprs();
770        // If existing ordering satisfies a prefix of the GROUP BY expressions,
771        // prefix requirements with this section. In this case, aggregation will
772        // work more efficiently.
773        // Copy the `PhysicalSortExpr`s to retain the sort options.
774        let (new_sort_exprs, indices) =
775            input_eq_properties.find_longest_permutation(&groupby_exprs)?;
776
777        let mut new_requirements = new_sort_exprs
778            .into_iter()
779            .map(PhysicalSortRequirement::from)
780            .collect::<Vec<_>>();
781
782        let req = get_finer_aggregate_exprs_requirement(
783            &mut aggr_expr,
784            &group_by,
785            input_eq_properties,
786            &mode,
787        )?;
788        new_requirements.extend(req);
789
790        let required_input_ordering =
791            LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
792
793        // If our aggregation has grouping sets then our base grouping exprs will
794        // be expanded based on the flags in `group_by.groups` where for each
795        // group we swap the grouping expr for `null` if the flag is `true`
796        // That means that each index in `indices` is valid if and only if
797        // it is not null in every group
798        let indices: Vec<usize> = indices
799            .into_iter()
800            .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
801            .collect();
802
803        let input_order_mode = if indices.len() == groupby_exprs.len()
804            && !indices.is_empty()
805            && group_by.groups.len() == 1
806        {
807            InputOrderMode::Sorted
808        } else if !indices.is_empty() {
809            InputOrderMode::PartiallySorted(indices)
810        } else {
811            InputOrderMode::Linear
812        };
813
814        // construct a map from the input expression to the output expression of the Aggregation group by
815        let group_expr_mapping =
816            ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
817
818        let cache = Self::compute_properties(
819            &input,
820            Arc::clone(&schema),
821            &group_expr_mapping,
822            &mode,
823            &input_order_mode,
824            aggr_expr.as_ref(),
825        )?;
826
827        let mut exec = AggregateExec {
828            mode,
829            group_by,
830            aggr_expr: aggr_expr.into(),
831            filter_expr,
832            input,
833            schema,
834            input_schema,
835            metrics: ExecutionPlanMetricsSet::new(),
836            required_input_ordering,
837            limit_options: None,
838            input_order_mode,
839            cache: Arc::new(cache),
840            dynamic_filter: None,
841        };
842
843        exec.init_dynamic_filter();
844
845        Ok(exec)
846    }
847
848    /// Aggregation mode (full, partial)
849    pub fn mode(&self) -> &AggregateMode {
850        &self.mode
851    }
852
853    /// Set the limit options for this AggExec
854    pub fn with_limit_options(mut self, limit_options: Option<LimitOptions>) -> Self {
855        self.limit_options = limit_options;
856        self
857    }
858
859    /// Get the limit options (if set)
860    pub fn limit_options(&self) -> Option<LimitOptions> {
861        self.limit_options
862    }
863
864    /// Grouping expressions
865    pub fn group_expr(&self) -> &PhysicalGroupBy {
866        &self.group_by
867    }
868
869    /// Grouping expressions as they occur in the output schema
870    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
871        self.group_by.output_exprs()
872    }
873
874    /// Aggregate expressions
875    pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
876        &self.aggr_expr
877    }
878
879    /// FILTER (WHERE clause) expression for each aggregate expression
880    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
881        &self.filter_expr
882    }
883
884    /// Input plan
885    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
886        &self.input
887    }
888
889    /// Get the input schema before any aggregates are applied
890    pub fn input_schema(&self) -> SchemaRef {
891        Arc::clone(&self.input_schema)
892    }
893
894    fn execute_typed(
895        &self,
896        partition: usize,
897        context: &Arc<TaskContext>,
898    ) -> Result<StreamType> {
899        if self.group_by.is_true_no_grouping() {
900            return Ok(StreamType::AggregateStream(AggregateStream::new(
901                self, context, partition,
902            )?));
903        }
904
905        // grouping by an expression that has a sort/limit upstream
906        if let Some(config) = self.limit_options
907            && !self.is_unordered_unfiltered_group_by_distinct()
908        {
909            return Ok(StreamType::GroupedPriorityQueue(
910                GroupedTopKAggregateStream::new(self, context, partition, config.limit)?,
911            ));
912        }
913
914        // grouping by something else and we need to just materialize all results
915        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
916            self, context, partition,
917        )?))
918    }
919
920    /// Finds the DataType and SortDirection for this Aggregate, if there is one
921    pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
922        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
923        agg_expr.get_minmax_desc()
924    }
925
926    /// true, if this Aggregate has a group-by with no required or explicit ordering,
927    /// no filtering and no aggregate expressions
928    /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule
929    /// on an AggregateExec.
930    pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
931        if self
932            .limit_options()
933            .and_then(|config| config.descending)
934            .is_some()
935        {
936            return false;
937        }
938        // ensure there is a group by
939        if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
940            return false;
941        }
942        // ensure there are no aggregate expressions
943        if !self.aggr_expr().is_empty() {
944            return false;
945        }
946        // ensure there are no filters on aggregate expressions; the above check
947        // may preclude this case
948        if self.filter_expr().iter().any(|e| e.is_some()) {
949            return false;
950        }
951        // ensure there are no order by expressions
952        if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
953            return false;
954        }
955        // ensure there is no output ordering; can this rule be relaxed?
956        if self.properties().output_ordering().is_some() {
957            return false;
958        }
959        // ensure no ordering is required on the input
960        if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
961            return matches!(requirement, OrderingRequirements::Hard(_));
962        }
963        true
964    }
965
966    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
967    pub fn compute_properties(
968        input: &Arc<dyn ExecutionPlan>,
969        schema: SchemaRef,
970        group_expr_mapping: &ProjectionMapping,
971        mode: &AggregateMode,
972        input_order_mode: &InputOrderMode,
973        aggr_exprs: &[Arc<AggregateFunctionExpr>],
974    ) -> Result<PlanProperties> {
975        // Construct equivalence properties:
976        let mut eq_properties = input
977            .equivalence_properties()
978            .project(group_expr_mapping, schema);
979
980        // If the group by is empty, then we ensure that the operator will produce
981        // only one row, and mark the generated result as a constant value.
982        if group_expr_mapping.is_empty() {
983            let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
984                let column = Arc::new(Column::new(func.name(), idx));
985                ConstExpr::from(column as Arc<dyn PhysicalExpr>)
986            });
987            eq_properties.add_constants(new_constants)?;
988        }
989
990        // Group by expression will be a distinct value after the aggregation.
991        // Add it into the constraint set.
992        let mut constraints = eq_properties.constraints().to_vec();
993        let new_constraint = Constraint::Unique(
994            group_expr_mapping
995                .iter()
996                .flat_map(|(_, target_cols)| {
997                    target_cols.iter().flat_map(|(expr, _)| {
998                        expr.as_any().downcast_ref::<Column>().map(|c| c.index())
999                    })
1000                })
1001                .collect(),
1002        );
1003        constraints.push(new_constraint);
1004        eq_properties =
1005            eq_properties.with_constraints(Constraints::new_unverified(constraints));
1006
1007        // Get output partitioning:
1008        let input_partitioning = input.output_partitioning().clone();
1009        let output_partitioning = match mode.input_mode() {
1010            AggregateInputMode::Raw => {
1011                // First stage aggregation will not change the output partitioning,
1012                // but needs to respect aliases (e.g. mapping in the GROUP BY
1013                // expression).
1014                let input_eq_properties = input.equivalence_properties();
1015                input_partitioning.project(group_expr_mapping, input_eq_properties)
1016            }
1017            AggregateInputMode::Partial => input_partitioning.clone(),
1018        };
1019
1020        // TODO: Emission type and boundedness information can be enhanced here
1021        let emission_type = if *input_order_mode == InputOrderMode::Linear {
1022            EmissionType::Final
1023        } else {
1024            input.pipeline_behavior()
1025        };
1026
1027        Ok(PlanProperties::new(
1028            eq_properties,
1029            output_partitioning,
1030            emission_type,
1031            input.boundedness(),
1032        ))
1033    }
1034
1035    pub fn input_order_mode(&self) -> &InputOrderMode {
1036        &self.input_order_mode
1037    }
1038
1039    fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
1040        // TODO stats: group expressions:
1041        // - once expressions will be able to compute their own stats, use it here
1042        // - case where we group by on a column for which with have the `distinct` stat
1043        // TODO stats: aggr expression:
1044        // - aggregations sometimes also preserve invariants such as min, max...
1045
1046        let column_statistics = {
1047            // self.schema: [<group by exprs>, <aggregate exprs>]
1048            let mut column_statistics = Statistics::unknown_column(&self.schema());
1049
1050            for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
1051                if let Some(col) = expr.as_any().downcast_ref::<Column>() {
1052                    column_statistics[idx].max_value = child_statistics.column_statistics
1053                        [col.index()]
1054                    .max_value
1055                    .clone();
1056
1057                    column_statistics[idx].min_value = child_statistics.column_statistics
1058                        [col.index()]
1059                    .min_value
1060                    .clone();
1061                }
1062            }
1063
1064            column_statistics
1065        };
1066        match self.mode {
1067            AggregateMode::Final | AggregateMode::FinalPartitioned
1068                if self.group_by.expr.is_empty() =>
1069            {
1070                let total_byte_size =
1071                    Self::calculate_scaled_byte_size(child_statistics, 1);
1072
1073                Ok(Statistics {
1074                    num_rows: Precision::Exact(1),
1075                    column_statistics,
1076                    total_byte_size,
1077                })
1078            }
1079            _ => {
1080                // When the input row count is 1, we can adopt that statistic keeping its reliability.
1081                // When it is larger than 1, we degrade the precision since it may decrease after aggregation.
1082                let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
1083                {
1084                    if *value > 1 {
1085                        child_statistics.num_rows.to_inexact()
1086                    } else if *value == 0 {
1087                        child_statistics.num_rows
1088                    } else {
1089                        // num_rows = 1 case
1090                        let grouping_set_num = self.group_by.groups.len();
1091                        child_statistics.num_rows.map(|x| x * grouping_set_num)
1092                    }
1093                } else {
1094                    Precision::Absent
1095                };
1096
1097                let total_byte_size = num_rows
1098                    .get_value()
1099                    .and_then(|&output_rows| {
1100                        Self::calculate_scaled_byte_size(child_statistics, output_rows)
1101                            .get_value()
1102                            .map(|&bytes| Precision::Inexact(bytes))
1103                    })
1104                    .unwrap_or(Precision::Absent);
1105
1106                Ok(Statistics {
1107                    num_rows,
1108                    column_statistics,
1109                    total_byte_size,
1110                })
1111            }
1112        }
1113    }
1114
1115    /// Check if dynamic filter is possible for the current plan node.
1116    /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
1117    /// - If not supported, `self.dynamic_filter` should be kept `None`
1118    fn init_dynamic_filter(&mut self) {
1119        if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
1120            debug_assert!(
1121                self.dynamic_filter.is_none(),
1122                "The current operator node does not support dynamic filter"
1123            );
1124            return;
1125        }
1126
1127        // Already initialized.
1128        if self.dynamic_filter.is_some() {
1129            return;
1130        }
1131
1132        // Collect supported accumulators
1133        // It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1134        // to `AggregateStream`
1135        let mut aggr_dyn_filters = Vec::new();
1136        // All column references in the dynamic filter, used when initializing the dynamic
1137        // filter, and it's used to decide if this dynamic filter is able to get push
1138        // through certain node during optimization.
1139        let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
1140        for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
1141            // 1. Only `min` or `max` aggregate function
1142            let fun_name = aggr_expr.fun().name();
1143            // HACK: Should check the function type more precisely
1144            // Issue: <https://github.com/apache/datafusion/issues/18643>
1145            let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
1146                DynamicFilterAggregateType::Min
1147            } else if fun_name.eq_ignore_ascii_case("max") {
1148                DynamicFilterAggregateType::Max
1149            } else {
1150                return;
1151            };
1152
1153            // 2. arg should be only 1 column reference
1154            if let [arg] = aggr_expr.expressions().as_slice()
1155                && arg.as_any().is::<Column>()
1156            {
1157                all_cols.push(Arc::clone(arg));
1158                aggr_dyn_filters.push(PerAccumulatorDynFilter {
1159                    aggr_type,
1160                    aggr_index: i,
1161                    shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
1162                });
1163            }
1164        }
1165
1166        if !aggr_dyn_filters.is_empty() {
1167            self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1168                filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1169                supported_accumulators_info: aggr_dyn_filters,
1170            }))
1171        }
1172    }
1173
1174    /// Calculate scaled byte size based on row count ratio.
1175    /// Returns `Precision::Absent` if input statistics are insufficient.
1176    /// Returns `Precision::Inexact` with the scaled value otherwise.
1177    ///
1178    /// This is a simple heuristic that assumes uniform row sizes.
1179    #[inline]
1180    fn calculate_scaled_byte_size(
1181        input_stats: &Statistics,
1182        target_row_count: usize,
1183    ) -> Precision<usize> {
1184        match (
1185            input_stats.num_rows.get_value(),
1186            input_stats.total_byte_size.get_value(),
1187        ) {
1188            (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
1189                let bytes_per_row = input_bytes as f64 / input_rows as f64;
1190                let scaled_bytes =
1191                    (bytes_per_row * target_row_count as f64).ceil() as usize;
1192                Precision::Inexact(scaled_bytes)
1193            }
1194            _ => Precision::Absent,
1195        }
1196    }
1197
1198    fn with_new_children_and_same_properties(
1199        &self,
1200        mut children: Vec<Arc<dyn ExecutionPlan>>,
1201    ) -> Self {
1202        Self {
1203            input: children.swap_remove(0),
1204            metrics: ExecutionPlanMetricsSet::new(),
1205            ..Self::clone(self)
1206        }
1207    }
1208}
1209
1210impl DisplayAs for AggregateExec {
1211    fn fmt_as(
1212        &self,
1213        t: DisplayFormatType,
1214        f: &mut std::fmt::Formatter,
1215    ) -> std::fmt::Result {
1216        match t {
1217            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1218                let format_expr_with_alias =
1219                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1220                        let e = e.to_string();
1221                        if &e != alias {
1222                            format!("{e} as {alias}")
1223                        } else {
1224                            e
1225                        }
1226                    };
1227
1228                write!(f, "AggregateExec: mode={:?}", self.mode)?;
1229                let g: Vec<String> = if self.group_by.is_single() {
1230                    self.group_by
1231                        .expr
1232                        .iter()
1233                        .map(format_expr_with_alias)
1234                        .collect()
1235                } else {
1236                    self.group_by
1237                        .groups
1238                        .iter()
1239                        .map(|group| {
1240                            let terms = group
1241                                .iter()
1242                                .enumerate()
1243                                .map(|(idx, is_null)| {
1244                                    if *is_null {
1245                                        format_expr_with_alias(
1246                                            &self.group_by.null_expr[idx],
1247                                        )
1248                                    } else {
1249                                        format_expr_with_alias(&self.group_by.expr[idx])
1250                                    }
1251                                })
1252                                .collect::<Vec<String>>()
1253                                .join(", ");
1254                            format!("({terms})")
1255                        })
1256                        .collect()
1257                };
1258
1259                write!(f, ", gby=[{}]", g.join(", "))?;
1260
1261                let a: Vec<String> = self
1262                    .aggr_expr
1263                    .iter()
1264                    .map(|agg| agg.name().to_string())
1265                    .collect();
1266                write!(f, ", aggr=[{}]", a.join(", "))?;
1267                if let Some(config) = self.limit_options {
1268                    write!(f, ", lim=[{}]", config.limit)?;
1269                }
1270
1271                if self.input_order_mode != InputOrderMode::Linear {
1272                    write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
1273                }
1274            }
1275            DisplayFormatType::TreeRender => {
1276                let format_expr_with_alias =
1277                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1278                        let expr_sql = fmt_sql(e.as_ref()).to_string();
1279                        if &expr_sql != alias {
1280                            format!("{expr_sql} as {alias}")
1281                        } else {
1282                            expr_sql
1283                        }
1284                    };
1285
1286                let g: Vec<String> = if self.group_by.is_single() {
1287                    self.group_by
1288                        .expr
1289                        .iter()
1290                        .map(format_expr_with_alias)
1291                        .collect()
1292                } else {
1293                    self.group_by
1294                        .groups
1295                        .iter()
1296                        .map(|group| {
1297                            let terms = group
1298                                .iter()
1299                                .enumerate()
1300                                .map(|(idx, is_null)| {
1301                                    if *is_null {
1302                                        format_expr_with_alias(
1303                                            &self.group_by.null_expr[idx],
1304                                        )
1305                                    } else {
1306                                        format_expr_with_alias(&self.group_by.expr[idx])
1307                                    }
1308                                })
1309                                .collect::<Vec<String>>()
1310                                .join(", ");
1311                            format!("({terms})")
1312                        })
1313                        .collect()
1314                };
1315                let a: Vec<String> = self
1316                    .aggr_expr
1317                    .iter()
1318                    .map(|agg| agg.human_display().to_string())
1319                    .collect();
1320                writeln!(f, "mode={:?}", self.mode)?;
1321                if !g.is_empty() {
1322                    writeln!(f, "group_by={}", g.join(", "))?;
1323                }
1324                if !a.is_empty() {
1325                    writeln!(f, "aggr={}", a.join(", "))?;
1326                }
1327                if let Some(config) = self.limit_options {
1328                    writeln!(f, "limit={}", config.limit)?;
1329                }
1330            }
1331        }
1332        Ok(())
1333    }
1334}
1335
1336impl ExecutionPlan for AggregateExec {
1337    fn name(&self) -> &'static str {
1338        "AggregateExec"
1339    }
1340
1341    /// Return a reference to Any that can be used for down-casting
1342    fn as_any(&self) -> &dyn Any {
1343        self
1344    }
1345
1346    fn properties(&self) -> &Arc<PlanProperties> {
1347        &self.cache
1348    }
1349
1350    fn required_input_distribution(&self) -> Vec<Distribution> {
1351        match &self.mode {
1352            AggregateMode::Partial | AggregateMode::PartialReduce => {
1353                vec![Distribution::UnspecifiedDistribution]
1354            }
1355            AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
1356                vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
1357            }
1358            AggregateMode::Final | AggregateMode::Single => {
1359                vec![Distribution::SinglePartition]
1360            }
1361        }
1362    }
1363
1364    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
1365        vec![self.required_input_ordering.clone()]
1366    }
1367
1368    /// The output ordering of [`AggregateExec`] is determined by its `group_by`
1369    /// columns. Although this method is not explicitly used by any optimizer
1370    /// rules yet, overriding the default implementation ensures that it
1371    /// accurately reflects the actual behavior.
1372    ///
1373    /// If the [`InputOrderMode`] is `Linear`, the `group_by` columns don't have
1374    /// an ordering, which means the results do not either. However, in the
1375    /// `Ordered` and `PartiallyOrdered` cases, the `group_by` columns do have
1376    /// an ordering, which is preserved in the output.
1377    fn maintains_input_order(&self) -> Vec<bool> {
1378        vec![self.input_order_mode != InputOrderMode::Linear]
1379    }
1380
1381    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1382        vec![&self.input]
1383    }
1384
1385    fn with_new_children(
1386        self: Arc<Self>,
1387        children: Vec<Arc<dyn ExecutionPlan>>,
1388    ) -> Result<Arc<dyn ExecutionPlan>> {
1389        check_if_same_properties!(self, children);
1390
1391        let mut me = AggregateExec::try_new_with_schema(
1392            self.mode,
1393            Arc::clone(&self.group_by),
1394            self.aggr_expr.to_vec(),
1395            Arc::clone(&self.filter_expr),
1396            Arc::clone(&children[0]),
1397            Arc::clone(&self.input_schema),
1398            Arc::clone(&self.schema),
1399        )?;
1400        me.limit_options = self.limit_options;
1401        me.dynamic_filter = self.dynamic_filter.clone();
1402
1403        Ok(Arc::new(me))
1404    }
1405
1406    fn execute(
1407        &self,
1408        partition: usize,
1409        context: Arc<TaskContext>,
1410    ) -> Result<SendableRecordBatchStream> {
1411        self.execute_typed(partition, &context)
1412            .map(|stream| stream.into())
1413    }
1414
1415    fn metrics(&self) -> Option<MetricsSet> {
1416        Some(self.metrics.clone_inner())
1417    }
1418
1419    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1420        let child_statistics = self.input().partition_statistics(partition)?;
1421        self.statistics_inner(&child_statistics)
1422    }
1423
1424    fn cardinality_effect(&self) -> CardinalityEffect {
1425        CardinalityEffect::LowerEqual
1426    }
1427
1428    /// Push down parent filters when possible (see implementation comment for details),
1429    /// and also pushdown self dynamic filters (see `AggrDynFilter` for details)
1430    fn gather_filters_for_pushdown(
1431        &self,
1432        phase: FilterPushdownPhase,
1433        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1434        config: &ConfigOptions,
1435    ) -> Result<FilterDescription> {
1436        // It's safe to push down filters through aggregates when filters only reference
1437        // grouping columns, because such filters determine which groups to compute, not
1438        // *how* to compute them. Each group's aggregate values (SUM, COUNT, etc.) are
1439        // calculated from the same input rows regardless of whether we filter before or
1440        // after grouping - filtering before just eliminates entire groups early.
1441        // This optimization is NOT safe for filters on aggregated columns (like filtering on
1442        // the result of SUM or COUNT), as those require computing all groups first.
1443
1444        let grouping_columns: HashSet<_> = self
1445            .group_by
1446            .expr()
1447            .iter()
1448            .flat_map(|(expr, _)| collect_columns(expr))
1449            .collect();
1450
1451        // Analyze each filter separately to determine if it can be pushed down
1452        let mut safe_filters = Vec::new();
1453        let mut unsafe_filters = Vec::new();
1454
1455        for filter in parent_filters {
1456            let filter_columns: HashSet<_> =
1457                collect_columns(&filter).into_iter().collect();
1458
1459            // Check if this filter references non-grouping columns
1460            let references_non_grouping = !grouping_columns.is_empty()
1461                && !filter_columns.is_subset(&grouping_columns);
1462
1463            if references_non_grouping {
1464                unsafe_filters.push(filter);
1465                continue;
1466            }
1467
1468            // For GROUPING SETS, verify this filter's columns appear in all grouping sets
1469            if self.group_by.groups().len() > 1 {
1470                let filter_column_indices: Vec<usize> = filter_columns
1471                    .iter()
1472                    .filter_map(|filter_col| {
1473                        self.group_by.expr().iter().position(|(expr, _)| {
1474                            collect_columns(expr).contains(filter_col)
1475                        })
1476                    })
1477                    .collect();
1478
1479                // Check if any of this filter's columns are missing from any grouping set
1480                let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1481                    filter_column_indices
1482                        .iter()
1483                        .any(|&idx| null_mask.get(idx) == Some(&true))
1484                });
1485
1486                if has_missing_column {
1487                    unsafe_filters.push(filter);
1488                    continue;
1489                }
1490            }
1491
1492            // This filter is safe to push down
1493            safe_filters.push(filter);
1494        }
1495
1496        // Build child filter description with both safe and unsafe filters
1497        let child = self.children()[0];
1498        let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1499
1500        // Add unsafe filters as unsupported
1501        child_desc.parent_filters.extend(
1502            unsafe_filters
1503                .into_iter()
1504                .map(PushedDownPredicate::unsupported),
1505        );
1506
1507        // Include self dynamic filter when it's possible
1508        if phase == FilterPushdownPhase::Post
1509            && config.optimizer.enable_aggregate_dynamic_filter_pushdown
1510            && let Some(self_dyn_filter) = &self.dynamic_filter
1511        {
1512            let dyn_filter = Arc::clone(&self_dyn_filter.filter);
1513            child_desc = child_desc.with_self_filter(dyn_filter);
1514        }
1515
1516        Ok(FilterDescription::new().with_child(child_desc))
1517    }
1518
1519    /// If child accepts self's dynamic filter, keep `self.dynamic_filter` with Some,
1520    /// otherwise clear it to None.
1521    fn handle_child_pushdown_result(
1522        &self,
1523        phase: FilterPushdownPhase,
1524        child_pushdown_result: ChildPushdownResult,
1525        _config: &ConfigOptions,
1526    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1527        let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
1528
1529        // If this node tried to pushdown some dynamic filter before, now we check
1530        // if the child accept the filter
1531        if phase == FilterPushdownPhase::Post
1532            && let Some(dyn_filter) = &self.dynamic_filter
1533        {
1534            // let child_accepts_dyn_filter = child_pushdown_result
1535            //     .self_filters
1536            //     .first()
1537            //     .map(|filters| {
1538            //         assert_eq_or_internal_err!(
1539            //             filters.len(),
1540            //             1,
1541            //             "Aggregate only pushdown one self dynamic filter"
1542            //         );
1543            //         let filter = filters.get(0).unwrap(); // Asserted above
1544            //         Ok(matches!(filter.discriminant, PushedDown::Yes))
1545            //     })
1546            //     .unwrap_or_else(|| internal_err!("The length of self filters equals to the number of child of this ExecutionPlan, so it must be 1"))?;
1547
1548            // HACK: The above snippet should be used, however, now the child reply
1549            // `PushDown::No` can indicate they're not able to push down row-level
1550            // filter, but still keep the filter for statistics pruning.
1551            // So here, we try to use ref count to determine if the dynamic filter
1552            // has actually be pushed down.
1553            // Issue: <https://github.com/apache/datafusion/issues/18856>
1554            let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
1555
1556            if !child_accepts_dyn_filter {
1557                // Child can't consume the self dynamic filter, so disable it by setting
1558                // to `None`
1559                let mut new_node = self.clone();
1560                new_node.dynamic_filter = None;
1561
1562                result = result
1563                    .with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
1564            }
1565        }
1566
1567        Ok(result)
1568    }
1569}
1570
1571fn create_schema(
1572    input_schema: &Schema,
1573    group_by: &PhysicalGroupBy,
1574    aggr_expr: &[Arc<AggregateFunctionExpr>],
1575    mode: AggregateMode,
1576) -> Result<Schema> {
1577    let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1578    fields.extend(group_by.output_fields(input_schema)?);
1579
1580    match mode.output_mode() {
1581        AggregateOutputMode::Final => {
1582            // in final mode, the field with the final result of the accumulator
1583            for expr in aggr_expr {
1584                fields.push(expr.field())
1585            }
1586        }
1587        AggregateOutputMode::Partial => {
1588            // in partial mode, the fields of the accumulator's state
1589            for expr in aggr_expr {
1590                fields.extend(expr.state_fields()?.iter().cloned());
1591            }
1592        }
1593    }
1594
1595    Ok(Schema::new_with_metadata(
1596        fields,
1597        input_schema.metadata().clone(),
1598    ))
1599}
1600
1601/// Determines the lexical ordering requirement for an aggregate expression.
1602///
1603/// # Parameters
1604///
1605/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the
1606///   aggregate expression.
1607/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1608///   physical GROUP BY expression.
1609/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1610///   mode of aggregation.
1611/// - `include_soft_requirement`: When `false`, only hard requirements are
1612///   considered, as indicated by [`AggregateFunctionExpr::order_sensitivity`]
1613///   returning [`AggregateOrderSensitivity::HardRequirement`].
1614///   Otherwise, also soft requirements ([`AggregateOrderSensitivity::SoftRequirement`])
1615///   are considered.
1616///
1617/// # Returns
1618///
1619/// A `LexOrdering` instance indicating the lexical ordering requirement for
1620/// the aggregate expression.
1621fn get_aggregate_expr_req(
1622    aggr_expr: &AggregateFunctionExpr,
1623    group_by: &PhysicalGroupBy,
1624    agg_mode: &AggregateMode,
1625    include_soft_requirement: bool,
1626) -> Option<LexOrdering> {
1627    // If the aggregation is performing a "second stage" calculation,
1628    // then ignore the ordering requirement. Ordering requirement applies
1629    // only to the aggregation input data.
1630    if agg_mode.input_mode() == AggregateInputMode::Partial {
1631        return None;
1632    }
1633
1634    match aggr_expr.order_sensitivity() {
1635        AggregateOrderSensitivity::Insensitive => return None,
1636        AggregateOrderSensitivity::HardRequirement => {}
1637        AggregateOrderSensitivity::SoftRequirement => {
1638            if !include_soft_requirement {
1639                return None;
1640            }
1641        }
1642        AggregateOrderSensitivity::Beneficial => return None,
1643    }
1644
1645    let mut sort_exprs = aggr_expr.order_bys().to_vec();
1646    // In non-first stage modes, we accumulate data (using `merge_batch`) from
1647    // different partitions (i.e. merge partial results). During this merge, we
1648    // consider the ordering of each partial result. Hence, we do not need to
1649    // use the ordering requirement in such modes as long as partial results are
1650    // generated with the correct ordering.
1651    if group_by.is_single() {
1652        // Remove all orderings that occur in the group by. These requirements
1653        // will definitely be satisfied -- Each group by expression will have
1654        // distinct values per group, hence all requirements are satisfied.
1655        let physical_exprs = group_by.input_exprs();
1656        sort_exprs.retain(|sort_expr| {
1657            !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1658        });
1659    }
1660    LexOrdering::new(sort_exprs)
1661}
1662
1663/// Concatenates the given slices.
1664pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1665    [lhs, rhs].concat()
1666}
1667
1668// Determines if the candidate ordering is finer than the current ordering.
1669// Returns `None` if they are incomparable, `Some(true)` if there is no current
1670// ordering or candidate ordering is finer, and `Some(false)` otherwise.
1671fn determine_finer(
1672    current: &Option<LexOrdering>,
1673    candidate: &LexOrdering,
1674) -> Option<bool> {
1675    if let Some(ordering) = current {
1676        candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1677    } else {
1678        Some(true)
1679    }
1680}
1681
1682/// Gets the common requirement that satisfies all the aggregate expressions.
1683/// When possible, chooses the requirement that is already satisfied by the
1684/// equivalence properties.
1685///
1686/// # Parameters
1687///
1688/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the
1689///   aggregate expressions.
1690/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1691///   physical GROUP BY expression.
1692/// - `eq_properties`: A reference to an `EquivalenceProperties` instance
1693///   representing equivalence properties for ordering.
1694/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1695///   mode of aggregation.
1696///
1697/// # Returns
1698///
1699/// A `Result<Vec<PhysicalSortRequirement>>` instance, which is the requirement
1700/// that satisfies all the aggregate requirements. Returns an error in case of
1701/// conflicting requirements.
1702pub fn get_finer_aggregate_exprs_requirement(
1703    aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1704    group_by: &PhysicalGroupBy,
1705    eq_properties: &EquivalenceProperties,
1706    agg_mode: &AggregateMode,
1707) -> Result<Vec<PhysicalSortRequirement>> {
1708    let mut requirement = None;
1709
1710    // First try and find a match for all hard and soft requirements.
1711    // If a match can't be found, try a second time just matching hard
1712    // requirements.
1713    for include_soft_requirement in [false, true] {
1714        for aggr_expr in aggr_exprs.iter_mut() {
1715            let Some(aggr_req) = get_aggregate_expr_req(
1716                aggr_expr,
1717                group_by,
1718                agg_mode,
1719                include_soft_requirement,
1720            )
1721            .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1722                // There is no aggregate ordering requirement, or it is trivially
1723                // satisfied -- we can skip this expression.
1724                continue;
1725            };
1726            // If the common requirement is finer than the current expression's,
1727            // we can skip this expression. If the latter is finer than the former,
1728            // adopt it if it is satisfied by the equivalence properties. Otherwise,
1729            // defer the analysis to the reverse expression.
1730            let forward_finer = determine_finer(&requirement, &aggr_req);
1731            if let Some(finer) = forward_finer {
1732                if !finer {
1733                    continue;
1734                } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1735                    requirement = Some(aggr_req);
1736                    continue;
1737                }
1738            }
1739            if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1740                let Some(rev_aggr_req) = get_aggregate_expr_req(
1741                    &reverse_aggr_expr,
1742                    group_by,
1743                    agg_mode,
1744                    include_soft_requirement,
1745                )
1746                .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1747                    // The reverse requirement is trivially satisfied -- just reverse
1748                    // the expression and continue with the next one:
1749                    *aggr_expr = Arc::new(reverse_aggr_expr);
1750                    continue;
1751                };
1752                // If the common requirement is finer than the reverse expression's,
1753                // just reverse it and continue the loop with the next aggregate
1754                // expression. If the latter is finer than the former, adopt it if
1755                // it is satisfied by the equivalence properties. Otherwise, adopt
1756                // the forward expression.
1757                if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1758                    if !finer {
1759                        *aggr_expr = Arc::new(reverse_aggr_expr);
1760                    } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1761                        *aggr_expr = Arc::new(reverse_aggr_expr);
1762                        requirement = Some(rev_aggr_req);
1763                    } else {
1764                        requirement = Some(aggr_req);
1765                    }
1766                } else if forward_finer.is_some() {
1767                    requirement = Some(aggr_req);
1768                } else {
1769                    // Neither the existing requirement nor the current aggregate
1770                    // requirement satisfy the other (forward or reverse), this
1771                    // means they are conflicting. This is a problem only for hard
1772                    // requirements. Unsatisfied soft requirements can be ignored.
1773                    if !include_soft_requirement {
1774                        return not_impl_err!(
1775                            "Conflicting ordering requirements in aggregate functions is not supported"
1776                        );
1777                    }
1778                }
1779            }
1780        }
1781    }
1782
1783    Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1784}
1785
1786/// Returns physical expressions for arguments to evaluate against a batch.
1787///
1788/// The expressions are different depending on `mode`:
1789/// * Partial: AggregateFunctionExpr::expressions
1790/// * Final: columns of `AggregateFunctionExpr::state_fields()`
1791pub fn aggregate_expressions(
1792    aggr_expr: &[Arc<AggregateFunctionExpr>],
1793    mode: &AggregateMode,
1794    col_idx_base: usize,
1795) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1796    match mode.input_mode() {
1797        AggregateInputMode::Raw => Ok(aggr_expr
1798            .iter()
1799            .map(|agg| {
1800                let mut result = agg.expressions();
1801                // Append ordering requirements to expressions' results. This
1802                // way order sensitive aggregators can satisfy requirement
1803                // themselves.
1804                result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
1805                result
1806            })
1807            .collect()),
1808        AggregateInputMode::Partial => {
1809            // In merge mode, we build the merge expressions of the aggregation.
1810            let mut col_idx_base = col_idx_base;
1811            aggr_expr
1812                .iter()
1813                .map(|agg| {
1814                    let exprs = merge_expressions(col_idx_base, agg)?;
1815                    col_idx_base += exprs.len();
1816                    Ok(exprs)
1817                })
1818                .collect()
1819        }
1820    }
1821}
1822
1823/// uses `state_fields` to build a vec of physical column expressions required to merge the
1824/// AggregateFunctionExpr' accumulator's state.
1825///
1826/// `index_base` is the starting physical column index for the next expanded state field.
1827fn merge_expressions(
1828    index_base: usize,
1829    expr: &AggregateFunctionExpr,
1830) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1831    expr.state_fields().map(|fields| {
1832        fields
1833            .iter()
1834            .enumerate()
1835            .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1836            .collect()
1837    })
1838}
1839
1840pub type AccumulatorItem = Box<dyn Accumulator>;
1841
1842pub fn create_accumulators(
1843    aggr_expr: &[Arc<AggregateFunctionExpr>],
1844) -> Result<Vec<AccumulatorItem>> {
1845    aggr_expr
1846        .iter()
1847        .map(|expr| expr.create_accumulator())
1848        .collect()
1849}
1850
1851/// returns a vector of ArrayRefs, where each entry corresponds to either the
1852/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
1853pub fn finalize_aggregation(
1854    accumulators: &mut [AccumulatorItem],
1855    mode: &AggregateMode,
1856) -> Result<Vec<ArrayRef>> {
1857    match mode.output_mode() {
1858        AggregateOutputMode::Final => {
1859            // Merge the state to the final value
1860            accumulators
1861                .iter_mut()
1862                .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1863                .collect()
1864        }
1865        AggregateOutputMode::Partial => {
1866            // Build the vector of states
1867            accumulators
1868                .iter_mut()
1869                .map(|accumulator| {
1870                    accumulator.state().and_then(|e| {
1871                        e.iter()
1872                            .map(|v| v.to_array())
1873                            .collect::<Result<Vec<ArrayRef>>>()
1874                    })
1875                })
1876                .flatten_ok()
1877                .collect()
1878        }
1879    }
1880}
1881
1882/// Evaluates groups of expressions against a record batch.
1883pub fn evaluate_many(
1884    expr: &[Vec<Arc<dyn PhysicalExpr>>],
1885    batch: &RecordBatch,
1886) -> Result<Vec<Vec<ArrayRef>>> {
1887    expr.iter()
1888        .map(|expr| evaluate_expressions_to_arrays(expr, batch))
1889        .collect()
1890}
1891
1892fn evaluate_optional(
1893    expr: &[Option<Arc<dyn PhysicalExpr>>],
1894    batch: &RecordBatch,
1895) -> Result<Vec<Option<ArrayRef>>> {
1896    expr.iter()
1897        .map(|expr| {
1898            expr.as_ref()
1899                .map(|expr| {
1900                    expr.evaluate(batch)
1901                        .and_then(|v| v.into_array(batch.num_rows()))
1902                })
1903                .transpose()
1904        })
1905        .collect()
1906}
1907
1908fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1909    if group.len() > 64 {
1910        return not_impl_err!(
1911            "Grouping sets with more than 64 columns are not supported"
1912        );
1913    }
1914    let group_id = group.iter().fold(0u64, |acc, &is_null| {
1915        (acc << 1) | if is_null { 1 } else { 0 }
1916    });
1917    let num_rows = batch.num_rows();
1918    if group.len() <= 8 {
1919        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1920    } else if group.len() <= 16 {
1921        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1922    } else if group.len() <= 32 {
1923        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1924    } else {
1925        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1926    }
1927}
1928
1929/// Evaluate a group by expression against a `RecordBatch`
1930///
1931/// Arguments:
1932/// - `group_by`: the expression to evaluate
1933/// - `batch`: the `RecordBatch` to evaluate against
1934///
1935/// Returns: A Vec of Vecs of Array of results
1936/// The outer Vec appears to be for grouping sets
1937/// The inner Vec contains the results per expression
1938/// The inner-inner Array contains the results per row
1939pub fn evaluate_group_by(
1940    group_by: &PhysicalGroupBy,
1941    batch: &RecordBatch,
1942) -> Result<Vec<Vec<ArrayRef>>> {
1943    let exprs = evaluate_expressions_to_arrays(
1944        group_by.expr.iter().map(|(expr, _)| expr),
1945        batch,
1946    )?;
1947    let null_exprs = evaluate_expressions_to_arrays(
1948        group_by.null_expr.iter().map(|(expr, _)| expr),
1949        batch,
1950    )?;
1951
1952    group_by
1953        .groups
1954        .iter()
1955        .map(|group| {
1956            let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1957            group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1958                if *is_null {
1959                    Arc::clone(&null_exprs[idx])
1960                } else {
1961                    Arc::clone(&exprs[idx])
1962                }
1963            }));
1964            if !group_by.is_single() {
1965                group_values.push(group_id_array(group, batch)?);
1966            }
1967            Ok(group_values)
1968        })
1969        .collect()
1970}
1971
1972#[cfg(test)]
1973mod tests {
1974    use std::task::{Context, Poll};
1975
1976    use super::*;
1977    use crate::RecordBatchStream;
1978    use crate::coalesce_partitions::CoalescePartitionsExec;
1979    use crate::common;
1980    use crate::common::collect;
1981    use crate::execution_plan::Boundedness;
1982    use crate::expressions::col;
1983    use crate::metrics::MetricValue;
1984    use crate::test::TestMemoryExec;
1985    use crate::test::assert_is_pending;
1986    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
1987
1988    use arrow::array::{
1989        DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
1990        UInt32Array, UInt64Array,
1991    };
1992    use arrow::compute::{SortOptions, concat_batches};
1993    use arrow::datatypes::{DataType, Int32Type};
1994    use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
1995    use datafusion_common::{DataFusionError, ScalarValue, internal_err};
1996    use datafusion_execution::config::SessionConfig;
1997    use datafusion_execution::memory_pool::FairSpillPool;
1998    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1999    use datafusion_functions_aggregate::array_agg::array_agg_udaf;
2000    use datafusion_functions_aggregate::average::avg_udaf;
2001    use datafusion_functions_aggregate::count::count_udaf;
2002    use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
2003    use datafusion_functions_aggregate::median::median_udaf;
2004    use datafusion_functions_aggregate::sum::sum_udaf;
2005    use datafusion_physical_expr::Partitioning;
2006    use datafusion_physical_expr::PhysicalSortExpr;
2007    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
2008    use datafusion_physical_expr::expressions::Literal;
2009    use datafusion_physical_expr::expressions::lit;
2010
2011    use crate::projection::ProjectionExec;
2012    use datafusion_physical_expr::projection::ProjectionExpr;
2013    use futures::{FutureExt, Stream};
2014    use insta::{allow_duplicates, assert_snapshot};
2015
2016    // Generate a schema which consists of 5 columns (a, b, c, d, e)
2017    fn create_test_schema() -> Result<SchemaRef> {
2018        let a = Field::new("a", DataType::Int32, true);
2019        let b = Field::new("b", DataType::Int32, true);
2020        let c = Field::new("c", DataType::Int32, true);
2021        let d = Field::new("d", DataType::Int32, true);
2022        let e = Field::new("e", DataType::Int32, true);
2023        let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
2024
2025        Ok(schema)
2026    }
2027
2028    /// some mock data to aggregates
2029    fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
2030        // define a schema.
2031        let schema = Arc::new(Schema::new(vec![
2032            Field::new("a", DataType::UInt32, false),
2033            Field::new("b", DataType::Float64, false),
2034        ]));
2035
2036        // define data.
2037        (
2038            Arc::clone(&schema),
2039            vec![
2040                RecordBatch::try_new(
2041                    Arc::clone(&schema),
2042                    vec![
2043                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2044                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2045                    ],
2046                )
2047                .unwrap(),
2048                RecordBatch::try_new(
2049                    schema,
2050                    vec![
2051                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2052                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2053                    ],
2054                )
2055                .unwrap(),
2056            ],
2057        )
2058    }
2059
2060    /// Generates some mock data for aggregate tests.
2061    fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
2062        // Define a schema:
2063        let schema = Arc::new(Schema::new(vec![
2064            Field::new("a", DataType::UInt32, false),
2065            Field::new("b", DataType::Float64, false),
2066        ]));
2067
2068        // Generate data so that first and last value results are at 2nd and
2069        // 3rd partitions.  With this construction, we guarantee we don't receive
2070        // the expected result by accident, but merging actually works properly;
2071        // i.e. it doesn't depend on the data insertion order.
2072        (
2073            Arc::clone(&schema),
2074            vec![
2075                RecordBatch::try_new(
2076                    Arc::clone(&schema),
2077                    vec![
2078                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2079                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2080                    ],
2081                )
2082                .unwrap(),
2083                RecordBatch::try_new(
2084                    Arc::clone(&schema),
2085                    vec![
2086                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2087                        Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
2088                    ],
2089                )
2090                .unwrap(),
2091                RecordBatch::try_new(
2092                    Arc::clone(&schema),
2093                    vec![
2094                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2095                        Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
2096                    ],
2097                )
2098                .unwrap(),
2099                RecordBatch::try_new(
2100                    schema,
2101                    vec![
2102                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2103                        Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
2104                    ],
2105                )
2106                .unwrap(),
2107            ],
2108        )
2109    }
2110
2111    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
2112        let session_config = SessionConfig::new().with_batch_size(batch_size);
2113        let runtime = RuntimeEnvBuilder::new()
2114            .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
2115            .build_arc()
2116            .unwrap();
2117        let task_ctx = TaskContext::default()
2118            .with_session_config(session_config)
2119            .with_runtime(runtime);
2120        Arc::new(task_ctx)
2121    }
2122
2123    async fn check_grouping_sets(
2124        input: Arc<dyn ExecutionPlan>,
2125        spill: bool,
2126    ) -> Result<()> {
2127        let input_schema = input.schema();
2128
2129        let grouping_set = PhysicalGroupBy::new(
2130            vec![
2131                (col("a", &input_schema)?, "a".to_string()),
2132                (col("b", &input_schema)?, "b".to_string()),
2133            ],
2134            vec![
2135                (lit(ScalarValue::UInt32(None)), "a".to_string()),
2136                (lit(ScalarValue::Float64(None)), "b".to_string()),
2137            ],
2138            vec![
2139                vec![false, true],  // (a, NULL)
2140                vec![true, false],  // (NULL, b)
2141                vec![false, false], // (a,b)
2142            ],
2143            true,
2144        );
2145
2146        let aggregates = vec![Arc::new(
2147            AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
2148                .schema(Arc::clone(&input_schema))
2149                .alias("COUNT(1)")
2150                .build()?,
2151        )];
2152
2153        let task_ctx = if spill {
2154            // adjust the max memory size to have the partial aggregate result for spill mode.
2155            new_spill_ctx(4, 500)
2156        } else {
2157            Arc::new(TaskContext::default())
2158        };
2159
2160        let partial_aggregate = Arc::new(AggregateExec::try_new(
2161            AggregateMode::Partial,
2162            grouping_set.clone(),
2163            aggregates.clone(),
2164            vec![None],
2165            input,
2166            Arc::clone(&input_schema),
2167        )?);
2168
2169        let result =
2170            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2171
2172        if spill {
2173            // In spill mode, we test with the limited memory, if the mem usage exceeds,
2174            // we trigger the early emit rule, which turns out the partial aggregate result.
2175            allow_duplicates! {
2176            assert_snapshot!(batches_to_sort_string(&result),
2177            @r"
2178            +---+-----+---------------+-----------------+
2179            | a | b   | __grouping_id | COUNT(1)[count] |
2180            +---+-----+---------------+-----------------+
2181            |   | 1.0 | 2             | 1               |
2182            |   | 1.0 | 2             | 1               |
2183            |   | 2.0 | 2             | 1               |
2184            |   | 2.0 | 2             | 1               |
2185            |   | 3.0 | 2             | 1               |
2186            |   | 3.0 | 2             | 1               |
2187            |   | 4.0 | 2             | 1               |
2188            |   | 4.0 | 2             | 1               |
2189            | 2 |     | 1             | 1               |
2190            | 2 |     | 1             | 1               |
2191            | 2 | 1.0 | 0             | 1               |
2192            | 2 | 1.0 | 0             | 1               |
2193            | 3 |     | 1             | 1               |
2194            | 3 |     | 1             | 2               |
2195            | 3 | 2.0 | 0             | 2               |
2196            | 3 | 3.0 | 0             | 1               |
2197            | 4 |     | 1             | 1               |
2198            | 4 |     | 1             | 2               |
2199            | 4 | 3.0 | 0             | 1               |
2200            | 4 | 4.0 | 0             | 2               |
2201            +---+-----+---------------+-----------------+
2202            "
2203            );
2204            }
2205        } else {
2206            allow_duplicates! {
2207            assert_snapshot!(batches_to_sort_string(&result),
2208            @r"
2209            +---+-----+---------------+-----------------+
2210            | a | b   | __grouping_id | COUNT(1)[count] |
2211            +---+-----+---------------+-----------------+
2212            |   | 1.0 | 2             | 2               |
2213            |   | 2.0 | 2             | 2               |
2214            |   | 3.0 | 2             | 2               |
2215            |   | 4.0 | 2             | 2               |
2216            | 2 |     | 1             | 2               |
2217            | 2 | 1.0 | 0             | 2               |
2218            | 3 |     | 1             | 3               |
2219            | 3 | 2.0 | 0             | 2               |
2220            | 3 | 3.0 | 0             | 1               |
2221            | 4 |     | 1             | 3               |
2222            | 4 | 3.0 | 0             | 1               |
2223            | 4 | 4.0 | 0             | 2               |
2224            +---+-----+---------------+-----------------+
2225            "
2226            );
2227            }
2228        };
2229
2230        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2231
2232        let final_grouping_set = grouping_set.as_final();
2233
2234        let task_ctx = if spill {
2235            new_spill_ctx(4, 3160)
2236        } else {
2237            task_ctx
2238        };
2239
2240        let merged_aggregate = Arc::new(AggregateExec::try_new(
2241            AggregateMode::Final,
2242            final_grouping_set,
2243            aggregates,
2244            vec![None],
2245            merge,
2246            input_schema,
2247        )?);
2248
2249        let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2250        let batch = concat_batches(&result[0].schema(), &result)?;
2251        assert_eq!(batch.num_columns(), 4);
2252        assert_eq!(batch.num_rows(), 12);
2253
2254        allow_duplicates! {
2255        assert_snapshot!(
2256            batches_to_sort_string(&result),
2257            @r"
2258        +---+-----+---------------+----------+
2259        | a | b   | __grouping_id | COUNT(1) |
2260        +---+-----+---------------+----------+
2261        |   | 1.0 | 2             | 2        |
2262        |   | 2.0 | 2             | 2        |
2263        |   | 3.0 | 2             | 2        |
2264        |   | 4.0 | 2             | 2        |
2265        | 2 |     | 1             | 2        |
2266        | 2 | 1.0 | 0             | 2        |
2267        | 3 |     | 1             | 3        |
2268        | 3 | 2.0 | 0             | 2        |
2269        | 3 | 3.0 | 0             | 1        |
2270        | 4 |     | 1             | 3        |
2271        | 4 | 3.0 | 0             | 1        |
2272        | 4 | 4.0 | 0             | 2        |
2273        +---+-----+---------------+----------+
2274        "
2275        );
2276        }
2277
2278        let metrics = merged_aggregate.metrics().unwrap();
2279        let output_rows = metrics.output_rows().unwrap();
2280        assert_eq!(12, output_rows);
2281
2282        Ok(())
2283    }
2284
2285    /// build the aggregates on the data from some_data() and check the results
2286    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
2287        let input_schema = input.schema();
2288
2289        let grouping_set = PhysicalGroupBy::new(
2290            vec![(col("a", &input_schema)?, "a".to_string())],
2291            vec![],
2292            vec![vec![false]],
2293            false,
2294        );
2295
2296        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2297            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2298                .schema(Arc::clone(&input_schema))
2299                .alias("AVG(b)")
2300                .build()?,
2301        )];
2302
2303        let task_ctx = if spill {
2304            // set to an appropriate value to trigger spill
2305            new_spill_ctx(2, 1600)
2306        } else {
2307            Arc::new(TaskContext::default())
2308        };
2309
2310        let partial_aggregate = Arc::new(AggregateExec::try_new(
2311            AggregateMode::Partial,
2312            grouping_set.clone(),
2313            aggregates.clone(),
2314            vec![None],
2315            input,
2316            Arc::clone(&input_schema),
2317        )?);
2318
2319        let result =
2320            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2321
2322        if spill {
2323            allow_duplicates! {
2324            assert_snapshot!(batches_to_sort_string(&result), @r"
2325            +---+---------------+-------------+
2326            | a | AVG(b)[count] | AVG(b)[sum] |
2327            +---+---------------+-------------+
2328            | 2 | 1             | 1.0         |
2329            | 2 | 1             | 1.0         |
2330            | 3 | 1             | 2.0         |
2331            | 3 | 2             | 5.0         |
2332            | 4 | 3             | 11.0        |
2333            +---+---------------+-------------+
2334            ");
2335            }
2336        } else {
2337            allow_duplicates! {
2338            assert_snapshot!(batches_to_sort_string(&result), @r"
2339            +---+---------------+-------------+
2340            | a | AVG(b)[count] | AVG(b)[sum] |
2341            +---+---------------+-------------+
2342            | 2 | 2             | 2.0         |
2343            | 3 | 3             | 7.0         |
2344            | 4 | 3             | 11.0        |
2345            +---+---------------+-------------+
2346            ");
2347            }
2348        };
2349
2350        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2351
2352        let final_grouping_set = grouping_set.as_final();
2353
2354        let merged_aggregate = Arc::new(AggregateExec::try_new(
2355            AggregateMode::Final,
2356            final_grouping_set,
2357            aggregates,
2358            vec![None],
2359            merge,
2360            input_schema,
2361        )?);
2362
2363        // Verify statistics are preserved proportionally through aggregation
2364        let final_stats = merged_aggregate.partition_statistics(None)?;
2365        assert!(final_stats.total_byte_size.get_value().is_some());
2366
2367        let task_ctx = if spill {
2368            // enlarge memory limit to let the final aggregation finish
2369            new_spill_ctx(2, 2600)
2370        } else {
2371            Arc::clone(&task_ctx)
2372        };
2373        let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
2374        let batch = concat_batches(&result[0].schema(), &result)?;
2375        assert_eq!(batch.num_columns(), 2);
2376        assert_eq!(batch.num_rows(), 3);
2377
2378        allow_duplicates! {
2379        assert_snapshot!(batches_to_sort_string(&result), @r"
2380        +---+--------------------+
2381        | a | AVG(b)             |
2382        +---+--------------------+
2383        | 2 | 1.0                |
2384        | 3 | 2.3333333333333335 |
2385        | 4 | 3.6666666666666665 |
2386        +---+--------------------+
2387        ");
2388            // For row 2: 3, (2 + 3 + 2) / 3
2389            // For row 3: 4, (3 + 4 + 4) / 3
2390        }
2391
2392        let metrics = merged_aggregate.metrics().unwrap();
2393        let output_rows = metrics.output_rows().unwrap();
2394        let spill_count = metrics.spill_count().unwrap();
2395        let spilled_bytes = metrics.spilled_bytes().unwrap();
2396        let spilled_rows = metrics.spilled_rows().unwrap();
2397
2398        if spill {
2399            // When spilling, the output rows metrics become partial output size + final output size
2400            // This is because final aggregation starts while partial aggregation is still emitting
2401            assert_eq!(8, output_rows);
2402
2403            assert!(spill_count > 0);
2404            assert!(spilled_bytes > 0);
2405            assert!(spilled_rows > 0);
2406        } else {
2407            assert_eq!(3, output_rows);
2408
2409            assert_eq!(0, spill_count);
2410            assert_eq!(0, spilled_bytes);
2411            assert_eq!(0, spilled_rows);
2412        }
2413
2414        Ok(())
2415    }
2416
2417    /// Define a test source that can yield back to runtime before returning its first item ///
2418
2419    #[derive(Debug)]
2420    struct TestYieldingExec {
2421        /// True if this exec should yield back to runtime the first time it is polled
2422        pub yield_first: bool,
2423        cache: Arc<PlanProperties>,
2424    }
2425
2426    impl TestYieldingExec {
2427        fn new(yield_first: bool) -> Self {
2428            let schema = some_data().0;
2429            let cache = Self::compute_properties(schema);
2430            Self {
2431                yield_first,
2432                cache: Arc::new(cache),
2433            }
2434        }
2435
2436        /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
2437        fn compute_properties(schema: SchemaRef) -> PlanProperties {
2438            PlanProperties::new(
2439                EquivalenceProperties::new(schema),
2440                Partitioning::UnknownPartitioning(1),
2441                EmissionType::Incremental,
2442                Boundedness::Bounded,
2443            )
2444        }
2445    }
2446
2447    impl DisplayAs for TestYieldingExec {
2448        fn fmt_as(
2449            &self,
2450            t: DisplayFormatType,
2451            f: &mut std::fmt::Formatter,
2452        ) -> std::fmt::Result {
2453            match t {
2454                DisplayFormatType::Default | DisplayFormatType::Verbose => {
2455                    write!(f, "TestYieldingExec")
2456                }
2457                DisplayFormatType::TreeRender => {
2458                    // TODO: collect info
2459                    write!(f, "")
2460                }
2461            }
2462        }
2463    }
2464
2465    impl ExecutionPlan for TestYieldingExec {
2466        fn name(&self) -> &'static str {
2467            "TestYieldingExec"
2468        }
2469
2470        fn as_any(&self) -> &dyn Any {
2471            self
2472        }
2473
2474        fn properties(&self) -> &Arc<PlanProperties> {
2475            &self.cache
2476        }
2477
2478        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2479            vec![]
2480        }
2481
2482        fn with_new_children(
2483            self: Arc<Self>,
2484            _: Vec<Arc<dyn ExecutionPlan>>,
2485        ) -> Result<Arc<dyn ExecutionPlan>> {
2486            internal_err!("Children cannot be replaced in {self:?}")
2487        }
2488
2489        fn execute(
2490            &self,
2491            _partition: usize,
2492            _context: Arc<TaskContext>,
2493        ) -> Result<SendableRecordBatchStream> {
2494            let stream = if self.yield_first {
2495                TestYieldingStream::New
2496            } else {
2497                TestYieldingStream::Yielded
2498            };
2499
2500            Ok(Box::pin(stream))
2501        }
2502
2503        fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
2504            if partition.is_some() {
2505                return Ok(Statistics::new_unknown(self.schema().as_ref()));
2506            }
2507            let (_, batches) = some_data();
2508            Ok(common::compute_record_batch_statistics(
2509                &[batches],
2510                &self.schema(),
2511                None,
2512            ))
2513        }
2514    }
2515
2516    /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records
2517    enum TestYieldingStream {
2518        New,
2519        Yielded,
2520        ReturnedBatch1,
2521        ReturnedBatch2,
2522    }
2523
2524    impl Stream for TestYieldingStream {
2525        type Item = Result<RecordBatch>;
2526
2527        fn poll_next(
2528            mut self: std::pin::Pin<&mut Self>,
2529            cx: &mut Context<'_>,
2530        ) -> Poll<Option<Self::Item>> {
2531            match &*self {
2532                TestYieldingStream::New => {
2533                    *(self.as_mut()) = TestYieldingStream::Yielded;
2534                    cx.waker().wake_by_ref();
2535                    Poll::Pending
2536                }
2537                TestYieldingStream::Yielded => {
2538                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2539                    Poll::Ready(Some(Ok(some_data().1[0].clone())))
2540                }
2541                TestYieldingStream::ReturnedBatch1 => {
2542                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2543                    Poll::Ready(Some(Ok(some_data().1[1].clone())))
2544                }
2545                TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2546            }
2547        }
2548    }
2549
2550    impl RecordBatchStream for TestYieldingStream {
2551        fn schema(&self) -> SchemaRef {
2552            some_data().0
2553        }
2554    }
2555
2556    //--- Tests ---//
2557
2558    #[tokio::test]
2559    async fn aggregate_source_not_yielding() -> Result<()> {
2560        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2561
2562        check_aggregates(input, false).await
2563    }
2564
2565    #[tokio::test]
2566    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2567        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2568
2569        check_grouping_sets(input, false).await
2570    }
2571
2572    #[tokio::test]
2573    async fn aggregate_source_with_yielding() -> Result<()> {
2574        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2575
2576        check_aggregates(input, false).await
2577    }
2578
2579    #[tokio::test]
2580    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2581        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2582
2583        check_grouping_sets(input, false).await
2584    }
2585
2586    #[tokio::test]
2587    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2588        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2589
2590        check_aggregates(input, true).await
2591    }
2592
2593    #[tokio::test]
2594    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2595        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2596
2597        check_grouping_sets(input, true).await
2598    }
2599
2600    #[tokio::test]
2601    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2602        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2603
2604        check_aggregates(input, true).await
2605    }
2606
2607    #[tokio::test]
2608    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2609        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2610
2611        check_grouping_sets(input, true).await
2612    }
2613
2614    // Median(a)
2615    fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2616        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2617            .schema(schema)
2618            .alias("MEDIAN(a)")
2619            .build()
2620    }
2621
2622    #[tokio::test]
2623    async fn test_oom() -> Result<()> {
2624        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2625        let input_schema = input.schema();
2626
2627        let runtime = RuntimeEnvBuilder::new()
2628            .with_memory_limit(1, 1.0)
2629            .build_arc()?;
2630        let task_ctx = TaskContext::default().with_runtime(runtime);
2631        let task_ctx = Arc::new(task_ctx);
2632
2633        let groups_none = PhysicalGroupBy::default();
2634        let groups_some = PhysicalGroupBy::new(
2635            vec![(col("a", &input_schema)?, "a".to_string())],
2636            vec![],
2637            vec![vec![false]],
2638            false,
2639        );
2640
2641        // something that allocates within the aggregator
2642        let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2643            vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2644
2645        // use fast-path in `row_hash.rs`.
2646        let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2647            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2648                .schema(Arc::clone(&input_schema))
2649                .alias("AVG(b)")
2650                .build()?,
2651        )];
2652
2653        for (version, groups, aggregates) in [
2654            (0, groups_none, aggregates_v0),
2655            (2, groups_some, aggregates_v2),
2656        ] {
2657            let n_aggr = aggregates.len();
2658            let partial_aggregate = Arc::new(AggregateExec::try_new(
2659                AggregateMode::Single,
2660                groups,
2661                aggregates,
2662                vec![None; n_aggr],
2663                Arc::clone(&input),
2664                Arc::clone(&input_schema),
2665            )?);
2666
2667            let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
2668
2669            // ensure that we really got the version we wanted
2670            match version {
2671                0 => {
2672                    assert!(matches!(stream, StreamType::AggregateStream(_)));
2673                }
2674                1 => {
2675                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2676                }
2677                2 => {
2678                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2679                }
2680                _ => panic!("Unknown version: {version}"),
2681            }
2682
2683            let stream: SendableRecordBatchStream = stream.into();
2684            let err = collect(stream).await.unwrap_err();
2685
2686            // error root cause traversal is a bit complicated, see #4172.
2687            let err = err.find_root();
2688            assert!(
2689                matches!(err, DataFusionError::ResourcesExhausted(_)),
2690                "Wrong error type: {err}",
2691            );
2692        }
2693
2694        Ok(())
2695    }
2696
2697    #[tokio::test]
2698    async fn test_drop_cancel_without_groups() -> Result<()> {
2699        let task_ctx = Arc::new(TaskContext::default());
2700        let schema =
2701            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2702
2703        let groups = PhysicalGroupBy::default();
2704
2705        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2706            AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2707                .schema(Arc::clone(&schema))
2708                .alias("AVG(a)")
2709                .build()?,
2710        )];
2711
2712        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2713        let refs = blocking_exec.refs();
2714        let aggregate_exec = Arc::new(AggregateExec::try_new(
2715            AggregateMode::Partial,
2716            groups.clone(),
2717            aggregates.clone(),
2718            vec![None],
2719            blocking_exec,
2720            schema,
2721        )?);
2722
2723        let fut = crate::collect(aggregate_exec, task_ctx);
2724        let mut fut = fut.boxed();
2725
2726        assert_is_pending(&mut fut);
2727        drop(fut);
2728        assert_strong_count_converges_to_zero(refs).await;
2729
2730        Ok(())
2731    }
2732
2733    #[tokio::test]
2734    async fn test_drop_cancel_with_groups() -> Result<()> {
2735        let task_ctx = Arc::new(TaskContext::default());
2736        let schema = Arc::new(Schema::new(vec![
2737            Field::new("a", DataType::Float64, true),
2738            Field::new("b", DataType::Float64, true),
2739        ]));
2740
2741        let groups =
2742            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2743
2744        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2745            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2746                .schema(Arc::clone(&schema))
2747                .alias("AVG(b)")
2748                .build()?,
2749        )];
2750
2751        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2752        let refs = blocking_exec.refs();
2753        let aggregate_exec = Arc::new(AggregateExec::try_new(
2754            AggregateMode::Partial,
2755            groups,
2756            aggregates.clone(),
2757            vec![None],
2758            blocking_exec,
2759            schema,
2760        )?);
2761
2762        let fut = crate::collect(aggregate_exec, task_ctx);
2763        let mut fut = fut.boxed();
2764
2765        assert_is_pending(&mut fut);
2766        drop(fut);
2767        assert_strong_count_converges_to_zero(refs).await;
2768
2769        Ok(())
2770    }
2771
2772    #[tokio::test]
2773    async fn run_first_last_multi_partitions() -> Result<()> {
2774        for is_first_acc in [false, true] {
2775            for spill in [false, true] {
2776                first_last_multi_partitions(is_first_acc, spill, 4200).await?
2777            }
2778        }
2779        Ok(())
2780    }
2781
2782    // FIRST_VALUE(b ORDER BY b <SortOptions>)
2783    fn test_first_value_agg_expr(
2784        schema: &Schema,
2785        sort_options: SortOptions,
2786    ) -> Result<Arc<AggregateFunctionExpr>> {
2787        let order_bys = vec![PhysicalSortExpr {
2788            expr: col("b", schema)?,
2789            options: sort_options,
2790        }];
2791        let args = [col("b", schema)?];
2792
2793        AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2794            .order_by(order_bys)
2795            .schema(Arc::new(schema.clone()))
2796            .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2797            .build()
2798            .map(Arc::new)
2799    }
2800
2801    // LAST_VALUE(b ORDER BY b <SortOptions>)
2802    fn test_last_value_agg_expr(
2803        schema: &Schema,
2804        sort_options: SortOptions,
2805    ) -> Result<Arc<AggregateFunctionExpr>> {
2806        let order_bys = vec![PhysicalSortExpr {
2807            expr: col("b", schema)?,
2808            options: sort_options,
2809        }];
2810        let args = [col("b", schema)?];
2811        AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2812            .order_by(order_bys)
2813            .schema(Arc::new(schema.clone()))
2814            .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2815            .build()
2816            .map(Arc::new)
2817    }
2818
2819    // This function constructs the physical plan below,
2820    //
2821    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2822    // "  CoalescePartitionsExec",
2823    // "    AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2824    // "      DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2825    //
2826    // and checks whether the function `merge_batch` works correctly for
2827    // FIRST_VALUE and LAST_VALUE functions.
2828    async fn first_last_multi_partitions(
2829        is_first_acc: bool,
2830        spill: bool,
2831        max_memory: usize,
2832    ) -> Result<()> {
2833        let task_ctx = if spill {
2834            new_spill_ctx(2, max_memory)
2835        } else {
2836            Arc::new(TaskContext::default())
2837        };
2838
2839        let (schema, data) = some_data_v2();
2840        let partition1 = data[0].clone();
2841        let partition2 = data[1].clone();
2842        let partition3 = data[2].clone();
2843        let partition4 = data[3].clone();
2844
2845        let groups =
2846            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2847
2848        let sort_options = SortOptions {
2849            descending: false,
2850            nulls_first: false,
2851        };
2852        let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2853            vec![test_first_value_agg_expr(&schema, sort_options)?]
2854        } else {
2855            vec![test_last_value_agg_expr(&schema, sort_options)?]
2856        };
2857
2858        let memory_exec = TestMemoryExec::try_new_exec(
2859            &[
2860                vec![partition1],
2861                vec![partition2],
2862                vec![partition3],
2863                vec![partition4],
2864            ],
2865            Arc::clone(&schema),
2866            None,
2867        )?;
2868        let aggregate_exec = Arc::new(AggregateExec::try_new(
2869            AggregateMode::Partial,
2870            groups.clone(),
2871            aggregates.clone(),
2872            vec![None],
2873            memory_exec,
2874            Arc::clone(&schema),
2875        )?);
2876        let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2877            as Arc<dyn ExecutionPlan>;
2878        let aggregate_final = Arc::new(AggregateExec::try_new(
2879            AggregateMode::Final,
2880            groups,
2881            aggregates.clone(),
2882            vec![None],
2883            coalesce,
2884            schema,
2885        )?) as Arc<dyn ExecutionPlan>;
2886
2887        let result = crate::collect(aggregate_final, task_ctx).await?;
2888        if is_first_acc {
2889            allow_duplicates! {
2890            assert_snapshot!(batches_to_string(&result), @r"
2891            +---+--------------------------------------------+
2892            | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
2893            +---+--------------------------------------------+
2894            | 2 | 0.0                                        |
2895            | 3 | 1.0                                        |
2896            | 4 | 3.0                                        |
2897            +---+--------------------------------------------+
2898            ");
2899            }
2900        } else {
2901            allow_duplicates! {
2902            assert_snapshot!(batches_to_string(&result), @r"
2903            +---+-------------------------------------------+
2904            | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
2905            +---+-------------------------------------------+
2906            | 2 | 3.0                                       |
2907            | 3 | 5.0                                       |
2908            | 4 | 6.0                                       |
2909            +---+-------------------------------------------+
2910            ");
2911            }
2912        };
2913        Ok(())
2914    }
2915
2916    #[tokio::test]
2917    async fn test_get_finest_requirements() -> Result<()> {
2918        let test_schema = create_test_schema()?;
2919
2920        let options = SortOptions {
2921            descending: false,
2922            nulls_first: false,
2923        };
2924        let col_a = &col("a", &test_schema)?;
2925        let col_b = &col("b", &test_schema)?;
2926        let col_c = &col("c", &test_schema)?;
2927        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2928        // Columns a and b are equal.
2929        eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
2930        // Aggregate requirements are
2931        // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively
2932        let order_by_exprs = vec![
2933            vec![],
2934            vec![PhysicalSortExpr {
2935                expr: Arc::clone(col_a),
2936                options,
2937            }],
2938            vec![
2939                PhysicalSortExpr {
2940                    expr: Arc::clone(col_a),
2941                    options,
2942                },
2943                PhysicalSortExpr {
2944                    expr: Arc::clone(col_b),
2945                    options,
2946                },
2947                PhysicalSortExpr {
2948                    expr: Arc::clone(col_c),
2949                    options,
2950                },
2951            ],
2952            vec![
2953                PhysicalSortExpr {
2954                    expr: Arc::clone(col_a),
2955                    options,
2956                },
2957                PhysicalSortExpr {
2958                    expr: Arc::clone(col_b),
2959                    options,
2960                },
2961            ],
2962        ];
2963
2964        let common_requirement = vec![
2965            PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
2966            PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
2967        ];
2968        let mut aggr_exprs = order_by_exprs
2969            .into_iter()
2970            .map(|order_by_expr| {
2971                AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2972                    .alias("a")
2973                    .order_by(order_by_expr)
2974                    .schema(Arc::clone(&test_schema))
2975                    .build()
2976                    .map(Arc::new)
2977                    .unwrap()
2978            })
2979            .collect::<Vec<_>>();
2980        let group_by = PhysicalGroupBy::new_single(vec![]);
2981        let result = get_finer_aggregate_exprs_requirement(
2982            &mut aggr_exprs,
2983            &group_by,
2984            &eq_properties,
2985            &AggregateMode::Partial,
2986        )?;
2987        assert_eq!(result, common_requirement);
2988        Ok(())
2989    }
2990
2991    #[test]
2992    fn test_agg_exec_same_schema() -> Result<()> {
2993        let schema = Arc::new(Schema::new(vec![
2994            Field::new("a", DataType::Float32, true),
2995            Field::new("b", DataType::Float32, true),
2996        ]));
2997
2998        let col_a = col("a", &schema)?;
2999        let option_desc = SortOptions {
3000            descending: true,
3001            nulls_first: true,
3002        };
3003        let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
3004
3005        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3006            test_first_value_agg_expr(&schema, option_desc)?,
3007            test_last_value_agg_expr(&schema, option_desc)?,
3008        ];
3009        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
3010        let aggregate_exec = Arc::new(AggregateExec::try_new(
3011            AggregateMode::Partial,
3012            groups,
3013            aggregates,
3014            vec![None, None],
3015            Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
3016            schema,
3017        )?);
3018        let new_agg =
3019            Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
3020        assert_eq!(new_agg.schema(), aggregate_exec.schema());
3021        Ok(())
3022    }
3023
3024    #[tokio::test]
3025    async fn test_agg_exec_group_by_const() -> Result<()> {
3026        let schema = Arc::new(Schema::new(vec![
3027            Field::new("a", DataType::Float32, true),
3028            Field::new("b", DataType::Float32, true),
3029            Field::new("const", DataType::Int32, false),
3030        ]));
3031
3032        let col_a = col("a", &schema)?;
3033        let col_b = col("b", &schema)?;
3034        let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
3035
3036        let groups = PhysicalGroupBy::new(
3037            vec![
3038                (col_a, "a".to_string()),
3039                (col_b, "b".to_string()),
3040                (const_expr, "const".to_string()),
3041            ],
3042            vec![
3043                (
3044                    Arc::new(Literal::new(ScalarValue::Float32(None))),
3045                    "a".to_string(),
3046                ),
3047                (
3048                    Arc::new(Literal::new(ScalarValue::Float32(None))),
3049                    "b".to_string(),
3050                ),
3051                (
3052                    Arc::new(Literal::new(ScalarValue::Int32(None))),
3053                    "const".to_string(),
3054                ),
3055            ],
3056            vec![
3057                vec![false, true, true],
3058                vec![true, false, true],
3059                vec![true, true, false],
3060            ],
3061            true,
3062        );
3063
3064        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3065            AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
3066                .schema(Arc::clone(&schema))
3067                .alias("1")
3068                .build()
3069                .map(Arc::new)?,
3070        ];
3071
3072        let input_batches = (0..4)
3073            .map(|_| {
3074                let a = Arc::new(Float32Array::from(vec![0.; 8192]));
3075                let b = Arc::new(Float32Array::from(vec![0.; 8192]));
3076                let c = Arc::new(Int32Array::from(vec![1; 8192]));
3077
3078                RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
3079            })
3080            .collect();
3081
3082        let input =
3083            TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
3084
3085        let aggregate_exec = Arc::new(AggregateExec::try_new(
3086            AggregateMode::Single,
3087            groups,
3088            aggregates.clone(),
3089            vec![None],
3090            input,
3091            schema,
3092        )?);
3093
3094        let output =
3095            collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
3096
3097        allow_duplicates! {
3098        assert_snapshot!(batches_to_sort_string(&output), @r"
3099        +-----+-----+-------+---------------+-------+
3100        | a   | b   | const | __grouping_id | 1     |
3101        +-----+-----+-------+---------------+-------+
3102        |     |     | 1     | 6             | 32768 |
3103        |     | 0.0 |       | 5             | 32768 |
3104        | 0.0 |     |       | 3             | 32768 |
3105        +-----+-----+-------+---------------+-------+
3106        ");
3107        }
3108
3109        Ok(())
3110    }
3111
3112    #[tokio::test]
3113    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
3114        let batch = RecordBatch::try_new(
3115            Arc::new(Schema::new(vec![
3116                Field::new(
3117                    "labels".to_string(),
3118                    DataType::Struct(
3119                        vec![
3120                            Field::new(
3121                                "a".to_string(),
3122                                DataType::Dictionary(
3123                                    Box::new(DataType::Int32),
3124                                    Box::new(DataType::Utf8),
3125                                ),
3126                                true,
3127                            ),
3128                            Field::new(
3129                                "b".to_string(),
3130                                DataType::Dictionary(
3131                                    Box::new(DataType::Int32),
3132                                    Box::new(DataType::Utf8),
3133                                ),
3134                                true,
3135                            ),
3136                        ]
3137                        .into(),
3138                    ),
3139                    false,
3140                ),
3141                Field::new("value", DataType::UInt64, false),
3142            ])),
3143            vec![
3144                Arc::new(StructArray::from(vec![
3145                    (
3146                        Arc::new(Field::new(
3147                            "a".to_string(),
3148                            DataType::Dictionary(
3149                                Box::new(DataType::Int32),
3150                                Box::new(DataType::Utf8),
3151                            ),
3152                            true,
3153                        )),
3154                        Arc::new(
3155                            vec![Some("a"), None, Some("a")]
3156                                .into_iter()
3157                                .collect::<DictionaryArray<Int32Type>>(),
3158                        ) as ArrayRef,
3159                    ),
3160                    (
3161                        Arc::new(Field::new(
3162                            "b".to_string(),
3163                            DataType::Dictionary(
3164                                Box::new(DataType::Int32),
3165                                Box::new(DataType::Utf8),
3166                            ),
3167                            true,
3168                        )),
3169                        Arc::new(
3170                            vec![Some("b"), Some("c"), Some("b")]
3171                                .into_iter()
3172                                .collect::<DictionaryArray<Int32Type>>(),
3173                        ) as ArrayRef,
3174                    ),
3175                ])),
3176                Arc::new(UInt64Array::from(vec![1, 1, 1])),
3177            ],
3178        )
3179        .expect("Failed to create RecordBatch");
3180
3181        let group_by = PhysicalGroupBy::new_single(vec![(
3182            col("labels", &batch.schema())?,
3183            "labels".to_string(),
3184        )]);
3185
3186        let aggr_expr = vec![
3187            AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
3188                .schema(Arc::clone(&batch.schema()))
3189                .alias(String::from("SUM(value)"))
3190                .build()
3191                .map(Arc::new)?,
3192        ];
3193
3194        let input = TestMemoryExec::try_new_exec(
3195            &[vec![batch.clone()]],
3196            Arc::<Schema>::clone(&batch.schema()),
3197            None,
3198        )?;
3199        let aggregate_exec = Arc::new(AggregateExec::try_new(
3200            AggregateMode::FinalPartitioned,
3201            group_by,
3202            aggr_expr,
3203            vec![None],
3204            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3205            batch.schema(),
3206        )?);
3207
3208        let session_config = SessionConfig::default();
3209        let ctx = TaskContext::default().with_session_config(session_config);
3210        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3211
3212        allow_duplicates! {
3213        assert_snapshot!(batches_to_string(&output), @r"
3214        +--------------+------------+
3215        | labels       | SUM(value) |
3216        +--------------+------------+
3217        | {a: a, b: b} | 2          |
3218        | {a: , b: c}  | 1          |
3219        +--------------+------------+
3220        ");
3221        }
3222
3223        Ok(())
3224    }
3225
3226    #[tokio::test]
3227    async fn test_skip_aggregation_after_first_batch() -> Result<()> {
3228        let schema = Arc::new(Schema::new(vec![
3229            Field::new("key", DataType::Int32, true),
3230            Field::new("val", DataType::Int32, true),
3231        ]));
3232
3233        let group_by =
3234            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3235
3236        let aggr_expr = vec![
3237            AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3238                .schema(Arc::clone(&schema))
3239                .alias(String::from("COUNT(val)"))
3240                .build()
3241                .map(Arc::new)?,
3242        ];
3243
3244        let input_data = vec![
3245            RecordBatch::try_new(
3246                Arc::clone(&schema),
3247                vec![
3248                    Arc::new(Int32Array::from(vec![1, 2, 3])),
3249                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3250                ],
3251            )
3252            .unwrap(),
3253            RecordBatch::try_new(
3254                Arc::clone(&schema),
3255                vec![
3256                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3257                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3258                ],
3259            )
3260            .unwrap(),
3261        ];
3262
3263        let input =
3264            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3265        let aggregate_exec = Arc::new(AggregateExec::try_new(
3266            AggregateMode::Partial,
3267            group_by,
3268            aggr_expr,
3269            vec![None],
3270            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3271            schema,
3272        )?);
3273
3274        let mut session_config = SessionConfig::default();
3275        session_config = session_config.set(
3276            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3277            &ScalarValue::Int64(Some(2)),
3278        );
3279        session_config = session_config.set(
3280            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3281            &ScalarValue::Float64(Some(0.1)),
3282        );
3283
3284        let ctx = TaskContext::default().with_session_config(session_config);
3285        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3286
3287        allow_duplicates! {
3288            assert_snapshot!(batches_to_string(&output), @r"
3289            +-----+-------------------+
3290            | key | COUNT(val)[count] |
3291            +-----+-------------------+
3292            | 1   | 1                 |
3293            | 2   | 1                 |
3294            | 3   | 1                 |
3295            | 2   | 1                 |
3296            | 3   | 1                 |
3297            | 4   | 1                 |
3298            +-----+-------------------+
3299            ");
3300        }
3301
3302        Ok(())
3303    }
3304
3305    #[tokio::test]
3306    async fn test_skip_aggregation_after_threshold() -> Result<()> {
3307        let schema = Arc::new(Schema::new(vec![
3308            Field::new("key", DataType::Int32, true),
3309            Field::new("val", DataType::Int32, true),
3310        ]));
3311
3312        let group_by =
3313            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3314
3315        let aggr_expr = vec![
3316            AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3317                .schema(Arc::clone(&schema))
3318                .alias(String::from("COUNT(val)"))
3319                .build()
3320                .map(Arc::new)?,
3321        ];
3322
3323        let input_data = vec![
3324            RecordBatch::try_new(
3325                Arc::clone(&schema),
3326                vec![
3327                    Arc::new(Int32Array::from(vec![1, 2, 3])),
3328                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3329                ],
3330            )
3331            .unwrap(),
3332            RecordBatch::try_new(
3333                Arc::clone(&schema),
3334                vec![
3335                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3336                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3337                ],
3338            )
3339            .unwrap(),
3340            RecordBatch::try_new(
3341                Arc::clone(&schema),
3342                vec![
3343                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3344                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3345                ],
3346            )
3347            .unwrap(),
3348        ];
3349
3350        let input =
3351            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3352        let aggregate_exec = Arc::new(AggregateExec::try_new(
3353            AggregateMode::Partial,
3354            group_by,
3355            aggr_expr,
3356            vec![None],
3357            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3358            schema,
3359        )?);
3360
3361        let mut session_config = SessionConfig::default();
3362        session_config = session_config.set(
3363            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3364            &ScalarValue::Int64(Some(5)),
3365        );
3366        session_config = session_config.set(
3367            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3368            &ScalarValue::Float64(Some(0.1)),
3369        );
3370
3371        let ctx = TaskContext::default().with_session_config(session_config);
3372        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3373
3374        allow_duplicates! {
3375            assert_snapshot!(batches_to_string(&output), @r"
3376            +-----+-------------------+
3377            | key | COUNT(val)[count] |
3378            +-----+-------------------+
3379            | 1   | 1                 |
3380            | 2   | 2                 |
3381            | 3   | 2                 |
3382            | 4   | 1                 |
3383            | 2   | 1                 |
3384            | 3   | 1                 |
3385            | 4   | 1                 |
3386            +-----+-------------------+
3387            ");
3388        }
3389
3390        Ok(())
3391    }
3392
3393    #[test]
3394    fn group_exprs_nullable() -> Result<()> {
3395        let input_schema = Arc::new(Schema::new(vec![
3396            Field::new("a", DataType::Float32, false),
3397            Field::new("b", DataType::Float32, false),
3398        ]));
3399
3400        let aggr_expr = vec![
3401            AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
3402                .schema(Arc::clone(&input_schema))
3403                .alias("COUNT(a)")
3404                .build()
3405                .map(Arc::new)?,
3406        ];
3407
3408        let grouping_set = PhysicalGroupBy::new(
3409            vec![
3410                (col("a", &input_schema)?, "a".to_string()),
3411                (col("b", &input_schema)?, "b".to_string()),
3412            ],
3413            vec![
3414                (lit(ScalarValue::Float32(None)), "a".to_string()),
3415                (lit(ScalarValue::Float32(None)), "b".to_string()),
3416            ],
3417            vec![
3418                vec![false, true],  // (a, NULL)
3419                vec![false, false], // (a,b)
3420            ],
3421            true,
3422        );
3423        let aggr_schema = create_schema(
3424            &input_schema,
3425            &grouping_set,
3426            &aggr_expr,
3427            AggregateMode::Final,
3428        )?;
3429        let expected_schema = Schema::new(vec![
3430            Field::new("a", DataType::Float32, false),
3431            Field::new("b", DataType::Float32, true),
3432            Field::new("__grouping_id", DataType::UInt8, false),
3433            Field::new("COUNT(a)", DataType::Int64, false),
3434        ]);
3435        assert_eq!(aggr_schema, expected_schema);
3436        Ok(())
3437    }
3438
3439    // test for https://github.com/apache/datafusion/issues/13949
3440    async fn run_test_with_spill_pool_if_necessary(
3441        pool_size: usize,
3442        expect_spill: bool,
3443    ) -> Result<()> {
3444        fn create_record_batch(
3445            schema: &Arc<Schema>,
3446            data: (Vec<u32>, Vec<f64>),
3447        ) -> Result<RecordBatch> {
3448            Ok(RecordBatch::try_new(
3449                Arc::clone(schema),
3450                vec![
3451                    Arc::new(UInt32Array::from(data.0)),
3452                    Arc::new(Float64Array::from(data.1)),
3453                ],
3454            )?)
3455        }
3456
3457        let schema = Arc::new(Schema::new(vec![
3458            Field::new("a", DataType::UInt32, false),
3459            Field::new("b", DataType::Float64, false),
3460        ]));
3461
3462        let batches = vec![
3463            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3464            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3465        ];
3466        let plan: Arc<dyn ExecutionPlan> =
3467            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3468
3469        let grouping_set = PhysicalGroupBy::new(
3470            vec![(col("a", &schema)?, "a".to_string())],
3471            vec![],
3472            vec![vec![false]],
3473            false,
3474        );
3475
3476        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
3477        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3478            Arc::new(
3479                AggregateExprBuilder::new(
3480                    datafusion_functions_aggregate::min_max::min_udaf(),
3481                    vec![col("b", &schema)?],
3482                )
3483                .schema(Arc::clone(&schema))
3484                .alias("MIN(b)")
3485                .build()?,
3486            ),
3487            Arc::new(
3488                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3489                    .schema(Arc::clone(&schema))
3490                    .alias("AVG(b)")
3491                    .build()?,
3492            ),
3493        ];
3494
3495        let single_aggregate = Arc::new(AggregateExec::try_new(
3496            AggregateMode::Single,
3497            grouping_set,
3498            aggregates,
3499            vec![None, None],
3500            plan,
3501            Arc::clone(&schema),
3502        )?);
3503
3504        let batch_size = 2;
3505        let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3506        let task_ctx = Arc::new(
3507            TaskContext::default()
3508                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3509                .with_runtime(Arc::new(
3510                    RuntimeEnvBuilder::new()
3511                        .with_memory_pool(memory_pool)
3512                        .build()?,
3513                )),
3514        );
3515
3516        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3517
3518        assert_spill_count_metric(expect_spill, single_aggregate);
3519
3520        allow_duplicates! {
3521            assert_snapshot!(batches_to_string(&result), @r"
3522            +---+--------+--------+
3523            | a | MIN(b) | AVG(b) |
3524            +---+--------+--------+
3525            | 2 | 1.0    | 1.0    |
3526            | 3 | 2.0    | 2.0    |
3527            | 4 | 3.0    | 3.5    |
3528            +---+--------+--------+
3529            ");
3530        }
3531
3532        Ok(())
3533    }
3534
3535    fn assert_spill_count_metric(
3536        expect_spill: bool,
3537        single_aggregate: Arc<AggregateExec>,
3538    ) {
3539        if let Some(metrics_set) = single_aggregate.metrics() {
3540            let mut spill_count = 0;
3541
3542            // Inspect metrics for SpillCount
3543            for metric in metrics_set.iter() {
3544                if let MetricValue::SpillCount(count) = metric.value() {
3545                    spill_count = count.value();
3546                    break;
3547                }
3548            }
3549
3550            if expect_spill && spill_count == 0 {
3551                panic!(
3552                    "Expected spill but SpillCount metric not found or SpillCount was 0."
3553                );
3554            } else if !expect_spill && spill_count > 0 {
3555                panic!(
3556                    "Expected no spill but found SpillCount metric with value greater than 0."
3557                );
3558            }
3559        } else {
3560            panic!("No metrics returned from the operator; cannot verify spilling.");
3561        }
3562    }
3563
3564    #[tokio::test]
3565    async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3566        // test with spill
3567        run_test_with_spill_pool_if_necessary(2_000, true).await?;
3568        // test without spill
3569        run_test_with_spill_pool_if_necessary(20_000, false).await?;
3570        Ok(())
3571    }
3572
3573    #[tokio::test]
3574    async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
3575        // test with spill
3576        fn create_record_batch(
3577            schema: &Arc<Schema>,
3578            data: (Vec<u32>, Vec<f64>),
3579        ) -> Result<RecordBatch> {
3580            Ok(RecordBatch::try_new(
3581                Arc::clone(schema),
3582                vec![
3583                    Arc::new(UInt32Array::from(data.0)),
3584                    Arc::new(Float64Array::from(data.1)),
3585                ],
3586            )?)
3587        }
3588
3589        let schema = Arc::new(Schema::new(vec![
3590            Field::new("a", DataType::UInt32, false),
3591            Field::new("b", DataType::Float64, false),
3592        ]));
3593
3594        let batches = vec![
3595            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3596            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3597        ];
3598        let plan: Arc<dyn ExecutionPlan> =
3599            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3600        let proj = ProjectionExec::try_new(
3601            vec![
3602                ProjectionExpr::new(lit("0"), "l".to_string()),
3603                ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
3604                ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
3605            ],
3606            plan,
3607        )?;
3608        let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
3609        let schema = plan.schema();
3610
3611        let grouping_set = PhysicalGroupBy::new(
3612            vec![
3613                (col("l", &schema)?, "l".to_string()),
3614                (col("a", &schema)?, "a".to_string()),
3615            ],
3616            vec![],
3617            vec![vec![false, false]],
3618            false,
3619        );
3620
3621        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
3622        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3623            Arc::new(
3624                AggregateExprBuilder::new(
3625                    datafusion_functions_aggregate::min_max::min_udaf(),
3626                    vec![col("b", &schema)?],
3627                )
3628                .schema(Arc::clone(&schema))
3629                .alias("MIN(b)")
3630                .build()?,
3631            ),
3632            Arc::new(
3633                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3634                    .schema(Arc::clone(&schema))
3635                    .alias("AVG(b)")
3636                    .build()?,
3637            ),
3638        ];
3639
3640        let single_aggregate = Arc::new(AggregateExec::try_new(
3641            AggregateMode::Single,
3642            grouping_set,
3643            aggregates,
3644            vec![None, None],
3645            plan,
3646            Arc::clone(&schema),
3647        )?);
3648
3649        let batch_size = 2;
3650        let memory_pool = Arc::new(FairSpillPool::new(2000));
3651        let task_ctx = Arc::new(
3652            TaskContext::default()
3653                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3654                .with_runtime(Arc::new(
3655                    RuntimeEnvBuilder::new()
3656                        .with_memory_pool(memory_pool)
3657                        .build()?,
3658                )),
3659        );
3660
3661        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
3662        match result {
3663            Ok(result) => {
3664                assert_spill_count_metric(true, single_aggregate);
3665
3666                allow_duplicates! {
3667                    assert_snapshot!(batches_to_string(&result), @r"
3668                +---+---+--------+--------+
3669                | l | a | MIN(b) | AVG(b) |
3670                +---+---+--------+--------+
3671                | 0 | 2 | 1.0    | 1.0    |
3672                | 0 | 3 | 2.0    | 2.0    |
3673                | 0 | 4 | 3.0    | 3.5    |
3674                +---+---+--------+--------+
3675            ");
3676                }
3677            }
3678            Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
3679        }
3680
3681        Ok(())
3682    }
3683
3684    #[tokio::test]
3685    async fn test_aggregate_statistics_edge_cases() -> Result<()> {
3686        use crate::test::exec::StatisticsExec;
3687        use datafusion_common::ColumnStatistics;
3688
3689        let schema = Arc::new(Schema::new(vec![
3690            Field::new("a", DataType::Int32, false),
3691            Field::new("b", DataType::Float64, false),
3692        ]));
3693
3694        // Test 1: Absent statistics remain absent
3695        let input = Arc::new(StatisticsExec::new(
3696            Statistics {
3697                num_rows: Precision::Exact(100),
3698                total_byte_size: Precision::Absent,
3699                column_statistics: vec![
3700                    ColumnStatistics::new_unknown(),
3701                    ColumnStatistics::new_unknown(),
3702                ],
3703            },
3704            (*schema).clone(),
3705        )) as Arc<dyn ExecutionPlan>;
3706
3707        let agg = Arc::new(AggregateExec::try_new(
3708            AggregateMode::Final,
3709            PhysicalGroupBy::default(),
3710            vec![Arc::new(
3711                AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3712                    .schema(Arc::clone(&schema))
3713                    .alias("COUNT(a)")
3714                    .build()?,
3715            )],
3716            vec![None],
3717            input,
3718            Arc::clone(&schema),
3719        )?);
3720
3721        let stats = agg.partition_statistics(None)?;
3722        assert_eq!(stats.total_byte_size, Precision::Absent);
3723
3724        // Test 2: Zero rows returns Absent (can't estimate output size from zero input)
3725        let input_zero = Arc::new(StatisticsExec::new(
3726            Statistics {
3727                num_rows: Precision::Exact(0),
3728                total_byte_size: Precision::Exact(0),
3729                column_statistics: vec![
3730                    ColumnStatistics::new_unknown(),
3731                    ColumnStatistics::new_unknown(),
3732                ],
3733            },
3734            (*schema).clone(),
3735        )) as Arc<dyn ExecutionPlan>;
3736
3737        let agg_zero = Arc::new(AggregateExec::try_new(
3738            AggregateMode::Final,
3739            PhysicalGroupBy::default(),
3740            vec![Arc::new(
3741                AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3742                    .schema(Arc::clone(&schema))
3743                    .alias("COUNT(a)")
3744                    .build()?,
3745            )],
3746            vec![None],
3747            input_zero,
3748            Arc::clone(&schema),
3749        )?);
3750
3751        let stats_zero = agg_zero.partition_statistics(None)?;
3752        assert_eq!(stats_zero.total_byte_size, Precision::Absent);
3753
3754        Ok(())
3755    }
3756
3757    #[tokio::test]
3758    async fn test_order_is_retained_when_spilling() -> Result<()> {
3759        let schema = Arc::new(Schema::new(vec![
3760            Field::new("a", DataType::Int64, false),
3761            Field::new("b", DataType::Int64, false),
3762            Field::new("c", DataType::Int64, false),
3763        ]));
3764
3765        let batches = vec![vec![
3766            RecordBatch::try_new(
3767                Arc::clone(&schema),
3768                vec![
3769                    Arc::new(Int64Array::from(vec![2])),
3770                    Arc::new(Int64Array::from(vec![2])),
3771                    Arc::new(Int64Array::from(vec![1])),
3772                ],
3773            )?,
3774            RecordBatch::try_new(
3775                Arc::clone(&schema),
3776                vec![
3777                    Arc::new(Int64Array::from(vec![1])),
3778                    Arc::new(Int64Array::from(vec![1])),
3779                    Arc::new(Int64Array::from(vec![1])),
3780                ],
3781            )?,
3782            RecordBatch::try_new(
3783                Arc::clone(&schema),
3784                vec![
3785                    Arc::new(Int64Array::from(vec![0])),
3786                    Arc::new(Int64Array::from(vec![0])),
3787                    Arc::new(Int64Array::from(vec![1])),
3788                ],
3789            )?,
3790        ]];
3791        let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3792        let scan = scan.try_with_sort_information(vec![
3793            LexOrdering::new([PhysicalSortExpr::new(
3794                col("b", schema.as_ref())?,
3795                SortOptions::default().desc(),
3796            )])
3797            .unwrap(),
3798        ])?;
3799
3800        let aggr = Arc::new(AggregateExec::try_new(
3801            AggregateMode::Single,
3802            PhysicalGroupBy::new(
3803                vec![
3804                    (col("b", schema.as_ref())?, "b".to_string()),
3805                    (col("c", schema.as_ref())?, "c".to_string()),
3806                ],
3807                vec![],
3808                vec![vec![false, false]],
3809                false,
3810            ),
3811            vec![Arc::new(
3812                AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
3813                    .schema(Arc::clone(&schema))
3814                    .alias("SUM(c)")
3815                    .build()?,
3816            )],
3817            vec![None],
3818            Arc::new(scan) as Arc<dyn ExecutionPlan>,
3819            Arc::clone(&schema),
3820        )?);
3821
3822        let task_ctx = new_spill_ctx(1, 600);
3823        let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
3824        assert_spill_count_metric(true, aggr);
3825
3826        allow_duplicates! {
3827            assert_snapshot!(batches_to_string(&result), @r"
3828            +---+---+--------+
3829            | b | c | SUM(c) |
3830            +---+---+--------+
3831            | 2 | 1 | 1      |
3832            | 1 | 1 | 1      |
3833            | 0 | 1 | 1      |
3834            +---+---+--------+
3835        ");
3836        }
3837        Ok(())
3838    }
3839
3840    /// Tests that when the memory pool is too small to accommodate the sort
3841    /// reservation during spill, the error is properly propagated as
3842    /// ResourcesExhausted rather than silently exceeding memory limits.
3843    #[tokio::test]
3844    async fn test_sort_reservation_fails_during_spill() -> Result<()> {
3845        let schema = Arc::new(Schema::new(vec![
3846            Field::new("g", DataType::Int64, false),
3847            Field::new("a", DataType::Float64, false),
3848            Field::new("b", DataType::Float64, false),
3849            Field::new("c", DataType::Float64, false),
3850            Field::new("d", DataType::Float64, false),
3851            Field::new("e", DataType::Float64, false),
3852        ]));
3853
3854        let batches = vec![vec![
3855            RecordBatch::try_new(
3856                Arc::clone(&schema),
3857                vec![
3858                    Arc::new(Int64Array::from(vec![1])),
3859                    Arc::new(Float64Array::from(vec![10.0])),
3860                    Arc::new(Float64Array::from(vec![20.0])),
3861                    Arc::new(Float64Array::from(vec![30.0])),
3862                    Arc::new(Float64Array::from(vec![40.0])),
3863                    Arc::new(Float64Array::from(vec![50.0])),
3864                ],
3865            )?,
3866            RecordBatch::try_new(
3867                Arc::clone(&schema),
3868                vec![
3869                    Arc::new(Int64Array::from(vec![2])),
3870                    Arc::new(Float64Array::from(vec![11.0])),
3871                    Arc::new(Float64Array::from(vec![21.0])),
3872                    Arc::new(Float64Array::from(vec![31.0])),
3873                    Arc::new(Float64Array::from(vec![41.0])),
3874                    Arc::new(Float64Array::from(vec![51.0])),
3875                ],
3876            )?,
3877            RecordBatch::try_new(
3878                Arc::clone(&schema),
3879                vec![
3880                    Arc::new(Int64Array::from(vec![3])),
3881                    Arc::new(Float64Array::from(vec![12.0])),
3882                    Arc::new(Float64Array::from(vec![22.0])),
3883                    Arc::new(Float64Array::from(vec![32.0])),
3884                    Arc::new(Float64Array::from(vec![42.0])),
3885                    Arc::new(Float64Array::from(vec![52.0])),
3886                ],
3887            )?,
3888        ]];
3889
3890        let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3891
3892        let aggr = Arc::new(AggregateExec::try_new(
3893            AggregateMode::Single,
3894            PhysicalGroupBy::new(
3895                vec![(col("g", schema.as_ref())?, "g".to_string())],
3896                vec![],
3897                vec![vec![false]],
3898                false,
3899            ),
3900            vec![
3901                Arc::new(
3902                    AggregateExprBuilder::new(
3903                        avg_udaf(),
3904                        vec![col("a", schema.as_ref())?],
3905                    )
3906                    .schema(Arc::clone(&schema))
3907                    .alias("AVG(a)")
3908                    .build()?,
3909                ),
3910                Arc::new(
3911                    AggregateExprBuilder::new(
3912                        avg_udaf(),
3913                        vec![col("b", schema.as_ref())?],
3914                    )
3915                    .schema(Arc::clone(&schema))
3916                    .alias("AVG(b)")
3917                    .build()?,
3918                ),
3919                Arc::new(
3920                    AggregateExprBuilder::new(
3921                        avg_udaf(),
3922                        vec![col("c", schema.as_ref())?],
3923                    )
3924                    .schema(Arc::clone(&schema))
3925                    .alias("AVG(c)")
3926                    .build()?,
3927                ),
3928                Arc::new(
3929                    AggregateExprBuilder::new(
3930                        avg_udaf(),
3931                        vec![col("d", schema.as_ref())?],
3932                    )
3933                    .schema(Arc::clone(&schema))
3934                    .alias("AVG(d)")
3935                    .build()?,
3936                ),
3937                Arc::new(
3938                    AggregateExprBuilder::new(
3939                        avg_udaf(),
3940                        vec![col("e", schema.as_ref())?],
3941                    )
3942                    .schema(Arc::clone(&schema))
3943                    .alias("AVG(e)")
3944                    .build()?,
3945                ),
3946            ],
3947            vec![None, None, None, None, None],
3948            Arc::new(scan) as Arc<dyn ExecutionPlan>,
3949            Arc::clone(&schema),
3950        )?);
3951
3952        // Pool must be large enough for accumulation to start but too small for
3953        // sort_memory after clearing.
3954        let task_ctx = new_spill_ctx(1, 500);
3955        let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;
3956
3957        match &result {
3958            Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
3959            Err(e) => {
3960                let root = e.find_root();
3961                assert!(
3962                    matches!(root, DataFusionError::ResourcesExhausted(_)),
3963                    "Expected ResourcesExhausted, got: {root}",
3964                );
3965                let msg = root.to_string();
3966                assert!(
3967                    msg.contains("Failed to reserve memory for sort during spill"),
3968                    "Expected sort reservation error, got: {msg}",
3969                );
3970            }
3971        }
3972
3973        Ok(())
3974    }
3975
3976    /// Tests that PartialReduce mode:
3977    /// 1. Accepts state as input (like Final)
3978    /// 2. Produces state as output (like Partial)
3979    /// 3. Can be followed by a Final stage to get the correct result
3980    ///
3981    /// This simulates a tree-reduce pattern:
3982    ///   Partial -> PartialReduce -> Final
3983    #[tokio::test]
3984    async fn test_partial_reduce_mode() -> Result<()> {
3985        let schema = Arc::new(Schema::new(vec![
3986            Field::new("a", DataType::UInt32, false),
3987            Field::new("b", DataType::Float64, false),
3988        ]));
3989
3990        // Produce two partitions of input data
3991        let batch1 = RecordBatch::try_new(
3992            Arc::clone(&schema),
3993            vec![
3994                Arc::new(UInt32Array::from(vec![1, 2, 3])),
3995                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
3996            ],
3997        )?;
3998        let batch2 = RecordBatch::try_new(
3999            Arc::clone(&schema),
4000            vec![
4001                Arc::new(UInt32Array::from(vec![1, 2, 3])),
4002                Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
4003            ],
4004        )?;
4005
4006        let groups =
4007            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
4008        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
4009            AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
4010                .schema(Arc::clone(&schema))
4011                .alias("SUM(b)")
4012                .build()?,
4013        )];
4014
4015        // Step 1: Partial aggregation on partition 1
4016        let input1 =
4017            TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?;
4018        let partial1 = Arc::new(AggregateExec::try_new(
4019            AggregateMode::Partial,
4020            groups.clone(),
4021            aggregates.clone(),
4022            vec![None],
4023            input1,
4024            Arc::clone(&schema),
4025        )?);
4026
4027        // Step 2: Partial aggregation on partition 2
4028        let input2 =
4029            TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?;
4030        let partial2 = Arc::new(AggregateExec::try_new(
4031            AggregateMode::Partial,
4032            groups.clone(),
4033            aggregates.clone(),
4034            vec![None],
4035            input2,
4036            Arc::clone(&schema),
4037        )?);
4038
4039        // Collect partial results
4040        let task_ctx = Arc::new(TaskContext::default());
4041        let partial_result1 =
4042            crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?;
4043        let partial_result2 =
4044            crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?;
4045
4046        // The partial results have state schema (group cols + accumulator state)
4047        let partial_schema = partial1.schema();
4048
4049        // Step 3: PartialReduce — combine partial results, still producing state
4050        let combined_input = TestMemoryExec::try_new_exec(
4051            &[partial_result1, partial_result2],
4052            Arc::clone(&partial_schema),
4053            None,
4054        )?;
4055        // Coalesce into a single partition for the PartialReduce
4056        let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
4057
4058        let partial_reduce = Arc::new(AggregateExec::try_new(
4059            AggregateMode::PartialReduce,
4060            groups.clone(),
4061            aggregates.clone(),
4062            vec![None],
4063            coalesced,
4064            Arc::clone(&partial_schema),
4065        )?);
4066
4067        // Verify PartialReduce output schema matches Partial output schema
4068        // (both produce state, not final values)
4069        assert_eq!(partial_reduce.schema(), partial_schema);
4070
4071        // Collect PartialReduce results
4072        let reduce_result =
4073            crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx))
4074                .await?;
4075
4076        // Step 4: Final aggregation on the PartialReduce output
4077        let final_input = TestMemoryExec::try_new_exec(
4078            &[reduce_result],
4079            Arc::clone(&partial_schema),
4080            None,
4081        )?;
4082        let final_agg = Arc::new(AggregateExec::try_new(
4083            AggregateMode::Final,
4084            groups.clone(),
4085            aggregates.clone(),
4086            vec![None],
4087            final_input,
4088            Arc::clone(&partial_schema),
4089        )?);
4090
4091        let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
4092
4093        // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 30+60=90
4094        assert_snapshot!(batches_to_sort_string(&result), @r"
4095            +---+--------+
4096            | a | SUM(b) |
4097            +---+--------+
4098            | 1 | 50.0   |
4099            | 2 | 70.0   |
4100            | 3 | 90.0   |
4101            +---+--------+
4102        ");
4103
4104        Ok(())
4105    }
4106}