Skip to main content

datafusion_physical_plan/aggregates/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregates functionalities
19
20use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25    no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26    topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30    ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
31    FilterPushdownPropagation, PushedDownPredicate,
32};
33use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
34use crate::{
35    DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36    SendableRecordBatchStream, Statistics,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use parking_lot::Mutex;
41use std::collections::HashSet;
42
43use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44use arrow::datatypes::{Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use arrow_schema::FieldRef;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49    Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_expr::{Accumulator, Aggregate};
53use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
54use datafusion_physical_expr::equivalence::ProjectionMapping;
55use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
56use datafusion_physical_expr::{
57    ConstExpr, EquivalenceProperties, physical_exprs_contains,
58};
59use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
60use datafusion_physical_expr_common::sort_expr::{
61    LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
62};
63
64use datafusion_expr::utils::AggregateOrderSensitivity;
65use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
66use itertools::Itertools;
67
68pub mod group_values;
69mod no_grouping;
70pub mod order;
71mod row_hash;
72mod topk;
73mod topk_stream;
74
75/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions.
76const AGGREGATION_HASH_SEED: ahash::RandomState =
77    ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
78
79/// Aggregation modes
80///
81/// See [`Accumulator::state`] for background information on multi-phase
82/// aggregation and how these modes are used.
83#[derive(Debug, Copy, Clone, PartialEq, Eq)]
84pub enum AggregateMode {
85    /// One of multiple layers of aggregation, any input partitioning
86    ///
87    /// Partial aggregate that can be applied in parallel across input
88    /// partitions.
89    ///
90    /// This is the first phase of a multi-phase aggregation.
91    Partial,
92    /// *Final* of multiple layers of aggregation, in exactly one partition
93    ///
94    /// Final aggregate that produces a single partition of output by combining
95    /// the output of multiple partial aggregates.
96    ///
97    /// This is the second phase of a multi-phase aggregation.
98    ///
99    /// This mode requires that the input is a single partition
100    ///
101    /// Note: Adjacent `Partial` and `Final` mode aggregation is equivalent to a `Single`
102    /// mode aggregation node. The `Final` mode is required since this is used in an
103    /// intermediate step. The [`CombinePartialFinalAggregate`] physical optimizer rule
104    /// will replace this combination with `Single` mode for more efficient execution.
105    ///
106    /// [`CombinePartialFinalAggregate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/combine_partial_final_agg/struct.CombinePartialFinalAggregate.html
107    Final,
108    /// *Final* of multiple layers of aggregation, input is *Partitioned*
109    ///
110    /// Final aggregate that works on pre-partitioned data.
111    ///
112    /// This mode requires that all rows with a particular grouping key are in
113    /// the same partitions, such as is the case with Hash repartitioning on the
114    /// group keys. If a group key is duplicated, duplicate groups would be
115    /// produced
116    FinalPartitioned,
117    /// *Single* layer of Aggregation, input is exactly one partition
118    ///
119    /// Applies the entire logical aggregation operation in a single operator,
120    /// as opposed to Partial / Final modes which apply the logical aggregation using
121    /// two operators.
122    ///
123    /// This mode requires that the input is a single partition (like Final)
124    Single,
125    /// *Single* layer of Aggregation, input is *Partitioned*
126    ///
127    /// Applies the entire logical aggregation operation in a single operator,
128    /// as opposed to Partial / Final modes which apply the logical aggregation
129    /// using two operators.
130    ///
131    /// This mode requires that the input has more than one partition, and is
132    /// partitioned by group key (like FinalPartitioned).
133    SinglePartitioned,
134}
135
136impl AggregateMode {
137    /// Checks whether this aggregation step describes a "first stage" calculation.
138    /// In other words, its input is not another aggregation result and the
139    /// `merge_batch` method will not be called for these modes.
140    pub fn is_first_stage(&self) -> bool {
141        match self {
142            AggregateMode::Partial
143            | AggregateMode::Single
144            | AggregateMode::SinglePartitioned => true,
145            AggregateMode::Final | AggregateMode::FinalPartitioned => false,
146        }
147    }
148}
149
150/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
151/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
152/// and a single group [false, false].
153/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression
154/// into multiple groups, using null expressions to align each group.
155/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should
156/// create a `PhysicalGroupBy` like
157/// ```text
158/// PhysicalGroupBy {
159///     expr: [(col(a), a), (col(b), b)],
160///     null_expr: [(NULL, a), (NULL, b)],
161///     groups: [
162///         [false, false], // (a,b)
163///         [false, true],  // (a) <=> (a, NULL)
164///         [true, false]   // (b) <=> (NULL, b)
165///     ]
166/// }
167/// ```
168#[derive(Clone, Debug, Default)]
169pub struct PhysicalGroupBy {
170    /// Distinct (Physical Expr, Alias) in the grouping set
171    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
172    /// Corresponding NULL expressions for expr
173    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
174    /// Null mask for each group in this grouping set. Each group is
175    /// composed of either one of the group expressions in expr or a null
176    /// expression in null_expr. If `groups[i][j]` is true, then the
177    /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
178    groups: Vec<Vec<bool>>,
179    /// True when GROUPING SETS/CUBE/ROLLUP are used so `__grouping_id` should
180    /// be included in the output schema.
181    has_grouping_set: bool,
182}
183
184impl PhysicalGroupBy {
185    /// Create a new `PhysicalGroupBy`
186    pub fn new(
187        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
188        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
189        groups: Vec<Vec<bool>>,
190        has_grouping_set: bool,
191    ) -> Self {
192        Self {
193            expr,
194            null_expr,
195            groups,
196            has_grouping_set,
197        }
198    }
199
200    /// Create a GROUPING SET with only a single group. This is the "standard"
201    /// case when building a plan from an expression such as `GROUP BY a,b,c`
202    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
203        let num_exprs = expr.len();
204        Self {
205            expr,
206            null_expr: vec![],
207            groups: vec![vec![false; num_exprs]],
208            has_grouping_set: false,
209        }
210    }
211
212    /// Calculate GROUP BY expressions nullable
213    pub fn exprs_nullable(&self) -> Vec<bool> {
214        let mut exprs_nullable = vec![false; self.expr.len()];
215        for group in self.groups.iter() {
216            group.iter().enumerate().for_each(|(index, is_null)| {
217                if *is_null {
218                    exprs_nullable[index] = true;
219                }
220            })
221        }
222        exprs_nullable
223    }
224
225    /// Returns true if this has no grouping at all (including no GROUPING SETS)
226    pub fn is_true_no_grouping(&self) -> bool {
227        self.is_empty() && !self.has_grouping_set
228    }
229
230    /// Returns the group expressions
231    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
232        &self.expr
233    }
234
235    /// Returns the null expressions
236    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
237        &self.null_expr
238    }
239
240    /// Returns the group null masks
241    pub fn groups(&self) -> &[Vec<bool>] {
242        &self.groups
243    }
244
245    /// Returns true if this grouping uses GROUPING SETS, CUBE or ROLLUP.
246    pub fn has_grouping_set(&self) -> bool {
247        self.has_grouping_set
248    }
249
250    /// Returns true if this `PhysicalGroupBy` has no group expressions
251    pub fn is_empty(&self) -> bool {
252        self.expr.is_empty()
253    }
254
255    /// Returns true if this is a "simple" GROUP BY (not using GROUPING SETS/CUBE/ROLLUP).
256    /// This determines whether the `__grouping_id` column is included in the output schema.
257    pub fn is_single(&self) -> bool {
258        !self.has_grouping_set
259    }
260
261    /// Calculate GROUP BY expressions according to input schema.
262    pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
263        self.expr
264            .iter()
265            .map(|(expr, _alias)| Arc::clone(expr))
266            .collect()
267    }
268
269    /// The number of expressions in the output schema.
270    fn num_output_exprs(&self) -> usize {
271        let mut num_exprs = self.expr.len();
272        if self.has_grouping_set {
273            num_exprs += 1
274        }
275        num_exprs
276    }
277
278    /// Return grouping expressions as they occur in the output schema.
279    pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
280        let num_output_exprs = self.num_output_exprs();
281        let mut output_exprs = Vec::with_capacity(num_output_exprs);
282        output_exprs.extend(
283            self.expr
284                .iter()
285                .enumerate()
286                .take(num_output_exprs)
287                .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
288        );
289        if self.has_grouping_set {
290            output_exprs.push(Arc::new(Column::new(
291                Aggregate::INTERNAL_GROUPING_ID,
292                self.expr.len(),
293            )) as _);
294        }
295        output_exprs
296    }
297
298    /// Returns the number expression as grouping keys.
299    pub fn num_group_exprs(&self) -> usize {
300        self.expr.len() + usize::from(self.has_grouping_set)
301    }
302
303    pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
304        Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
305    }
306
307    /// Returns the fields that are used as the grouping keys.
308    fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
309        let mut fields = Vec::with_capacity(self.num_group_exprs());
310        for ((expr, name), group_expr_nullable) in
311            self.expr.iter().zip(self.exprs_nullable().into_iter())
312        {
313            fields.push(
314                Field::new(
315                    name,
316                    expr.data_type(input_schema)?,
317                    group_expr_nullable || expr.nullable(input_schema)?,
318                )
319                .with_metadata(expr.return_field(input_schema)?.metadata().clone())
320                .into(),
321            );
322        }
323        if self.has_grouping_set {
324            fields.push(
325                Field::new(
326                    Aggregate::INTERNAL_GROUPING_ID,
327                    Aggregate::grouping_id_type(self.expr.len()),
328                    false,
329                )
330                .into(),
331            );
332        }
333        Ok(fields)
334    }
335
336    /// Returns the output fields of the group by.
337    ///
338    /// This might be different from the `group_fields` that might contain internal expressions that
339    /// should not be part of the output schema.
340    fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
341        let mut fields = self.group_fields(input_schema)?;
342        fields.truncate(self.num_output_exprs());
343        Ok(fields)
344    }
345
346    /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
347    /// aggregation.
348    pub fn as_final(&self) -> PhysicalGroupBy {
349        let expr: Vec<_> =
350            self.output_exprs()
351                .into_iter()
352                .zip(
353                    self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
354                        Aggregate::INTERNAL_GROUPING_ID.to_owned(),
355                    )),
356                )
357                .collect();
358        let num_exprs = expr.len();
359        let groups = if self.expr.is_empty() && !self.has_grouping_set {
360            // No GROUP BY expressions - should have no groups
361            vec![]
362        } else {
363            vec![vec![false; num_exprs]]
364        };
365        Self {
366            expr,
367            null_expr: vec![],
368            groups,
369            has_grouping_set: false,
370        }
371    }
372}
373
374impl PartialEq for PhysicalGroupBy {
375    fn eq(&self, other: &PhysicalGroupBy) -> bool {
376        self.expr.len() == other.expr.len()
377            && self
378                .expr
379                .iter()
380                .zip(other.expr.iter())
381                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
382            && self.null_expr.len() == other.null_expr.len()
383            && self
384                .null_expr
385                .iter()
386                .zip(other.null_expr.iter())
387                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
388            && self.groups == other.groups
389            && self.has_grouping_set == other.has_grouping_set
390    }
391}
392
393#[expect(clippy::large_enum_variant)]
394enum StreamType {
395    AggregateStream(AggregateStream),
396    GroupedHash(GroupedHashAggregateStream),
397    GroupedPriorityQueue(GroupedTopKAggregateStream),
398}
399
400impl From<StreamType> for SendableRecordBatchStream {
401    fn from(stream: StreamType) -> Self {
402        match stream {
403            StreamType::AggregateStream(stream) => Box::pin(stream),
404            StreamType::GroupedHash(stream) => Box::pin(stream),
405            StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
406        }
407    }
408}
409
410/// # Aggregate Dynamic Filter Pushdown Overview
411///
412/// For queries like
413///   -- `example_table(type TEXT, val INT)`
414///   SELECT min(val)
415///   FROM example_table
416///   WHERE type='A';
417///
418/// And `example_table`'s physical representation is a partitioned parquet file with
419/// column statistics
420/// - part-0.parquet: val {min=0, max=100}
421/// - part-1.parquet: val {min=100, max=200}
422/// - ...
423/// - part-100.parquet: val {min=10000, max=10100}
424///
425/// After scanning the 1st file, we know we only have to read files if their minimal
426/// value on `val` column is less than 0, the minimal `val` value in the 1st file.
427///
428/// We can skip scanning the remaining file by implementing dynamic filter, the
429/// intuition is we keep a shared data structure for current min in both `AggregateExec`
430/// and `DataSourceExec`, and let it update during execution, so the scanner can
431/// know during execution if it's possible to skip scanning certain files. See
432/// physical optimizer rule `FilterPushdown` for details.
433///
434/// # Implementation
435///
436/// ## Enable Condition
437/// - No grouping (no `GROUP BY` clause in the sql, only a single global group to aggregate)
438/// - The aggregate expression must be `min`/`max`, and evaluate directly on columns.
439///   Note multiple aggregate expressions that satisfy this requirement are allowed,
440///   and a dynamic filter will be constructed combining all applicable expr's
441///   states. See more in the following example with dynamic filter on multiple columns.
442///
443/// ## Filter Construction
444/// The filter is kept in the `DataSourceExec`, and it will gets update during execution,
445/// the reader will interpret it as "the upstream only needs rows that such filter
446/// predicate is evaluated to true", and certain scanner implementation like `parquet`
447/// can evalaute column statistics on those dynamic filters, to decide if they can
448/// prune a whole range.
449///
450/// ### Examples
451/// - Expr: `min(a)`, Dynamic Filter: `a < a_cur_min`
452/// - Expr: `min(a), max(a), min(b)`, Dynamic Filter: `(a < a_cur_min) OR (a > a_cur_max) OR (b < b_cur_min)`
453#[derive(Debug, Clone)]
454struct AggrDynFilter {
455    /// The physical expr for the dynamic filter shared between the `AggregateExec`
456    /// and the parquet scanner.
457    filter: Arc<DynamicFilterPhysicalExpr>,
458    /// The current bounds for the dynamic filter, updates during the execution to
459    /// tighten the bound for more effective pruning.
460    ///
461    /// Each vector element is for the accumulators that support dynamic filter.
462    /// e.g. This `AggregateExec` has accumulator:
463    /// min(a), avg(a), max(b)
464    /// And this field stores [PerAccumulatorDynFilter(min(a)), PerAccumulatorDynFilter(min(b))]
465    supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
466}
467
468// ---- Aggregate Dynamic Filter Utility Structs ----
469
470/// Aggregate expressions that support the dynamic filter pushdown in aggregation.
471/// See comments in [`AggrDynFilter`] for conditions.
472#[derive(Debug, Clone)]
473struct PerAccumulatorDynFilter {
474    aggr_type: DynamicFilterAggregateType,
475    /// During planning and optimization, the parent structure is kept in `AggregateExec`,
476    /// this index is into `aggr_expr` vec inside `AggregateExec`.
477    /// During execution, the parent struct is moved into `AggregateStream` (stream
478    /// for no grouping aggregate execution), and this index is into    `aggregate_expressions`
479    /// vec inside `AggregateStreamInner`
480    aggr_index: usize,
481    // The current bound. Shared among all streams.
482    shared_bound: Arc<Mutex<ScalarValue>>,
483}
484
485/// Aggregate types that are supported for dynamic filter in `AggregateExec`
486#[derive(Debug, Clone)]
487enum DynamicFilterAggregateType {
488    Min,
489    Max,
490}
491
492/// Hash aggregate execution plan
493#[derive(Debug, Clone)]
494pub struct AggregateExec {
495    /// Aggregation mode (full, partial)
496    mode: AggregateMode,
497    /// Group by expressions
498    group_by: PhysicalGroupBy,
499    /// Aggregate expressions
500    aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
501    /// FILTER (WHERE clause) expression for each aggregate expression
502    filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
503    /// Set if the output of this aggregation is truncated by a upstream sort/limit clause
504    limit: Option<usize>,
505    /// Input plan, could be a partial aggregate or the input to the aggregate
506    pub input: Arc<dyn ExecutionPlan>,
507    /// Schema after the aggregate is applied
508    schema: SchemaRef,
509    /// Input schema before any aggregation is applied. For partial aggregate this will be the
510    /// same as input.schema() but for the final aggregate it will be the same as the input
511    /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`.
512    /// We need the input schema of partial aggregate to be able to deserialize aggregate
513    /// expressions from protobuf for final aggregate.
514    pub input_schema: SchemaRef,
515    /// Execution metrics
516    metrics: ExecutionPlanMetricsSet,
517    required_input_ordering: Option<OrderingRequirements>,
518    /// Describes how the input is ordered relative to the group by columns
519    input_order_mode: InputOrderMode,
520    cache: PlanProperties,
521    /// During initialization, if the plan supports dynamic filtering (see [`AggrDynFilter`]),
522    /// it is set to `Some(..)` regardless of whether it can be pushed down to a child node.
523    ///
524    /// During filter pushdown optimization, if a child node can accept this filter,
525    /// it remains `Some(..)` to enable dynamic filtering during aggregate execution;
526    /// otherwise, it is cleared to `None`.
527    dynamic_filter: Option<Arc<AggrDynFilter>>,
528}
529
530impl AggregateExec {
531    /// Function used in `OptimizeAggregateOrder` optimizer rule,
532    /// where we need parts of the new value, others cloned from the old one
533    /// Rewrites aggregate exec with new aggregate expressions.
534    pub fn with_new_aggr_exprs(
535        &self,
536        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
537    ) -> Self {
538        Self {
539            aggr_expr,
540            // clone the rest of the fields
541            required_input_ordering: self.required_input_ordering.clone(),
542            metrics: ExecutionPlanMetricsSet::new(),
543            input_order_mode: self.input_order_mode.clone(),
544            cache: self.cache.clone(),
545            mode: self.mode,
546            group_by: self.group_by.clone(),
547            filter_expr: self.filter_expr.clone(),
548            limit: self.limit,
549            input: Arc::clone(&self.input),
550            schema: Arc::clone(&self.schema),
551            input_schema: Arc::clone(&self.input_schema),
552            dynamic_filter: self.dynamic_filter.clone(),
553        }
554    }
555
556    pub fn cache(&self) -> &PlanProperties {
557        &self.cache
558    }
559
560    /// Create a new hash aggregate execution plan
561    pub fn try_new(
562        mode: AggregateMode,
563        group_by: PhysicalGroupBy,
564        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
565        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
566        input: Arc<dyn ExecutionPlan>,
567        input_schema: SchemaRef,
568    ) -> Result<Self> {
569        let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
570
571        let schema = Arc::new(schema);
572        AggregateExec::try_new_with_schema(
573            mode,
574            group_by,
575            aggr_expr,
576            filter_expr,
577            input,
578            input_schema,
579            schema,
580        )
581    }
582
583    /// Create a new hash aggregate execution plan with the given schema.
584    /// This constructor isn't part of the public API, it is used internally
585    /// by DataFusion to enforce schema consistency during when re-creating
586    /// `AggregateExec`s inside optimization rules. Schema field names of an
587    /// `AggregateExec` depends on the names of aggregate expressions. Since
588    /// a rule may re-write aggregate expressions (e.g. reverse them) during
589    /// initialization, field names may change inadvertently if one re-creates
590    /// the schema in such cases.
591    fn try_new_with_schema(
592        mode: AggregateMode,
593        group_by: PhysicalGroupBy,
594        mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
595        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
596        input: Arc<dyn ExecutionPlan>,
597        input_schema: SchemaRef,
598        schema: SchemaRef,
599    ) -> Result<Self> {
600        // Make sure arguments are consistent in size
601        assert_eq_or_internal_err!(
602            aggr_expr.len(),
603            filter_expr.len(),
604            "Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
605            aggr_expr,
606            filter_expr
607        );
608
609        let input_eq_properties = input.equivalence_properties();
610        // Get GROUP BY expressions:
611        let groupby_exprs = group_by.input_exprs();
612        // If existing ordering satisfies a prefix of the GROUP BY expressions,
613        // prefix requirements with this section. In this case, aggregation will
614        // work more efficiently.
615        // Copy the `PhysicalSortExpr`s to retain the sort options.
616        let (new_sort_exprs, indices) =
617            input_eq_properties.find_longest_permutation(&groupby_exprs)?;
618
619        let mut new_requirements = new_sort_exprs
620            .into_iter()
621            .map(PhysicalSortRequirement::from)
622            .collect::<Vec<_>>();
623
624        let req = get_finer_aggregate_exprs_requirement(
625            &mut aggr_expr,
626            &group_by,
627            input_eq_properties,
628            &mode,
629        )?;
630        new_requirements.extend(req);
631
632        let required_input_ordering =
633            LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
634
635        // If our aggregation has grouping sets then our base grouping exprs will
636        // be expanded based on the flags in `group_by.groups` where for each
637        // group we swap the grouping expr for `null` if the flag is `true`
638        // That means that each index in `indices` is valid if and only if
639        // it is not null in every group
640        let indices: Vec<usize> = indices
641            .into_iter()
642            .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
643            .collect();
644
645        let input_order_mode = if indices.len() == groupby_exprs.len()
646            && !indices.is_empty()
647            && group_by.groups.len() == 1
648        {
649            InputOrderMode::Sorted
650        } else if !indices.is_empty() {
651            InputOrderMode::PartiallySorted(indices)
652        } else {
653            InputOrderMode::Linear
654        };
655
656        // construct a map from the input expression to the output expression of the Aggregation group by
657        let group_expr_mapping =
658            ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
659
660        let cache = Self::compute_properties(
661            &input,
662            Arc::clone(&schema),
663            &group_expr_mapping,
664            &mode,
665            &input_order_mode,
666            aggr_expr.as_slice(),
667        )?;
668
669        let mut exec = AggregateExec {
670            mode,
671            group_by,
672            aggr_expr,
673            filter_expr,
674            input,
675            schema,
676            input_schema,
677            metrics: ExecutionPlanMetricsSet::new(),
678            required_input_ordering,
679            limit: None,
680            input_order_mode,
681            cache,
682            dynamic_filter: None,
683        };
684
685        exec.init_dynamic_filter();
686
687        Ok(exec)
688    }
689
690    /// Aggregation mode (full, partial)
691    pub fn mode(&self) -> &AggregateMode {
692        &self.mode
693    }
694
695    /// Set the `limit` of this AggExec
696    pub fn with_limit(mut self, limit: Option<usize>) -> Self {
697        self.limit = limit;
698        self
699    }
700    /// Grouping expressions
701    pub fn group_expr(&self) -> &PhysicalGroupBy {
702        &self.group_by
703    }
704
705    /// Grouping expressions as they occur in the output schema
706    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
707        self.group_by.output_exprs()
708    }
709
710    /// Aggregate expressions
711    pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
712        &self.aggr_expr
713    }
714
715    /// FILTER (WHERE clause) expression for each aggregate expression
716    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
717        &self.filter_expr
718    }
719
720    /// Input plan
721    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
722        &self.input
723    }
724
725    /// Get the input schema before any aggregates are applied
726    pub fn input_schema(&self) -> SchemaRef {
727        Arc::clone(&self.input_schema)
728    }
729
730    /// number of rows soft limit of the AggregateExec
731    pub fn limit(&self) -> Option<usize> {
732        self.limit
733    }
734
735    fn execute_typed(
736        &self,
737        partition: usize,
738        context: &Arc<TaskContext>,
739    ) -> Result<StreamType> {
740        if self.group_by.is_true_no_grouping() {
741            return Ok(StreamType::AggregateStream(AggregateStream::new(
742                self, context, partition,
743            )?));
744        }
745
746        // grouping by an expression that has a sort/limit upstream
747        if let Some(limit) = self.limit
748            && !self.is_unordered_unfiltered_group_by_distinct()
749        {
750            return Ok(StreamType::GroupedPriorityQueue(
751                GroupedTopKAggregateStream::new(self, context, partition, limit)?,
752            ));
753        }
754
755        // grouping by something else and we need to just materialize all results
756        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
757            self, context, partition,
758        )?))
759    }
760
761    /// Finds the DataType and SortDirection for this Aggregate, if there is one
762    pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
763        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
764        agg_expr.get_minmax_desc()
765    }
766
767    /// true, if this Aggregate has a group-by with no required or explicit ordering,
768    /// no filtering and no aggregate expressions
769    /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule
770    /// on an AggregateExec.
771    pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
772        // ensure there is a group by
773        if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
774            return false;
775        }
776        // ensure there are no aggregate expressions
777        if !self.aggr_expr().is_empty() {
778            return false;
779        }
780        // ensure there are no filters on aggregate expressions; the above check
781        // may preclude this case
782        if self.filter_expr().iter().any(|e| e.is_some()) {
783            return false;
784        }
785        // ensure there are no order by expressions
786        if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
787            return false;
788        }
789        // ensure there is no output ordering; can this rule be relaxed?
790        if self.properties().output_ordering().is_some() {
791            return false;
792        }
793        // ensure no ordering is required on the input
794        if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
795            return matches!(requirement, OrderingRequirements::Hard(_));
796        }
797        true
798    }
799
800    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
801    pub fn compute_properties(
802        input: &Arc<dyn ExecutionPlan>,
803        schema: SchemaRef,
804        group_expr_mapping: &ProjectionMapping,
805        mode: &AggregateMode,
806        input_order_mode: &InputOrderMode,
807        aggr_exprs: &[Arc<AggregateFunctionExpr>],
808    ) -> Result<PlanProperties> {
809        // Construct equivalence properties:
810        let mut eq_properties = input
811            .equivalence_properties()
812            .project(group_expr_mapping, schema);
813
814        // If the group by is empty, then we ensure that the operator will produce
815        // only one row, and mark the generated result as a constant value.
816        if group_expr_mapping.is_empty() {
817            let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
818                let column = Arc::new(Column::new(func.name(), idx));
819                ConstExpr::from(column as Arc<dyn PhysicalExpr>)
820            });
821            eq_properties.add_constants(new_constants)?;
822        }
823
824        // Group by expression will be a distinct value after the aggregation.
825        // Add it into the constraint set.
826        let mut constraints = eq_properties.constraints().to_vec();
827        let new_constraint = Constraint::Unique(
828            group_expr_mapping
829                .iter()
830                .flat_map(|(_, target_cols)| {
831                    target_cols.iter().flat_map(|(expr, _)| {
832                        expr.as_any().downcast_ref::<Column>().map(|c| c.index())
833                    })
834                })
835                .collect(),
836        );
837        constraints.push(new_constraint);
838        eq_properties =
839            eq_properties.with_constraints(Constraints::new_unverified(constraints));
840
841        // Get output partitioning:
842        let input_partitioning = input.output_partitioning().clone();
843        let output_partitioning = if mode.is_first_stage() {
844            // First stage aggregation will not change the output partitioning,
845            // but needs to respect aliases (e.g. mapping in the GROUP BY
846            // expression).
847            let input_eq_properties = input.equivalence_properties();
848            input_partitioning.project(group_expr_mapping, input_eq_properties)
849        } else {
850            input_partitioning.clone()
851        };
852
853        // TODO: Emission type and boundedness information can be enhanced here
854        let emission_type = if *input_order_mode == InputOrderMode::Linear {
855            EmissionType::Final
856        } else {
857            input.pipeline_behavior()
858        };
859
860        Ok(PlanProperties::new(
861            eq_properties,
862            output_partitioning,
863            emission_type,
864            input.boundedness(),
865        ))
866    }
867
868    pub fn input_order_mode(&self) -> &InputOrderMode {
869        &self.input_order_mode
870    }
871
872    fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
873        // TODO stats: group expressions:
874        // - once expressions will be able to compute their own stats, use it here
875        // - case where we group by on a column for which with have the `distinct` stat
876        // TODO stats: aggr expression:
877        // - aggregations sometimes also preserve invariants such as min, max...
878
879        let column_statistics = {
880            // self.schema: [<group by exprs>, <aggregate exprs>]
881            let mut column_statistics = Statistics::unknown_column(&self.schema());
882
883            for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
884                if let Some(col) = expr.as_any().downcast_ref::<Column>() {
885                    column_statistics[idx].max_value = child_statistics.column_statistics
886                        [col.index()]
887                    .max_value
888                    .clone();
889
890                    column_statistics[idx].min_value = child_statistics.column_statistics
891                        [col.index()]
892                    .min_value
893                    .clone();
894                }
895            }
896
897            column_statistics
898        };
899        match self.mode {
900            AggregateMode::Final | AggregateMode::FinalPartitioned
901                if self.group_by.expr.is_empty() =>
902            {
903                let total_byte_size =
904                    Self::calculate_scaled_byte_size(child_statistics, 1);
905
906                Ok(Statistics {
907                    num_rows: Precision::Exact(1),
908                    column_statistics,
909                    total_byte_size,
910                })
911            }
912            _ => {
913                // When the input row count is 1, we can adopt that statistic keeping its reliability.
914                // When it is larger than 1, we degrade the precision since it may decrease after aggregation.
915                let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
916                {
917                    if *value > 1 {
918                        child_statistics.num_rows.to_inexact()
919                    } else if *value == 0 {
920                        child_statistics.num_rows
921                    } else {
922                        // num_rows = 1 case
923                        let grouping_set_num = self.group_by.groups.len();
924                        child_statistics.num_rows.map(|x| x * grouping_set_num)
925                    }
926                } else {
927                    Precision::Absent
928                };
929
930                let total_byte_size = num_rows
931                    .get_value()
932                    .and_then(|&output_rows| {
933                        Self::calculate_scaled_byte_size(child_statistics, output_rows)
934                            .get_value()
935                            .map(|&bytes| Precision::Inexact(bytes))
936                    })
937                    .unwrap_or(Precision::Absent);
938
939                Ok(Statistics {
940                    num_rows,
941                    column_statistics,
942                    total_byte_size,
943                })
944            }
945        }
946    }
947
948    /// Check if dynamic filter is possible for the current plan node.
949    /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
950    /// - If not supported, `self.dynamic_filter` should be kept `None`
951    fn init_dynamic_filter(&mut self) {
952        if (!self.group_by.is_empty()) || (!matches!(self.mode, AggregateMode::Partial)) {
953            debug_assert!(
954                self.dynamic_filter.is_none(),
955                "The current operator node does not support dynamic filter"
956            );
957            return;
958        }
959
960        // Already initialized.
961        if self.dynamic_filter.is_some() {
962            return;
963        }
964
965        // Collect supported accumulators
966        // It is assumed the order of aggregate expressions are not changed from `AggregateExec`
967        // to `AggregateStream`
968        let mut aggr_dyn_filters = Vec::new();
969        // All column references in the dynamic filter, used when initializing the dynamic
970        // filter, and it's used to decide if this dynamic filter is able to get push
971        // through certain node during optimization.
972        let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
973        for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
974            // 1. Only `min` or `max` aggregate function
975            let fun_name = aggr_expr.fun().name();
976            // HACK: Should check the function type more precisely
977            // Issue: <https://github.com/apache/datafusion/issues/18643>
978            let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
979                DynamicFilterAggregateType::Min
980            } else if fun_name.eq_ignore_ascii_case("max") {
981                DynamicFilterAggregateType::Max
982            } else {
983                continue;
984            };
985
986            // 2. arg should be only 1 column reference
987            if let [arg] = aggr_expr.expressions().as_slice()
988                && arg.as_any().is::<Column>()
989            {
990                all_cols.push(Arc::clone(arg));
991                aggr_dyn_filters.push(PerAccumulatorDynFilter {
992                    aggr_type,
993                    aggr_index: i,
994                    shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
995                });
996            }
997        }
998
999        if !aggr_dyn_filters.is_empty() {
1000            self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1001                filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1002                supported_accumulators_info: aggr_dyn_filters,
1003            }))
1004        }
1005    }
1006
1007    /// Calculate scaled byte size based on row count ratio.
1008    /// Returns `Precision::Absent` if input statistics are insufficient.
1009    /// Returns `Precision::Inexact` with the scaled value otherwise.
1010    ///
1011    /// This is a simple heuristic that assumes uniform row sizes.
1012    #[inline]
1013    fn calculate_scaled_byte_size(
1014        input_stats: &Statistics,
1015        target_row_count: usize,
1016    ) -> Precision<usize> {
1017        match (
1018            input_stats.num_rows.get_value(),
1019            input_stats.total_byte_size.get_value(),
1020        ) {
1021            (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
1022                let bytes_per_row = input_bytes as f64 / input_rows as f64;
1023                let scaled_bytes =
1024                    (bytes_per_row * target_row_count as f64).ceil() as usize;
1025                Precision::Inexact(scaled_bytes)
1026            }
1027            _ => Precision::Absent,
1028        }
1029    }
1030}
1031
1032impl DisplayAs for AggregateExec {
1033    fn fmt_as(
1034        &self,
1035        t: DisplayFormatType,
1036        f: &mut std::fmt::Formatter,
1037    ) -> std::fmt::Result {
1038        match t {
1039            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1040                let format_expr_with_alias =
1041                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1042                        let e = e.to_string();
1043                        if &e != alias {
1044                            format!("{e} as {alias}")
1045                        } else {
1046                            e
1047                        }
1048                    };
1049
1050                write!(f, "AggregateExec: mode={:?}", self.mode)?;
1051                let g: Vec<String> = if self.group_by.is_single() {
1052                    self.group_by
1053                        .expr
1054                        .iter()
1055                        .map(format_expr_with_alias)
1056                        .collect()
1057                } else {
1058                    self.group_by
1059                        .groups
1060                        .iter()
1061                        .map(|group| {
1062                            let terms = group
1063                                .iter()
1064                                .enumerate()
1065                                .map(|(idx, is_null)| {
1066                                    if *is_null {
1067                                        format_expr_with_alias(
1068                                            &self.group_by.null_expr[idx],
1069                                        )
1070                                    } else {
1071                                        format_expr_with_alias(&self.group_by.expr[idx])
1072                                    }
1073                                })
1074                                .collect::<Vec<String>>()
1075                                .join(", ");
1076                            format!("({terms})")
1077                        })
1078                        .collect()
1079                };
1080
1081                write!(f, ", gby=[{}]", g.join(", "))?;
1082
1083                let a: Vec<String> = self
1084                    .aggr_expr
1085                    .iter()
1086                    .map(|agg| agg.name().to_string())
1087                    .collect();
1088                write!(f, ", aggr=[{}]", a.join(", "))?;
1089                if let Some(limit) = self.limit {
1090                    write!(f, ", lim=[{limit}]")?;
1091                }
1092
1093                if self.input_order_mode != InputOrderMode::Linear {
1094                    write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
1095                }
1096            }
1097            DisplayFormatType::TreeRender => {
1098                let format_expr_with_alias =
1099                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1100                        let expr_sql = fmt_sql(e.as_ref()).to_string();
1101                        if &expr_sql != alias {
1102                            format!("{expr_sql} as {alias}")
1103                        } else {
1104                            expr_sql
1105                        }
1106                    };
1107
1108                let g: Vec<String> = if self.group_by.is_single() {
1109                    self.group_by
1110                        .expr
1111                        .iter()
1112                        .map(format_expr_with_alias)
1113                        .collect()
1114                } else {
1115                    self.group_by
1116                        .groups
1117                        .iter()
1118                        .map(|group| {
1119                            let terms = group
1120                                .iter()
1121                                .enumerate()
1122                                .map(|(idx, is_null)| {
1123                                    if *is_null {
1124                                        format_expr_with_alias(
1125                                            &self.group_by.null_expr[idx],
1126                                        )
1127                                    } else {
1128                                        format_expr_with_alias(&self.group_by.expr[idx])
1129                                    }
1130                                })
1131                                .collect::<Vec<String>>()
1132                                .join(", ");
1133                            format!("({terms})")
1134                        })
1135                        .collect()
1136                };
1137                let a: Vec<String> = self
1138                    .aggr_expr
1139                    .iter()
1140                    .map(|agg| agg.human_display().to_string())
1141                    .collect();
1142                writeln!(f, "mode={:?}", self.mode)?;
1143                if !g.is_empty() {
1144                    writeln!(f, "group_by={}", g.join(", "))?;
1145                }
1146                if !a.is_empty() {
1147                    writeln!(f, "aggr={}", a.join(", "))?;
1148                }
1149            }
1150        }
1151        Ok(())
1152    }
1153}
1154
1155impl ExecutionPlan for AggregateExec {
1156    fn name(&self) -> &'static str {
1157        "AggregateExec"
1158    }
1159
1160    /// Return a reference to Any that can be used for down-casting
1161    fn as_any(&self) -> &dyn Any {
1162        self
1163    }
1164
1165    fn properties(&self) -> &PlanProperties {
1166        &self.cache
1167    }
1168
1169    fn required_input_distribution(&self) -> Vec<Distribution> {
1170        match &self.mode {
1171            AggregateMode::Partial => {
1172                vec![Distribution::UnspecifiedDistribution]
1173            }
1174            AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
1175                vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
1176            }
1177            AggregateMode::Final | AggregateMode::Single => {
1178                vec![Distribution::SinglePartition]
1179            }
1180        }
1181    }
1182
1183    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
1184        vec![self.required_input_ordering.clone()]
1185    }
1186
1187    /// The output ordering of [`AggregateExec`] is determined by its `group_by`
1188    /// columns. Although this method is not explicitly used by any optimizer
1189    /// rules yet, overriding the default implementation ensures that it
1190    /// accurately reflects the actual behavior.
1191    ///
1192    /// If the [`InputOrderMode`] is `Linear`, the `group_by` columns don't have
1193    /// an ordering, which means the results do not either. However, in the
1194    /// `Ordered` and `PartiallyOrdered` cases, the `group_by` columns do have
1195    /// an ordering, which is preserved in the output.
1196    fn maintains_input_order(&self) -> Vec<bool> {
1197        vec![self.input_order_mode != InputOrderMode::Linear]
1198    }
1199
1200    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1201        vec![&self.input]
1202    }
1203
1204    fn with_new_children(
1205        self: Arc<Self>,
1206        children: Vec<Arc<dyn ExecutionPlan>>,
1207    ) -> Result<Arc<dyn ExecutionPlan>> {
1208        let mut me = AggregateExec::try_new_with_schema(
1209            self.mode,
1210            self.group_by.clone(),
1211            self.aggr_expr.clone(),
1212            self.filter_expr.clone(),
1213            Arc::clone(&children[0]),
1214            Arc::clone(&self.input_schema),
1215            Arc::clone(&self.schema),
1216        )?;
1217        me.limit = self.limit;
1218        me.dynamic_filter = self.dynamic_filter.clone();
1219
1220        Ok(Arc::new(me))
1221    }
1222
1223    fn execute(
1224        &self,
1225        partition: usize,
1226        context: Arc<TaskContext>,
1227    ) -> Result<SendableRecordBatchStream> {
1228        self.execute_typed(partition, &context)
1229            .map(|stream| stream.into())
1230    }
1231
1232    fn metrics(&self) -> Option<MetricsSet> {
1233        Some(self.metrics.clone_inner())
1234    }
1235
1236    fn statistics(&self) -> Result<Statistics> {
1237        self.partition_statistics(None)
1238    }
1239
1240    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1241        let child_statistics = self.input().partition_statistics(partition)?;
1242        self.statistics_inner(&child_statistics)
1243    }
1244
1245    fn cardinality_effect(&self) -> CardinalityEffect {
1246        CardinalityEffect::LowerEqual
1247    }
1248
1249    /// Push down parent filters when possible (see implementation comment for details),
1250    /// and also pushdown self dynamic filters (see `AggrDynFilter` for details)
1251    fn gather_filters_for_pushdown(
1252        &self,
1253        phase: FilterPushdownPhase,
1254        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1255        config: &ConfigOptions,
1256    ) -> Result<FilterDescription> {
1257        // It's safe to push down filters through aggregates when filters only reference
1258        // grouping columns, because such filters determine which groups to compute, not
1259        // *how* to compute them. Each group's aggregate values (SUM, COUNT, etc.) are
1260        // calculated from the same input rows regardless of whether we filter before or
1261        // after grouping - filtering before just eliminates entire groups early.
1262        // This optimization is NOT safe for filters on aggregated columns (like filtering on
1263        // the result of SUM or COUNT), as those require computing all groups first.
1264
1265        let grouping_columns: HashSet<_> = self
1266            .group_by
1267            .expr()
1268            .iter()
1269            .flat_map(|(expr, _)| collect_columns(expr))
1270            .collect();
1271
1272        // Analyze each filter separately to determine if it can be pushed down
1273        let mut safe_filters = Vec::new();
1274        let mut unsafe_filters = Vec::new();
1275
1276        for filter in parent_filters {
1277            let filter_columns: HashSet<_> =
1278                collect_columns(&filter).into_iter().collect();
1279
1280            // Check if this filter references non-grouping columns
1281            let references_non_grouping = !grouping_columns.is_empty()
1282                && !filter_columns.is_subset(&grouping_columns);
1283
1284            if references_non_grouping {
1285                unsafe_filters.push(filter);
1286                continue;
1287            }
1288
1289            // For GROUPING SETS, verify this filter's columns appear in all grouping sets
1290            if self.group_by.groups().len() > 1 {
1291                let filter_column_indices: Vec<usize> = filter_columns
1292                    .iter()
1293                    .filter_map(|filter_col| {
1294                        self.group_by.expr().iter().position(|(expr, _)| {
1295                            collect_columns(expr).contains(filter_col)
1296                        })
1297                    })
1298                    .collect();
1299
1300                // Check if any of this filter's columns are missing from any grouping set
1301                let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1302                    filter_column_indices
1303                        .iter()
1304                        .any(|&idx| null_mask.get(idx) == Some(&true))
1305                });
1306
1307                if has_missing_column {
1308                    unsafe_filters.push(filter);
1309                    continue;
1310                }
1311            }
1312
1313            // This filter is safe to push down
1314            safe_filters.push(filter);
1315        }
1316
1317        // Build child filter description with both safe and unsafe filters
1318        let child = self.children()[0];
1319        let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1320
1321        // Add unsafe filters as unsupported
1322        child_desc.parent_filters.extend(
1323            unsafe_filters
1324                .into_iter()
1325                .map(PushedDownPredicate::unsupported),
1326        );
1327
1328        // Include self dynamic filter when it's possible
1329        if matches!(phase, FilterPushdownPhase::Post)
1330            && config.optimizer.enable_aggregate_dynamic_filter_pushdown
1331            && let Some(self_dyn_filter) = &self.dynamic_filter
1332        {
1333            let dyn_filter = Arc::clone(&self_dyn_filter.filter);
1334            child_desc = child_desc.with_self_filter(dyn_filter);
1335        }
1336
1337        Ok(FilterDescription::new().with_child(child_desc))
1338    }
1339
1340    /// If child accepts self's dynamic filter, keep `self.dynamic_filter` with Some,
1341    /// otherwise clear it to None.
1342    fn handle_child_pushdown_result(
1343        &self,
1344        phase: FilterPushdownPhase,
1345        child_pushdown_result: ChildPushdownResult,
1346        _config: &ConfigOptions,
1347    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1348        let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
1349
1350        // If this node tried to pushdown some dynamic filter before, now we check
1351        // if the child accept the filter
1352        if matches!(phase, FilterPushdownPhase::Post) && self.dynamic_filter.is_some() {
1353            // let child_accepts_dyn_filter = child_pushdown_result
1354            //     .self_filters
1355            //     .first()
1356            //     .map(|filters| {
1357            //         assert_eq_or_internal_err!(
1358            //             filters.len(),
1359            //             1,
1360            //             "Aggregate only pushdown one self dynamic filter"
1361            //         );
1362            //         let filter = filters.get(0).unwrap(); // Asserted above
1363            //         Ok(matches!(filter.discriminant, PushedDown::Yes))
1364            //     })
1365            //     .unwrap_or_else(|| internal_err!("The length of self filters equals to the number of child of this ExecutionPlan, so it must be 1"))?;
1366
1367            // HACK: The above snippet should be used, however, now the child reply
1368            // `PushDown::No` can indicate they're not able to push down row-level
1369            // filter, but still keep the filter for statistics pruning.
1370            // So here, we try to use ref count to determine if the dynamic filter
1371            // has actually be pushed down.
1372            // Issue: <https://github.com/apache/datafusion/issues/18856>
1373            let dyn_filter = self.dynamic_filter.as_ref().unwrap();
1374            let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
1375
1376            if !child_accepts_dyn_filter {
1377                // Child can't consume the self dynamic filter, so disable it by setting
1378                // to `None`
1379                let mut new_node = self.clone();
1380                new_node.dynamic_filter = None;
1381
1382                result = result
1383                    .with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
1384            }
1385        }
1386
1387        Ok(result)
1388    }
1389}
1390
1391fn create_schema(
1392    input_schema: &Schema,
1393    group_by: &PhysicalGroupBy,
1394    aggr_expr: &[Arc<AggregateFunctionExpr>],
1395    mode: AggregateMode,
1396) -> Result<Schema> {
1397    let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1398    fields.extend(group_by.output_fields(input_schema)?);
1399
1400    match mode {
1401        AggregateMode::Partial => {
1402            // in partial mode, the fields of the accumulator's state
1403            for expr in aggr_expr {
1404                fields.extend(expr.state_fields()?.iter().cloned());
1405            }
1406        }
1407        AggregateMode::Final
1408        | AggregateMode::FinalPartitioned
1409        | AggregateMode::Single
1410        | AggregateMode::SinglePartitioned => {
1411            // in final mode, the field with the final result of the accumulator
1412            for expr in aggr_expr {
1413                fields.push(expr.field())
1414            }
1415        }
1416    }
1417
1418    Ok(Schema::new_with_metadata(
1419        fields,
1420        input_schema.metadata().clone(),
1421    ))
1422}
1423
1424/// Determines the lexical ordering requirement for an aggregate expression.
1425///
1426/// # Parameters
1427///
1428/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the
1429///   aggregate expression.
1430/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1431///   physical GROUP BY expression.
1432/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1433///   mode of aggregation.
1434/// - `include_soft_requirement`: When `false`, only hard requirements are
1435///   considered, as indicated by [`AggregateFunctionExpr::order_sensitivity`]
1436///   returning [`AggregateOrderSensitivity::HardRequirement`].
1437///   Otherwise, also soft requirements ([`AggregateOrderSensitivity::SoftRequirement`])
1438///   are considered.
1439///
1440/// # Returns
1441///
1442/// A `LexOrdering` instance indicating the lexical ordering requirement for
1443/// the aggregate expression.
1444fn get_aggregate_expr_req(
1445    aggr_expr: &AggregateFunctionExpr,
1446    group_by: &PhysicalGroupBy,
1447    agg_mode: &AggregateMode,
1448    include_soft_requirement: bool,
1449) -> Option<LexOrdering> {
1450    // If the aggregation is performing a "second stage" calculation,
1451    // then ignore the ordering requirement. Ordering requirement applies
1452    // only to the aggregation input data.
1453    if !agg_mode.is_first_stage() {
1454        return None;
1455    }
1456
1457    match aggr_expr.order_sensitivity() {
1458        AggregateOrderSensitivity::Insensitive => return None,
1459        AggregateOrderSensitivity::HardRequirement => {}
1460        AggregateOrderSensitivity::SoftRequirement => {
1461            if !include_soft_requirement {
1462                return None;
1463            }
1464        }
1465        AggregateOrderSensitivity::Beneficial => return None,
1466    }
1467
1468    let mut sort_exprs = aggr_expr.order_bys().to_vec();
1469    // In non-first stage modes, we accumulate data (using `merge_batch`) from
1470    // different partitions (i.e. merge partial results). During this merge, we
1471    // consider the ordering of each partial result. Hence, we do not need to
1472    // use the ordering requirement in such modes as long as partial results are
1473    // generated with the correct ordering.
1474    if group_by.is_single() {
1475        // Remove all orderings that occur in the group by. These requirements
1476        // will definitely be satisfied -- Each group by expression will have
1477        // distinct values per group, hence all requirements are satisfied.
1478        let physical_exprs = group_by.input_exprs();
1479        sort_exprs.retain(|sort_expr| {
1480            !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1481        });
1482    }
1483    LexOrdering::new(sort_exprs)
1484}
1485
1486/// Concatenates the given slices.
1487pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1488    [lhs, rhs].concat()
1489}
1490
1491// Determines if the candidate ordering is finer than the current ordering.
1492// Returns `None` if they are incomparable, `Some(true)` if there is no current
1493// ordering or candidate ordering is finer, and `Some(false)` otherwise.
1494fn determine_finer(
1495    current: &Option<LexOrdering>,
1496    candidate: &LexOrdering,
1497) -> Option<bool> {
1498    if let Some(ordering) = current {
1499        candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1500    } else {
1501        Some(true)
1502    }
1503}
1504
1505/// Gets the common requirement that satisfies all the aggregate expressions.
1506/// When possible, chooses the requirement that is already satisfied by the
1507/// equivalence properties.
1508///
1509/// # Parameters
1510///
1511/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the
1512///   aggregate expressions.
1513/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1514///   physical GROUP BY expression.
1515/// - `eq_properties`: A reference to an `EquivalenceProperties` instance
1516///   representing equivalence properties for ordering.
1517/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1518///   mode of aggregation.
1519///
1520/// # Returns
1521///
1522/// A `Result<Vec<PhysicalSortRequirement>>` instance, which is the requirement
1523/// that satisfies all the aggregate requirements. Returns an error in case of
1524/// conflicting requirements.
1525pub fn get_finer_aggregate_exprs_requirement(
1526    aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1527    group_by: &PhysicalGroupBy,
1528    eq_properties: &EquivalenceProperties,
1529    agg_mode: &AggregateMode,
1530) -> Result<Vec<PhysicalSortRequirement>> {
1531    let mut requirement = None;
1532
1533    // First try and find a match for all hard and soft requirements.
1534    // If a match can't be found, try a second time just matching hard
1535    // requirements.
1536    for include_soft_requirement in [false, true] {
1537        for aggr_expr in aggr_exprs.iter_mut() {
1538            let Some(aggr_req) = get_aggregate_expr_req(
1539                aggr_expr,
1540                group_by,
1541                agg_mode,
1542                include_soft_requirement,
1543            )
1544            .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1545                // There is no aggregate ordering requirement, or it is trivially
1546                // satisfied -- we can skip this expression.
1547                continue;
1548            };
1549            // If the common requirement is finer than the current expression's,
1550            // we can skip this expression. If the latter is finer than the former,
1551            // adopt it if it is satisfied by the equivalence properties. Otherwise,
1552            // defer the analysis to the reverse expression.
1553            let forward_finer = determine_finer(&requirement, &aggr_req);
1554            if let Some(finer) = forward_finer {
1555                if !finer {
1556                    continue;
1557                } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1558                    requirement = Some(aggr_req);
1559                    continue;
1560                }
1561            }
1562            if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1563                let Some(rev_aggr_req) = get_aggregate_expr_req(
1564                    &reverse_aggr_expr,
1565                    group_by,
1566                    agg_mode,
1567                    include_soft_requirement,
1568                )
1569                .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1570                    // The reverse requirement is trivially satisfied -- just reverse
1571                    // the expression and continue with the next one:
1572                    *aggr_expr = Arc::new(reverse_aggr_expr);
1573                    continue;
1574                };
1575                // If the common requirement is finer than the reverse expression's,
1576                // just reverse it and continue the loop with the next aggregate
1577                // expression. If the latter is finer than the former, adopt it if
1578                // it is satisfied by the equivalence properties. Otherwise, adopt
1579                // the forward expression.
1580                if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1581                    if !finer {
1582                        *aggr_expr = Arc::new(reverse_aggr_expr);
1583                    } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1584                        *aggr_expr = Arc::new(reverse_aggr_expr);
1585                        requirement = Some(rev_aggr_req);
1586                    } else {
1587                        requirement = Some(aggr_req);
1588                    }
1589                } else if forward_finer.is_some() {
1590                    requirement = Some(aggr_req);
1591                } else {
1592                    // Neither the existing requirement nor the current aggregate
1593                    // requirement satisfy the other (forward or reverse), this
1594                    // means they are conflicting. This is a problem only for hard
1595                    // requirements. Unsatisfied soft requirements can be ignored.
1596                    if !include_soft_requirement {
1597                        return not_impl_err!(
1598                            "Conflicting ordering requirements in aggregate functions is not supported"
1599                        );
1600                    }
1601                }
1602            }
1603        }
1604    }
1605
1606    Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1607}
1608
1609/// Returns physical expressions for arguments to evaluate against a batch.
1610///
1611/// The expressions are different depending on `mode`:
1612/// * Partial: AggregateFunctionExpr::expressions
1613/// * Final: columns of `AggregateFunctionExpr::state_fields()`
1614pub fn aggregate_expressions(
1615    aggr_expr: &[Arc<AggregateFunctionExpr>],
1616    mode: &AggregateMode,
1617    col_idx_base: usize,
1618) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1619    match mode {
1620        AggregateMode::Partial
1621        | AggregateMode::Single
1622        | AggregateMode::SinglePartitioned => Ok(aggr_expr
1623            .iter()
1624            .map(|agg| {
1625                let mut result = agg.expressions();
1626                // Append ordering requirements to expressions' results. This
1627                // way order sensitive aggregators can satisfy requirement
1628                // themselves.
1629                result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
1630                result
1631            })
1632            .collect()),
1633        // In this mode, we build the merge expressions of the aggregation.
1634        AggregateMode::Final | AggregateMode::FinalPartitioned => {
1635            let mut col_idx_base = col_idx_base;
1636            aggr_expr
1637                .iter()
1638                .map(|agg| {
1639                    let exprs = merge_expressions(col_idx_base, agg)?;
1640                    col_idx_base += exprs.len();
1641                    Ok(exprs)
1642                })
1643                .collect()
1644        }
1645    }
1646}
1647
1648/// uses `state_fields` to build a vec of physical column expressions required to merge the
1649/// AggregateFunctionExpr' accumulator's state.
1650///
1651/// `index_base` is the starting physical column index for the next expanded state field.
1652fn merge_expressions(
1653    index_base: usize,
1654    expr: &AggregateFunctionExpr,
1655) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1656    expr.state_fields().map(|fields| {
1657        fields
1658            .iter()
1659            .enumerate()
1660            .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1661            .collect()
1662    })
1663}
1664
1665pub type AccumulatorItem = Box<dyn Accumulator>;
1666
1667pub fn create_accumulators(
1668    aggr_expr: &[Arc<AggregateFunctionExpr>],
1669) -> Result<Vec<AccumulatorItem>> {
1670    aggr_expr
1671        .iter()
1672        .map(|expr| expr.create_accumulator())
1673        .collect()
1674}
1675
1676/// returns a vector of ArrayRefs, where each entry corresponds to either the
1677/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
1678pub fn finalize_aggregation(
1679    accumulators: &mut [AccumulatorItem],
1680    mode: &AggregateMode,
1681) -> Result<Vec<ArrayRef>> {
1682    match mode {
1683        AggregateMode::Partial => {
1684            // Build the vector of states
1685            accumulators
1686                .iter_mut()
1687                .map(|accumulator| {
1688                    accumulator.state().and_then(|e| {
1689                        e.iter()
1690                            .map(|v| v.to_array())
1691                            .collect::<Result<Vec<ArrayRef>>>()
1692                    })
1693                })
1694                .flatten_ok()
1695                .collect()
1696        }
1697        AggregateMode::Final
1698        | AggregateMode::FinalPartitioned
1699        | AggregateMode::Single
1700        | AggregateMode::SinglePartitioned => {
1701            // Merge the state to the final value
1702            accumulators
1703                .iter_mut()
1704                .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1705                .collect()
1706        }
1707    }
1708}
1709
1710/// Evaluates groups of expressions against a record batch.
1711pub fn evaluate_many(
1712    expr: &[Vec<Arc<dyn PhysicalExpr>>],
1713    batch: &RecordBatch,
1714) -> Result<Vec<Vec<ArrayRef>>> {
1715    expr.iter()
1716        .map(|expr| evaluate_expressions_to_arrays(expr, batch))
1717        .collect()
1718}
1719
1720fn evaluate_optional(
1721    expr: &[Option<Arc<dyn PhysicalExpr>>],
1722    batch: &RecordBatch,
1723) -> Result<Vec<Option<ArrayRef>>> {
1724    expr.iter()
1725        .map(|expr| {
1726            expr.as_ref()
1727                .map(|expr| {
1728                    expr.evaluate(batch)
1729                        .and_then(|v| v.into_array(batch.num_rows()))
1730                })
1731                .transpose()
1732        })
1733        .collect()
1734}
1735
1736fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1737    if group.len() > 64 {
1738        return not_impl_err!(
1739            "Grouping sets with more than 64 columns are not supported"
1740        );
1741    }
1742    let group_id = group.iter().fold(0u64, |acc, &is_null| {
1743        (acc << 1) | if is_null { 1 } else { 0 }
1744    });
1745    let num_rows = batch.num_rows();
1746    if group.len() <= 8 {
1747        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1748    } else if group.len() <= 16 {
1749        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1750    } else if group.len() <= 32 {
1751        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1752    } else {
1753        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1754    }
1755}
1756
1757/// Evaluate a group by expression against a `RecordBatch`
1758///
1759/// Arguments:
1760/// - `group_by`: the expression to evaluate
1761/// - `batch`: the `RecordBatch` to evaluate against
1762///
1763/// Returns: A Vec of Vecs of Array of results
1764/// The outer Vec appears to be for grouping sets
1765/// The inner Vec contains the results per expression
1766/// The inner-inner Array contains the results per row
1767pub fn evaluate_group_by(
1768    group_by: &PhysicalGroupBy,
1769    batch: &RecordBatch,
1770) -> Result<Vec<Vec<ArrayRef>>> {
1771    let exprs = evaluate_expressions_to_arrays(
1772        group_by.expr.iter().map(|(expr, _)| expr),
1773        batch,
1774    )?;
1775    let null_exprs = evaluate_expressions_to_arrays(
1776        group_by.null_expr.iter().map(|(expr, _)| expr),
1777        batch,
1778    )?;
1779
1780    group_by
1781        .groups
1782        .iter()
1783        .map(|group| {
1784            let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1785            group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1786                if *is_null {
1787                    Arc::clone(&null_exprs[idx])
1788                } else {
1789                    Arc::clone(&exprs[idx])
1790                }
1791            }));
1792            if !group_by.is_single() {
1793                group_values.push(group_id_array(group, batch)?);
1794            }
1795            Ok(group_values)
1796        })
1797        .collect()
1798}
1799
1800#[cfg(test)]
1801mod tests {
1802    use std::task::{Context, Poll};
1803
1804    use super::*;
1805    use crate::RecordBatchStream;
1806    use crate::coalesce_batches::CoalesceBatchesExec;
1807    use crate::coalesce_partitions::CoalescePartitionsExec;
1808    use crate::common;
1809    use crate::common::collect;
1810    use crate::execution_plan::Boundedness;
1811    use crate::expressions::col;
1812    use crate::metrics::MetricValue;
1813    use crate::test::TestMemoryExec;
1814    use crate::test::assert_is_pending;
1815    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
1816
1817    use arrow::array::{
1818        DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
1819        UInt32Array, UInt64Array,
1820    };
1821    use arrow::compute::{SortOptions, concat_batches};
1822    use arrow::datatypes::{DataType, Int32Type};
1823    use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
1824    use datafusion_common::{DataFusionError, ScalarValue, internal_err};
1825    use datafusion_execution::config::SessionConfig;
1826    use datafusion_execution::memory_pool::FairSpillPool;
1827    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1828    use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1829    use datafusion_functions_aggregate::average::avg_udaf;
1830    use datafusion_functions_aggregate::count::count_udaf;
1831    use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1832    use datafusion_functions_aggregate::median::median_udaf;
1833    use datafusion_functions_aggregate::sum::sum_udaf;
1834    use datafusion_physical_expr::Partitioning;
1835    use datafusion_physical_expr::PhysicalSortExpr;
1836    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1837    use datafusion_physical_expr::expressions::Literal;
1838    use datafusion_physical_expr::expressions::lit;
1839
1840    use crate::projection::ProjectionExec;
1841    use datafusion_physical_expr::projection::ProjectionExpr;
1842    use futures::{FutureExt, Stream};
1843    use insta::{allow_duplicates, assert_snapshot};
1844
1845    // Generate a schema which consists of 5 columns (a, b, c, d, e)
1846    fn create_test_schema() -> Result<SchemaRef> {
1847        let a = Field::new("a", DataType::Int32, true);
1848        let b = Field::new("b", DataType::Int32, true);
1849        let c = Field::new("c", DataType::Int32, true);
1850        let d = Field::new("d", DataType::Int32, true);
1851        let e = Field::new("e", DataType::Int32, true);
1852        let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1853
1854        Ok(schema)
1855    }
1856
1857    /// some mock data to aggregates
1858    fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1859        // define a schema.
1860        let schema = Arc::new(Schema::new(vec![
1861            Field::new("a", DataType::UInt32, false),
1862            Field::new("b", DataType::Float64, false),
1863        ]));
1864
1865        // define data.
1866        (
1867            Arc::clone(&schema),
1868            vec![
1869                RecordBatch::try_new(
1870                    Arc::clone(&schema),
1871                    vec![
1872                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1873                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1874                    ],
1875                )
1876                .unwrap(),
1877                RecordBatch::try_new(
1878                    schema,
1879                    vec![
1880                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1881                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1882                    ],
1883                )
1884                .unwrap(),
1885            ],
1886        )
1887    }
1888
1889    /// Generates some mock data for aggregate tests.
1890    fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1891        // Define a schema:
1892        let schema = Arc::new(Schema::new(vec![
1893            Field::new("a", DataType::UInt32, false),
1894            Field::new("b", DataType::Float64, false),
1895        ]));
1896
1897        // Generate data so that first and last value results are at 2nd and
1898        // 3rd partitions.  With this construction, we guarantee we don't receive
1899        // the expected result by accident, but merging actually works properly;
1900        // i.e. it doesn't depend on the data insertion order.
1901        (
1902            Arc::clone(&schema),
1903            vec![
1904                RecordBatch::try_new(
1905                    Arc::clone(&schema),
1906                    vec![
1907                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1908                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1909                    ],
1910                )
1911                .unwrap(),
1912                RecordBatch::try_new(
1913                    Arc::clone(&schema),
1914                    vec![
1915                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1916                        Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1917                    ],
1918                )
1919                .unwrap(),
1920                RecordBatch::try_new(
1921                    Arc::clone(&schema),
1922                    vec![
1923                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1924                        Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1925                    ],
1926                )
1927                .unwrap(),
1928                RecordBatch::try_new(
1929                    schema,
1930                    vec![
1931                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1932                        Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1933                    ],
1934                )
1935                .unwrap(),
1936            ],
1937        )
1938    }
1939
1940    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1941        let session_config = SessionConfig::new().with_batch_size(batch_size);
1942        let runtime = RuntimeEnvBuilder::new()
1943            .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1944            .build_arc()
1945            .unwrap();
1946        let task_ctx = TaskContext::default()
1947            .with_session_config(session_config)
1948            .with_runtime(runtime);
1949        Arc::new(task_ctx)
1950    }
1951
1952    async fn check_grouping_sets(
1953        input: Arc<dyn ExecutionPlan>,
1954        spill: bool,
1955    ) -> Result<()> {
1956        let input_schema = input.schema();
1957
1958        let grouping_set = PhysicalGroupBy::new(
1959            vec![
1960                (col("a", &input_schema)?, "a".to_string()),
1961                (col("b", &input_schema)?, "b".to_string()),
1962            ],
1963            vec![
1964                (lit(ScalarValue::UInt32(None)), "a".to_string()),
1965                (lit(ScalarValue::Float64(None)), "b".to_string()),
1966            ],
1967            vec![
1968                vec![false, true],  // (a, NULL)
1969                vec![true, false],  // (NULL, b)
1970                vec![false, false], // (a,b)
1971            ],
1972            true,
1973        );
1974
1975        let aggregates = vec![Arc::new(
1976            AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1977                .schema(Arc::clone(&input_schema))
1978                .alias("COUNT(1)")
1979                .build()?,
1980        )];
1981
1982        let task_ctx = if spill {
1983            // adjust the max memory size to have the partial aggregate result for spill mode.
1984            new_spill_ctx(4, 500)
1985        } else {
1986            Arc::new(TaskContext::default())
1987        };
1988
1989        let partial_aggregate = Arc::new(AggregateExec::try_new(
1990            AggregateMode::Partial,
1991            grouping_set.clone(),
1992            aggregates.clone(),
1993            vec![None],
1994            input,
1995            Arc::clone(&input_schema),
1996        )?);
1997
1998        let result =
1999            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2000
2001        if spill {
2002            // In spill mode, we test with the limited memory, if the mem usage exceeds,
2003            // we trigger the early emit rule, which turns out the partial aggregate result.
2004            allow_duplicates! {
2005            assert_snapshot!(batches_to_sort_string(&result),
2006            @r"
2007            +---+-----+---------------+-----------------+
2008            | a | b   | __grouping_id | COUNT(1)[count] |
2009            +---+-----+---------------+-----------------+
2010            |   | 1.0 | 2             | 1               |
2011            |   | 1.0 | 2             | 1               |
2012            |   | 2.0 | 2             | 1               |
2013            |   | 2.0 | 2             | 1               |
2014            |   | 3.0 | 2             | 1               |
2015            |   | 3.0 | 2             | 1               |
2016            |   | 4.0 | 2             | 1               |
2017            |   | 4.0 | 2             | 1               |
2018            | 2 |     | 1             | 1               |
2019            | 2 |     | 1             | 1               |
2020            | 2 | 1.0 | 0             | 1               |
2021            | 2 | 1.0 | 0             | 1               |
2022            | 3 |     | 1             | 1               |
2023            | 3 |     | 1             | 2               |
2024            | 3 | 2.0 | 0             | 2               |
2025            | 3 | 3.0 | 0             | 1               |
2026            | 4 |     | 1             | 1               |
2027            | 4 |     | 1             | 2               |
2028            | 4 | 3.0 | 0             | 1               |
2029            | 4 | 4.0 | 0             | 2               |
2030            +---+-----+---------------+-----------------+
2031            "
2032            );
2033            }
2034        } else {
2035            allow_duplicates! {
2036            assert_snapshot!(batches_to_sort_string(&result),
2037            @r"
2038            +---+-----+---------------+-----------------+
2039            | a | b   | __grouping_id | COUNT(1)[count] |
2040            +---+-----+---------------+-----------------+
2041            |   | 1.0 | 2             | 2               |
2042            |   | 2.0 | 2             | 2               |
2043            |   | 3.0 | 2             | 2               |
2044            |   | 4.0 | 2             | 2               |
2045            | 2 |     | 1             | 2               |
2046            | 2 | 1.0 | 0             | 2               |
2047            | 3 |     | 1             | 3               |
2048            | 3 | 2.0 | 0             | 2               |
2049            | 3 | 3.0 | 0             | 1               |
2050            | 4 |     | 1             | 3               |
2051            | 4 | 3.0 | 0             | 1               |
2052            | 4 | 4.0 | 0             | 2               |
2053            +---+-----+---------------+-----------------+
2054            "
2055            );
2056            }
2057        };
2058
2059        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2060
2061        let final_grouping_set = grouping_set.as_final();
2062
2063        let task_ctx = if spill {
2064            new_spill_ctx(4, 3160)
2065        } else {
2066            task_ctx
2067        };
2068
2069        let merged_aggregate = Arc::new(AggregateExec::try_new(
2070            AggregateMode::Final,
2071            final_grouping_set,
2072            aggregates,
2073            vec![None],
2074            merge,
2075            input_schema,
2076        )?);
2077
2078        let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2079        let batch = concat_batches(&result[0].schema(), &result)?;
2080        assert_eq!(batch.num_columns(), 4);
2081        assert_eq!(batch.num_rows(), 12);
2082
2083        allow_duplicates! {
2084        assert_snapshot!(
2085            batches_to_sort_string(&result),
2086            @r"
2087        +---+-----+---------------+----------+
2088        | a | b   | __grouping_id | COUNT(1) |
2089        +---+-----+---------------+----------+
2090        |   | 1.0 | 2             | 2        |
2091        |   | 2.0 | 2             | 2        |
2092        |   | 3.0 | 2             | 2        |
2093        |   | 4.0 | 2             | 2        |
2094        | 2 |     | 1             | 2        |
2095        | 2 | 1.0 | 0             | 2        |
2096        | 3 |     | 1             | 3        |
2097        | 3 | 2.0 | 0             | 2        |
2098        | 3 | 3.0 | 0             | 1        |
2099        | 4 |     | 1             | 3        |
2100        | 4 | 3.0 | 0             | 1        |
2101        | 4 | 4.0 | 0             | 2        |
2102        +---+-----+---------------+----------+
2103        "
2104        );
2105        }
2106
2107        let metrics = merged_aggregate.metrics().unwrap();
2108        let output_rows = metrics.output_rows().unwrap();
2109        assert_eq!(12, output_rows);
2110
2111        Ok(())
2112    }
2113
2114    /// build the aggregates on the data from some_data() and check the results
2115    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
2116        let input_schema = input.schema();
2117
2118        let grouping_set = PhysicalGroupBy::new(
2119            vec![(col("a", &input_schema)?, "a".to_string())],
2120            vec![],
2121            vec![vec![false]],
2122            false,
2123        );
2124
2125        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2126            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2127                .schema(Arc::clone(&input_schema))
2128                .alias("AVG(b)")
2129                .build()?,
2130        )];
2131
2132        let task_ctx = if spill {
2133            // set to an appropriate value to trigger spill
2134            new_spill_ctx(2, 1600)
2135        } else {
2136            Arc::new(TaskContext::default())
2137        };
2138
2139        let partial_aggregate = Arc::new(AggregateExec::try_new(
2140            AggregateMode::Partial,
2141            grouping_set.clone(),
2142            aggregates.clone(),
2143            vec![None],
2144            input,
2145            Arc::clone(&input_schema),
2146        )?);
2147
2148        let result =
2149            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2150
2151        if spill {
2152            allow_duplicates! {
2153            assert_snapshot!(batches_to_sort_string(&result), @r"
2154            +---+---------------+-------------+
2155            | a | AVG(b)[count] | AVG(b)[sum] |
2156            +---+---------------+-------------+
2157            | 2 | 1             | 1.0         |
2158            | 2 | 1             | 1.0         |
2159            | 3 | 1             | 2.0         |
2160            | 3 | 2             | 5.0         |
2161            | 4 | 3             | 11.0        |
2162            +---+---------------+-------------+
2163            ");
2164            }
2165        } else {
2166            allow_duplicates! {
2167            assert_snapshot!(batches_to_sort_string(&result), @r"
2168            +---+---------------+-------------+
2169            | a | AVG(b)[count] | AVG(b)[sum] |
2170            +---+---------------+-------------+
2171            | 2 | 2             | 2.0         |
2172            | 3 | 3             | 7.0         |
2173            | 4 | 3             | 11.0        |
2174            +---+---------------+-------------+
2175            ");
2176            }
2177        };
2178
2179        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2180
2181        let final_grouping_set = grouping_set.as_final();
2182
2183        let merged_aggregate = Arc::new(AggregateExec::try_new(
2184            AggregateMode::Final,
2185            final_grouping_set,
2186            aggregates,
2187            vec![None],
2188            merge,
2189            input_schema,
2190        )?);
2191
2192        // Verify statistics are preserved proportionally through aggregation
2193        let final_stats = merged_aggregate.partition_statistics(None)?;
2194        assert!(final_stats.total_byte_size.get_value().is_some());
2195
2196        let task_ctx = if spill {
2197            // enlarge memory limit to let the final aggregation finish
2198            new_spill_ctx(2, 2600)
2199        } else {
2200            Arc::clone(&task_ctx)
2201        };
2202        let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
2203        let batch = concat_batches(&result[0].schema(), &result)?;
2204        assert_eq!(batch.num_columns(), 2);
2205        assert_eq!(batch.num_rows(), 3);
2206
2207        allow_duplicates! {
2208        assert_snapshot!(batches_to_sort_string(&result), @r"
2209        +---+--------------------+
2210        | a | AVG(b)             |
2211        +---+--------------------+
2212        | 2 | 1.0                |
2213        | 3 | 2.3333333333333335 |
2214        | 4 | 3.6666666666666665 |
2215        +---+--------------------+
2216        ");
2217            // For row 2: 3, (2 + 3 + 2) / 3
2218            // For row 3: 4, (3 + 4 + 4) / 3
2219        }
2220
2221        let metrics = merged_aggregate.metrics().unwrap();
2222        let output_rows = metrics.output_rows().unwrap();
2223        let spill_count = metrics.spill_count().unwrap();
2224        let spilled_bytes = metrics.spilled_bytes().unwrap();
2225        let spilled_rows = metrics.spilled_rows().unwrap();
2226
2227        if spill {
2228            // When spilling, the output rows metrics become partial output size + final output size
2229            // This is because final aggregation starts while partial aggregation is still emitting
2230            assert_eq!(8, output_rows);
2231
2232            assert!(spill_count > 0);
2233            assert!(spilled_bytes > 0);
2234            assert!(spilled_rows > 0);
2235        } else {
2236            assert_eq!(3, output_rows);
2237
2238            assert_eq!(0, spill_count);
2239            assert_eq!(0, spilled_bytes);
2240            assert_eq!(0, spilled_rows);
2241        }
2242
2243        Ok(())
2244    }
2245
2246    /// Define a test source that can yield back to runtime before returning its first item ///
2247
2248    #[derive(Debug)]
2249    struct TestYieldingExec {
2250        /// True if this exec should yield back to runtime the first time it is polled
2251        pub yield_first: bool,
2252        cache: PlanProperties,
2253    }
2254
2255    impl TestYieldingExec {
2256        fn new(yield_first: bool) -> Self {
2257            let schema = some_data().0;
2258            let cache = Self::compute_properties(schema);
2259            Self { yield_first, cache }
2260        }
2261
2262        /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
2263        fn compute_properties(schema: SchemaRef) -> PlanProperties {
2264            PlanProperties::new(
2265                EquivalenceProperties::new(schema),
2266                Partitioning::UnknownPartitioning(1),
2267                EmissionType::Incremental,
2268                Boundedness::Bounded,
2269            )
2270        }
2271    }
2272
2273    impl DisplayAs for TestYieldingExec {
2274        fn fmt_as(
2275            &self,
2276            t: DisplayFormatType,
2277            f: &mut std::fmt::Formatter,
2278        ) -> std::fmt::Result {
2279            match t {
2280                DisplayFormatType::Default | DisplayFormatType::Verbose => {
2281                    write!(f, "TestYieldingExec")
2282                }
2283                DisplayFormatType::TreeRender => {
2284                    // TODO: collect info
2285                    write!(f, "")
2286                }
2287            }
2288        }
2289    }
2290
2291    impl ExecutionPlan for TestYieldingExec {
2292        fn name(&self) -> &'static str {
2293            "TestYieldingExec"
2294        }
2295
2296        fn as_any(&self) -> &dyn Any {
2297            self
2298        }
2299
2300        fn properties(&self) -> &PlanProperties {
2301            &self.cache
2302        }
2303
2304        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2305            vec![]
2306        }
2307
2308        fn with_new_children(
2309            self: Arc<Self>,
2310            _: Vec<Arc<dyn ExecutionPlan>>,
2311        ) -> Result<Arc<dyn ExecutionPlan>> {
2312            internal_err!("Children cannot be replaced in {self:?}")
2313        }
2314
2315        fn execute(
2316            &self,
2317            _partition: usize,
2318            _context: Arc<TaskContext>,
2319        ) -> Result<SendableRecordBatchStream> {
2320            let stream = if self.yield_first {
2321                TestYieldingStream::New
2322            } else {
2323                TestYieldingStream::Yielded
2324            };
2325
2326            Ok(Box::pin(stream))
2327        }
2328
2329        fn statistics(&self) -> Result<Statistics> {
2330            self.partition_statistics(None)
2331        }
2332
2333        fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
2334            if partition.is_some() {
2335                return Ok(Statistics::new_unknown(self.schema().as_ref()));
2336            }
2337            let (_, batches) = some_data();
2338            Ok(common::compute_record_batch_statistics(
2339                &[batches],
2340                &self.schema(),
2341                None,
2342            ))
2343        }
2344    }
2345
2346    /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records
2347    enum TestYieldingStream {
2348        New,
2349        Yielded,
2350        ReturnedBatch1,
2351        ReturnedBatch2,
2352    }
2353
2354    impl Stream for TestYieldingStream {
2355        type Item = Result<RecordBatch>;
2356
2357        fn poll_next(
2358            mut self: std::pin::Pin<&mut Self>,
2359            cx: &mut Context<'_>,
2360        ) -> Poll<Option<Self::Item>> {
2361            match &*self {
2362                TestYieldingStream::New => {
2363                    *(self.as_mut()) = TestYieldingStream::Yielded;
2364                    cx.waker().wake_by_ref();
2365                    Poll::Pending
2366                }
2367                TestYieldingStream::Yielded => {
2368                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2369                    Poll::Ready(Some(Ok(some_data().1[0].clone())))
2370                }
2371                TestYieldingStream::ReturnedBatch1 => {
2372                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2373                    Poll::Ready(Some(Ok(some_data().1[1].clone())))
2374                }
2375                TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2376            }
2377        }
2378    }
2379
2380    impl RecordBatchStream for TestYieldingStream {
2381        fn schema(&self) -> SchemaRef {
2382            some_data().0
2383        }
2384    }
2385
2386    //--- Tests ---//
2387
2388    #[tokio::test]
2389    async fn aggregate_source_not_yielding() -> Result<()> {
2390        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2391
2392        check_aggregates(input, false).await
2393    }
2394
2395    #[tokio::test]
2396    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2397        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2398
2399        check_grouping_sets(input, false).await
2400    }
2401
2402    #[tokio::test]
2403    async fn aggregate_source_with_yielding() -> Result<()> {
2404        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2405
2406        check_aggregates(input, false).await
2407    }
2408
2409    #[tokio::test]
2410    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2411        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2412
2413        check_grouping_sets(input, false).await
2414    }
2415
2416    #[tokio::test]
2417    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2418        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2419
2420        check_aggregates(input, true).await
2421    }
2422
2423    #[tokio::test]
2424    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2425        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2426
2427        check_grouping_sets(input, true).await
2428    }
2429
2430    #[tokio::test]
2431    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2432        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2433
2434        check_aggregates(input, true).await
2435    }
2436
2437    #[tokio::test]
2438    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2439        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2440
2441        check_grouping_sets(input, true).await
2442    }
2443
2444    // Median(a)
2445    fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2446        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2447            .schema(schema)
2448            .alias("MEDIAN(a)")
2449            .build()
2450    }
2451
2452    #[tokio::test]
2453    async fn test_oom() -> Result<()> {
2454        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2455        let input_schema = input.schema();
2456
2457        let runtime = RuntimeEnvBuilder::new()
2458            .with_memory_limit(1, 1.0)
2459            .build_arc()?;
2460        let task_ctx = TaskContext::default().with_runtime(runtime);
2461        let task_ctx = Arc::new(task_ctx);
2462
2463        let groups_none = PhysicalGroupBy::default();
2464        let groups_some = PhysicalGroupBy::new(
2465            vec![(col("a", &input_schema)?, "a".to_string())],
2466            vec![],
2467            vec![vec![false]],
2468            false,
2469        );
2470
2471        // something that allocates within the aggregator
2472        let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2473            vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2474
2475        // use fast-path in `row_hash.rs`.
2476        let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2477            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2478                .schema(Arc::clone(&input_schema))
2479                .alias("AVG(b)")
2480                .build()?,
2481        )];
2482
2483        for (version, groups, aggregates) in [
2484            (0, groups_none, aggregates_v0),
2485            (2, groups_some, aggregates_v2),
2486        ] {
2487            let n_aggr = aggregates.len();
2488            let partial_aggregate = Arc::new(AggregateExec::try_new(
2489                AggregateMode::Single,
2490                groups,
2491                aggregates,
2492                vec![None; n_aggr],
2493                Arc::clone(&input),
2494                Arc::clone(&input_schema),
2495            )?);
2496
2497            let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
2498
2499            // ensure that we really got the version we wanted
2500            match version {
2501                0 => {
2502                    assert!(matches!(stream, StreamType::AggregateStream(_)));
2503                }
2504                1 => {
2505                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2506                }
2507                2 => {
2508                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2509                }
2510                _ => panic!("Unknown version: {version}"),
2511            }
2512
2513            let stream: SendableRecordBatchStream = stream.into();
2514            let err = collect(stream).await.unwrap_err();
2515
2516            // error root cause traversal is a bit complicated, see #4172.
2517            let err = err.find_root();
2518            assert!(
2519                matches!(err, DataFusionError::ResourcesExhausted(_)),
2520                "Wrong error type: {err}",
2521            );
2522        }
2523
2524        Ok(())
2525    }
2526
2527    #[tokio::test]
2528    async fn test_drop_cancel_without_groups() -> Result<()> {
2529        let task_ctx = Arc::new(TaskContext::default());
2530        let schema =
2531            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2532
2533        let groups = PhysicalGroupBy::default();
2534
2535        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2536            AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2537                .schema(Arc::clone(&schema))
2538                .alias("AVG(a)")
2539                .build()?,
2540        )];
2541
2542        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2543        let refs = blocking_exec.refs();
2544        let aggregate_exec = Arc::new(AggregateExec::try_new(
2545            AggregateMode::Partial,
2546            groups.clone(),
2547            aggregates.clone(),
2548            vec![None],
2549            blocking_exec,
2550            schema,
2551        )?);
2552
2553        let fut = crate::collect(aggregate_exec, task_ctx);
2554        let mut fut = fut.boxed();
2555
2556        assert_is_pending(&mut fut);
2557        drop(fut);
2558        assert_strong_count_converges_to_zero(refs).await;
2559
2560        Ok(())
2561    }
2562
2563    #[tokio::test]
2564    async fn test_drop_cancel_with_groups() -> Result<()> {
2565        let task_ctx = Arc::new(TaskContext::default());
2566        let schema = Arc::new(Schema::new(vec![
2567            Field::new("a", DataType::Float64, true),
2568            Field::new("b", DataType::Float64, true),
2569        ]));
2570
2571        let groups =
2572            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2573
2574        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2575            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2576                .schema(Arc::clone(&schema))
2577                .alias("AVG(b)")
2578                .build()?,
2579        )];
2580
2581        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2582        let refs = blocking_exec.refs();
2583        let aggregate_exec = Arc::new(AggregateExec::try_new(
2584            AggregateMode::Partial,
2585            groups,
2586            aggregates.clone(),
2587            vec![None],
2588            blocking_exec,
2589            schema,
2590        )?);
2591
2592        let fut = crate::collect(aggregate_exec, task_ctx);
2593        let mut fut = fut.boxed();
2594
2595        assert_is_pending(&mut fut);
2596        drop(fut);
2597        assert_strong_count_converges_to_zero(refs).await;
2598
2599        Ok(())
2600    }
2601
2602    #[tokio::test]
2603    async fn run_first_last_multi_partitions() -> Result<()> {
2604        for use_coalesce_batches in [false, true] {
2605            for is_first_acc in [false, true] {
2606                for spill in [false, true] {
2607                    first_last_multi_partitions(
2608                        use_coalesce_batches,
2609                        is_first_acc,
2610                        spill,
2611                        4200,
2612                    )
2613                    .await?
2614                }
2615            }
2616        }
2617        Ok(())
2618    }
2619
2620    // FIRST_VALUE(b ORDER BY b <SortOptions>)
2621    fn test_first_value_agg_expr(
2622        schema: &Schema,
2623        sort_options: SortOptions,
2624    ) -> Result<Arc<AggregateFunctionExpr>> {
2625        let order_bys = vec![PhysicalSortExpr {
2626            expr: col("b", schema)?,
2627            options: sort_options,
2628        }];
2629        let args = [col("b", schema)?];
2630
2631        AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2632            .order_by(order_bys)
2633            .schema(Arc::new(schema.clone()))
2634            .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2635            .build()
2636            .map(Arc::new)
2637    }
2638
2639    // LAST_VALUE(b ORDER BY b <SortOptions>)
2640    fn test_last_value_agg_expr(
2641        schema: &Schema,
2642        sort_options: SortOptions,
2643    ) -> Result<Arc<AggregateFunctionExpr>> {
2644        let order_bys = vec![PhysicalSortExpr {
2645            expr: col("b", schema)?,
2646            options: sort_options,
2647        }];
2648        let args = [col("b", schema)?];
2649        AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2650            .order_by(order_bys)
2651            .schema(Arc::new(schema.clone()))
2652            .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2653            .build()
2654            .map(Arc::new)
2655    }
2656
2657    // This function either constructs the physical plan below,
2658    //
2659    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2660    // "  CoalesceBatchesExec: target_batch_size=1024",
2661    // "    CoalescePartitionsExec",
2662    // "      AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2663    // "        DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2664    //
2665    // or
2666    //
2667    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2668    // "  CoalescePartitionsExec",
2669    // "    AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2670    // "      DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2671    //
2672    // and checks whether the function `merge_batch` works correctly for
2673    // FIRST_VALUE and LAST_VALUE functions.
2674    async fn first_last_multi_partitions(
2675        use_coalesce_batches: bool,
2676        is_first_acc: bool,
2677        spill: bool,
2678        max_memory: usize,
2679    ) -> Result<()> {
2680        let task_ctx = if spill {
2681            new_spill_ctx(2, max_memory)
2682        } else {
2683            Arc::new(TaskContext::default())
2684        };
2685
2686        let (schema, data) = some_data_v2();
2687        let partition1 = data[0].clone();
2688        let partition2 = data[1].clone();
2689        let partition3 = data[2].clone();
2690        let partition4 = data[3].clone();
2691
2692        let groups =
2693            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2694
2695        let sort_options = SortOptions {
2696            descending: false,
2697            nulls_first: false,
2698        };
2699        let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2700            vec![test_first_value_agg_expr(&schema, sort_options)?]
2701        } else {
2702            vec![test_last_value_agg_expr(&schema, sort_options)?]
2703        };
2704
2705        let memory_exec = TestMemoryExec::try_new_exec(
2706            &[
2707                vec![partition1],
2708                vec![partition2],
2709                vec![partition3],
2710                vec![partition4],
2711            ],
2712            Arc::clone(&schema),
2713            None,
2714        )?;
2715        let aggregate_exec = Arc::new(AggregateExec::try_new(
2716            AggregateMode::Partial,
2717            groups.clone(),
2718            aggregates.clone(),
2719            vec![None],
2720            memory_exec,
2721            Arc::clone(&schema),
2722        )?);
2723        let coalesce = if use_coalesce_batches {
2724            let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2725            Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2726        } else {
2727            Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2728                as Arc<dyn ExecutionPlan>
2729        };
2730        let aggregate_final = Arc::new(AggregateExec::try_new(
2731            AggregateMode::Final,
2732            groups,
2733            aggregates.clone(),
2734            vec![None],
2735            coalesce,
2736            schema,
2737        )?) as Arc<dyn ExecutionPlan>;
2738
2739        let result = crate::collect(aggregate_final, task_ctx).await?;
2740        if is_first_acc {
2741            allow_duplicates! {
2742            assert_snapshot!(batches_to_string(&result), @r"
2743            +---+--------------------------------------------+
2744            | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
2745            +---+--------------------------------------------+
2746            | 2 | 0.0                                        |
2747            | 3 | 1.0                                        |
2748            | 4 | 3.0                                        |
2749            +---+--------------------------------------------+
2750            ");
2751            }
2752        } else {
2753            allow_duplicates! {
2754            assert_snapshot!(batches_to_string(&result), @r"
2755            +---+-------------------------------------------+
2756            | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
2757            +---+-------------------------------------------+
2758            | 2 | 3.0                                       |
2759            | 3 | 5.0                                       |
2760            | 4 | 6.0                                       |
2761            +---+-------------------------------------------+
2762            ");
2763            }
2764        };
2765        Ok(())
2766    }
2767
2768    #[tokio::test]
2769    async fn test_get_finest_requirements() -> Result<()> {
2770        let test_schema = create_test_schema()?;
2771
2772        let options = SortOptions {
2773            descending: false,
2774            nulls_first: false,
2775        };
2776        let col_a = &col("a", &test_schema)?;
2777        let col_b = &col("b", &test_schema)?;
2778        let col_c = &col("c", &test_schema)?;
2779        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2780        // Columns a and b are equal.
2781        eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
2782        // Aggregate requirements are
2783        // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively
2784        let order_by_exprs = vec![
2785            vec![],
2786            vec![PhysicalSortExpr {
2787                expr: Arc::clone(col_a),
2788                options,
2789            }],
2790            vec![
2791                PhysicalSortExpr {
2792                    expr: Arc::clone(col_a),
2793                    options,
2794                },
2795                PhysicalSortExpr {
2796                    expr: Arc::clone(col_b),
2797                    options,
2798                },
2799                PhysicalSortExpr {
2800                    expr: Arc::clone(col_c),
2801                    options,
2802                },
2803            ],
2804            vec![
2805                PhysicalSortExpr {
2806                    expr: Arc::clone(col_a),
2807                    options,
2808                },
2809                PhysicalSortExpr {
2810                    expr: Arc::clone(col_b),
2811                    options,
2812                },
2813            ],
2814        ];
2815
2816        let common_requirement = vec![
2817            PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
2818            PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
2819        ];
2820        let mut aggr_exprs = order_by_exprs
2821            .into_iter()
2822            .map(|order_by_expr| {
2823                AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2824                    .alias("a")
2825                    .order_by(order_by_expr)
2826                    .schema(Arc::clone(&test_schema))
2827                    .build()
2828                    .map(Arc::new)
2829                    .unwrap()
2830            })
2831            .collect::<Vec<_>>();
2832        let group_by = PhysicalGroupBy::new_single(vec![]);
2833        let result = get_finer_aggregate_exprs_requirement(
2834            &mut aggr_exprs,
2835            &group_by,
2836            &eq_properties,
2837            &AggregateMode::Partial,
2838        )?;
2839        assert_eq!(result, common_requirement);
2840        Ok(())
2841    }
2842
2843    #[test]
2844    fn test_agg_exec_same_schema() -> Result<()> {
2845        let schema = Arc::new(Schema::new(vec![
2846            Field::new("a", DataType::Float32, true),
2847            Field::new("b", DataType::Float32, true),
2848        ]));
2849
2850        let col_a = col("a", &schema)?;
2851        let option_desc = SortOptions {
2852            descending: true,
2853            nulls_first: true,
2854        };
2855        let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2856
2857        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2858            test_first_value_agg_expr(&schema, option_desc)?,
2859            test_last_value_agg_expr(&schema, option_desc)?,
2860        ];
2861        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2862        let aggregate_exec = Arc::new(AggregateExec::try_new(
2863            AggregateMode::Partial,
2864            groups,
2865            aggregates,
2866            vec![None, None],
2867            Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2868            schema,
2869        )?);
2870        let new_agg =
2871            Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
2872        assert_eq!(new_agg.schema(), aggregate_exec.schema());
2873        Ok(())
2874    }
2875
2876    #[tokio::test]
2877    async fn test_agg_exec_group_by_const() -> Result<()> {
2878        let schema = Arc::new(Schema::new(vec![
2879            Field::new("a", DataType::Float32, true),
2880            Field::new("b", DataType::Float32, true),
2881            Field::new("const", DataType::Int32, false),
2882        ]));
2883
2884        let col_a = col("a", &schema)?;
2885        let col_b = col("b", &schema)?;
2886        let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2887
2888        let groups = PhysicalGroupBy::new(
2889            vec![
2890                (col_a, "a".to_string()),
2891                (col_b, "b".to_string()),
2892                (const_expr, "const".to_string()),
2893            ],
2894            vec![
2895                (
2896                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2897                    "a".to_string(),
2898                ),
2899                (
2900                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2901                    "b".to_string(),
2902                ),
2903                (
2904                    Arc::new(Literal::new(ScalarValue::Int32(None))),
2905                    "const".to_string(),
2906                ),
2907            ],
2908            vec![
2909                vec![false, true, true],
2910                vec![true, false, true],
2911                vec![true, true, false],
2912            ],
2913            true,
2914        );
2915
2916        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2917            AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2918                .schema(Arc::clone(&schema))
2919                .alias("1")
2920                .build()
2921                .map(Arc::new)?,
2922        ];
2923
2924        let input_batches = (0..4)
2925            .map(|_| {
2926                let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2927                let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2928                let c = Arc::new(Int32Array::from(vec![1; 8192]));
2929
2930                RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2931            })
2932            .collect();
2933
2934        let input =
2935            TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
2936
2937        let aggregate_exec = Arc::new(AggregateExec::try_new(
2938            AggregateMode::Single,
2939            groups,
2940            aggregates.clone(),
2941            vec![None],
2942            input,
2943            schema,
2944        )?);
2945
2946        let output =
2947            collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2948
2949        allow_duplicates! {
2950        assert_snapshot!(batches_to_sort_string(&output), @r"
2951        +-----+-----+-------+---------------+-------+
2952        | a   | b   | const | __grouping_id | 1     |
2953        +-----+-----+-------+---------------+-------+
2954        |     |     | 1     | 6             | 32768 |
2955        |     | 0.0 |       | 5             | 32768 |
2956        | 0.0 |     |       | 3             | 32768 |
2957        +-----+-----+-------+---------------+-------+
2958        ");
2959        }
2960
2961        Ok(())
2962    }
2963
2964    #[tokio::test]
2965    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2966        let batch = RecordBatch::try_new(
2967            Arc::new(Schema::new(vec![
2968                Field::new(
2969                    "labels".to_string(),
2970                    DataType::Struct(
2971                        vec![
2972                            Field::new(
2973                                "a".to_string(),
2974                                DataType::Dictionary(
2975                                    Box::new(DataType::Int32),
2976                                    Box::new(DataType::Utf8),
2977                                ),
2978                                true,
2979                            ),
2980                            Field::new(
2981                                "b".to_string(),
2982                                DataType::Dictionary(
2983                                    Box::new(DataType::Int32),
2984                                    Box::new(DataType::Utf8),
2985                                ),
2986                                true,
2987                            ),
2988                        ]
2989                        .into(),
2990                    ),
2991                    false,
2992                ),
2993                Field::new("value", DataType::UInt64, false),
2994            ])),
2995            vec![
2996                Arc::new(StructArray::from(vec![
2997                    (
2998                        Arc::new(Field::new(
2999                            "a".to_string(),
3000                            DataType::Dictionary(
3001                                Box::new(DataType::Int32),
3002                                Box::new(DataType::Utf8),
3003                            ),
3004                            true,
3005                        )),
3006                        Arc::new(
3007                            vec![Some("a"), None, Some("a")]
3008                                .into_iter()
3009                                .collect::<DictionaryArray<Int32Type>>(),
3010                        ) as ArrayRef,
3011                    ),
3012                    (
3013                        Arc::new(Field::new(
3014                            "b".to_string(),
3015                            DataType::Dictionary(
3016                                Box::new(DataType::Int32),
3017                                Box::new(DataType::Utf8),
3018                            ),
3019                            true,
3020                        )),
3021                        Arc::new(
3022                            vec![Some("b"), Some("c"), Some("b")]
3023                                .into_iter()
3024                                .collect::<DictionaryArray<Int32Type>>(),
3025                        ) as ArrayRef,
3026                    ),
3027                ])),
3028                Arc::new(UInt64Array::from(vec![1, 1, 1])),
3029            ],
3030        )
3031        .expect("Failed to create RecordBatch");
3032
3033        let group_by = PhysicalGroupBy::new_single(vec![(
3034            col("labels", &batch.schema())?,
3035            "labels".to_string(),
3036        )]);
3037
3038        let aggr_expr = vec![
3039            AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
3040                .schema(Arc::clone(&batch.schema()))
3041                .alias(String::from("SUM(value)"))
3042                .build()
3043                .map(Arc::new)?,
3044        ];
3045
3046        let input = TestMemoryExec::try_new_exec(
3047            &[vec![batch.clone()]],
3048            Arc::<Schema>::clone(&batch.schema()),
3049            None,
3050        )?;
3051        let aggregate_exec = Arc::new(AggregateExec::try_new(
3052            AggregateMode::FinalPartitioned,
3053            group_by,
3054            aggr_expr,
3055            vec![None],
3056            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3057            batch.schema(),
3058        )?);
3059
3060        let session_config = SessionConfig::default();
3061        let ctx = TaskContext::default().with_session_config(session_config);
3062        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3063
3064        allow_duplicates! {
3065        assert_snapshot!(batches_to_string(&output), @r"
3066        +--------------+------------+
3067        | labels       | SUM(value) |
3068        +--------------+------------+
3069        | {a: a, b: b} | 2          |
3070        | {a: , b: c}  | 1          |
3071        +--------------+------------+
3072        ");
3073        }
3074
3075        Ok(())
3076    }
3077
3078    #[tokio::test]
3079    async fn test_skip_aggregation_after_first_batch() -> Result<()> {
3080        let schema = Arc::new(Schema::new(vec![
3081            Field::new("key", DataType::Int32, true),
3082            Field::new("val", DataType::Int32, true),
3083        ]));
3084
3085        let group_by =
3086            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3087
3088        let aggr_expr = vec![
3089            AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3090                .schema(Arc::clone(&schema))
3091                .alias(String::from("COUNT(val)"))
3092                .build()
3093                .map(Arc::new)?,
3094        ];
3095
3096        let input_data = vec![
3097            RecordBatch::try_new(
3098                Arc::clone(&schema),
3099                vec![
3100                    Arc::new(Int32Array::from(vec![1, 2, 3])),
3101                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3102                ],
3103            )
3104            .unwrap(),
3105            RecordBatch::try_new(
3106                Arc::clone(&schema),
3107                vec![
3108                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3109                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3110                ],
3111            )
3112            .unwrap(),
3113        ];
3114
3115        let input =
3116            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3117        let aggregate_exec = Arc::new(AggregateExec::try_new(
3118            AggregateMode::Partial,
3119            group_by,
3120            aggr_expr,
3121            vec![None],
3122            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3123            schema,
3124        )?);
3125
3126        let mut session_config = SessionConfig::default();
3127        session_config = session_config.set(
3128            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3129            &ScalarValue::Int64(Some(2)),
3130        );
3131        session_config = session_config.set(
3132            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3133            &ScalarValue::Float64(Some(0.1)),
3134        );
3135
3136        let ctx = TaskContext::default().with_session_config(session_config);
3137        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3138
3139        allow_duplicates! {
3140            assert_snapshot!(batches_to_string(&output), @r"
3141            +-----+-------------------+
3142            | key | COUNT(val)[count] |
3143            +-----+-------------------+
3144            | 1   | 1                 |
3145            | 2   | 1                 |
3146            | 3   | 1                 |
3147            | 2   | 1                 |
3148            | 3   | 1                 |
3149            | 4   | 1                 |
3150            +-----+-------------------+
3151            ");
3152        }
3153
3154        Ok(())
3155    }
3156
3157    #[tokio::test]
3158    async fn test_skip_aggregation_after_threshold() -> Result<()> {
3159        let schema = Arc::new(Schema::new(vec![
3160            Field::new("key", DataType::Int32, true),
3161            Field::new("val", DataType::Int32, true),
3162        ]));
3163
3164        let group_by =
3165            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3166
3167        let aggr_expr = vec![
3168            AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3169                .schema(Arc::clone(&schema))
3170                .alias(String::from("COUNT(val)"))
3171                .build()
3172                .map(Arc::new)?,
3173        ];
3174
3175        let input_data = vec![
3176            RecordBatch::try_new(
3177                Arc::clone(&schema),
3178                vec![
3179                    Arc::new(Int32Array::from(vec![1, 2, 3])),
3180                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3181                ],
3182            )
3183            .unwrap(),
3184            RecordBatch::try_new(
3185                Arc::clone(&schema),
3186                vec![
3187                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3188                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3189                ],
3190            )
3191            .unwrap(),
3192            RecordBatch::try_new(
3193                Arc::clone(&schema),
3194                vec![
3195                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3196                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3197                ],
3198            )
3199            .unwrap(),
3200        ];
3201
3202        let input =
3203            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3204        let aggregate_exec = Arc::new(AggregateExec::try_new(
3205            AggregateMode::Partial,
3206            group_by,
3207            aggr_expr,
3208            vec![None],
3209            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3210            schema,
3211        )?);
3212
3213        let mut session_config = SessionConfig::default();
3214        session_config = session_config.set(
3215            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3216            &ScalarValue::Int64(Some(5)),
3217        );
3218        session_config = session_config.set(
3219            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3220            &ScalarValue::Float64(Some(0.1)),
3221        );
3222
3223        let ctx = TaskContext::default().with_session_config(session_config);
3224        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3225
3226        allow_duplicates! {
3227            assert_snapshot!(batches_to_string(&output), @r"
3228            +-----+-------------------+
3229            | key | COUNT(val)[count] |
3230            +-----+-------------------+
3231            | 1   | 1                 |
3232            | 2   | 2                 |
3233            | 3   | 2                 |
3234            | 4   | 1                 |
3235            | 2   | 1                 |
3236            | 3   | 1                 |
3237            | 4   | 1                 |
3238            +-----+-------------------+
3239            ");
3240        }
3241
3242        Ok(())
3243    }
3244
3245    #[test]
3246    fn group_exprs_nullable() -> Result<()> {
3247        let input_schema = Arc::new(Schema::new(vec![
3248            Field::new("a", DataType::Float32, false),
3249            Field::new("b", DataType::Float32, false),
3250        ]));
3251
3252        let aggr_expr = vec![
3253            AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
3254                .schema(Arc::clone(&input_schema))
3255                .alias("COUNT(a)")
3256                .build()
3257                .map(Arc::new)?,
3258        ];
3259
3260        let grouping_set = PhysicalGroupBy::new(
3261            vec![
3262                (col("a", &input_schema)?, "a".to_string()),
3263                (col("b", &input_schema)?, "b".to_string()),
3264            ],
3265            vec![
3266                (lit(ScalarValue::Float32(None)), "a".to_string()),
3267                (lit(ScalarValue::Float32(None)), "b".to_string()),
3268            ],
3269            vec![
3270                vec![false, true],  // (a, NULL)
3271                vec![false, false], // (a,b)
3272            ],
3273            true,
3274        );
3275        let aggr_schema = create_schema(
3276            &input_schema,
3277            &grouping_set,
3278            &aggr_expr,
3279            AggregateMode::Final,
3280        )?;
3281        let expected_schema = Schema::new(vec![
3282            Field::new("a", DataType::Float32, false),
3283            Field::new("b", DataType::Float32, true),
3284            Field::new("__grouping_id", DataType::UInt8, false),
3285            Field::new("COUNT(a)", DataType::Int64, false),
3286        ]);
3287        assert_eq!(aggr_schema, expected_schema);
3288        Ok(())
3289    }
3290
3291    // test for https://github.com/apache/datafusion/issues/13949
3292    async fn run_test_with_spill_pool_if_necessary(
3293        pool_size: usize,
3294        expect_spill: bool,
3295    ) -> Result<()> {
3296        fn create_record_batch(
3297            schema: &Arc<Schema>,
3298            data: (Vec<u32>, Vec<f64>),
3299        ) -> Result<RecordBatch> {
3300            Ok(RecordBatch::try_new(
3301                Arc::clone(schema),
3302                vec![
3303                    Arc::new(UInt32Array::from(data.0)),
3304                    Arc::new(Float64Array::from(data.1)),
3305                ],
3306            )?)
3307        }
3308
3309        let schema = Arc::new(Schema::new(vec![
3310            Field::new("a", DataType::UInt32, false),
3311            Field::new("b", DataType::Float64, false),
3312        ]));
3313
3314        let batches = vec![
3315            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3316            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3317        ];
3318        let plan: Arc<dyn ExecutionPlan> =
3319            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3320
3321        let grouping_set = PhysicalGroupBy::new(
3322            vec![(col("a", &schema)?, "a".to_string())],
3323            vec![],
3324            vec![vec![false]],
3325            false,
3326        );
3327
3328        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
3329        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3330            Arc::new(
3331                AggregateExprBuilder::new(
3332                    datafusion_functions_aggregate::min_max::min_udaf(),
3333                    vec![col("b", &schema)?],
3334                )
3335                .schema(Arc::clone(&schema))
3336                .alias("MIN(b)")
3337                .build()?,
3338            ),
3339            Arc::new(
3340                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3341                    .schema(Arc::clone(&schema))
3342                    .alias("AVG(b)")
3343                    .build()?,
3344            ),
3345        ];
3346
3347        let single_aggregate = Arc::new(AggregateExec::try_new(
3348            AggregateMode::Single,
3349            grouping_set,
3350            aggregates,
3351            vec![None, None],
3352            plan,
3353            Arc::clone(&schema),
3354        )?);
3355
3356        let batch_size = 2;
3357        let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3358        let task_ctx = Arc::new(
3359            TaskContext::default()
3360                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3361                .with_runtime(Arc::new(
3362                    RuntimeEnvBuilder::new()
3363                        .with_memory_pool(memory_pool)
3364                        .build()?,
3365                )),
3366        );
3367
3368        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3369
3370        assert_spill_count_metric(expect_spill, single_aggregate);
3371
3372        allow_duplicates! {
3373            assert_snapshot!(batches_to_string(&result), @r"
3374            +---+--------+--------+
3375            | a | MIN(b) | AVG(b) |
3376            +---+--------+--------+
3377            | 2 | 1.0    | 1.0    |
3378            | 3 | 2.0    | 2.0    |
3379            | 4 | 3.0    | 3.5    |
3380            +---+--------+--------+
3381            ");
3382        }
3383
3384        Ok(())
3385    }
3386
3387    fn assert_spill_count_metric(
3388        expect_spill: bool,
3389        single_aggregate: Arc<AggregateExec>,
3390    ) {
3391        if let Some(metrics_set) = single_aggregate.metrics() {
3392            let mut spill_count = 0;
3393
3394            // Inspect metrics for SpillCount
3395            for metric in metrics_set.iter() {
3396                if let MetricValue::SpillCount(count) = metric.value() {
3397                    spill_count = count.value();
3398                    break;
3399                }
3400            }
3401
3402            if expect_spill && spill_count == 0 {
3403                panic!(
3404                    "Expected spill but SpillCount metric not found or SpillCount was 0."
3405                );
3406            } else if !expect_spill && spill_count > 0 {
3407                panic!(
3408                    "Expected no spill but found SpillCount metric with value greater than 0."
3409                );
3410            }
3411        } else {
3412            panic!("No metrics returned from the operator; cannot verify spilling.");
3413        }
3414    }
3415
3416    #[tokio::test]
3417    async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3418        // test with spill
3419        run_test_with_spill_pool_if_necessary(2_000, true).await?;
3420        // test without spill
3421        run_test_with_spill_pool_if_necessary(20_000, false).await?;
3422        Ok(())
3423    }
3424
3425    #[tokio::test]
3426    async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
3427        // test with spill
3428        fn create_record_batch(
3429            schema: &Arc<Schema>,
3430            data: (Vec<u32>, Vec<f64>),
3431        ) -> Result<RecordBatch> {
3432            Ok(RecordBatch::try_new(
3433                Arc::clone(schema),
3434                vec![
3435                    Arc::new(UInt32Array::from(data.0)),
3436                    Arc::new(Float64Array::from(data.1)),
3437                ],
3438            )?)
3439        }
3440
3441        let schema = Arc::new(Schema::new(vec![
3442            Field::new("a", DataType::UInt32, false),
3443            Field::new("b", DataType::Float64, false),
3444        ]));
3445
3446        let batches = vec![
3447            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3448            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3449        ];
3450        let plan: Arc<dyn ExecutionPlan> =
3451            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3452        let proj = ProjectionExec::try_new(
3453            vec![
3454                ProjectionExpr::new(lit("0"), "l".to_string()),
3455                ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
3456                ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
3457            ],
3458            plan,
3459        )?;
3460        let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
3461        let schema = plan.schema();
3462
3463        let grouping_set = PhysicalGroupBy::new(
3464            vec![
3465                (col("l", &schema)?, "l".to_string()),
3466                (col("a", &schema)?, "a".to_string()),
3467            ],
3468            vec![],
3469            vec![vec![false, false]],
3470            false,
3471        );
3472
3473        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
3474        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3475            Arc::new(
3476                AggregateExprBuilder::new(
3477                    datafusion_functions_aggregate::min_max::min_udaf(),
3478                    vec![col("b", &schema)?],
3479                )
3480                .schema(Arc::clone(&schema))
3481                .alias("MIN(b)")
3482                .build()?,
3483            ),
3484            Arc::new(
3485                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3486                    .schema(Arc::clone(&schema))
3487                    .alias("AVG(b)")
3488                    .build()?,
3489            ),
3490        ];
3491
3492        let single_aggregate = Arc::new(AggregateExec::try_new(
3493            AggregateMode::Single,
3494            grouping_set,
3495            aggregates,
3496            vec![None, None],
3497            plan,
3498            Arc::clone(&schema),
3499        )?);
3500
3501        let batch_size = 2;
3502        let memory_pool = Arc::new(FairSpillPool::new(2000));
3503        let task_ctx = Arc::new(
3504            TaskContext::default()
3505                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3506                .with_runtime(Arc::new(
3507                    RuntimeEnvBuilder::new()
3508                        .with_memory_pool(memory_pool)
3509                        .build()?,
3510                )),
3511        );
3512
3513        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
3514        match result {
3515            Ok(result) => {
3516                assert_spill_count_metric(true, single_aggregate);
3517
3518                allow_duplicates! {
3519                    assert_snapshot!(batches_to_string(&result), @r"
3520                +---+---+--------+--------+
3521                | l | a | MIN(b) | AVG(b) |
3522                +---+---+--------+--------+
3523                | 0 | 2 | 1.0    | 1.0    |
3524                | 0 | 3 | 2.0    | 2.0    |
3525                | 0 | 4 | 3.0    | 3.5    |
3526                +---+---+--------+--------+
3527            ");
3528                }
3529            }
3530            Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
3531        }
3532
3533        Ok(())
3534    }
3535
3536    #[tokio::test]
3537    async fn test_aggregate_statistics_edge_cases() -> Result<()> {
3538        use crate::test::exec::StatisticsExec;
3539        use datafusion_common::ColumnStatistics;
3540
3541        let schema = Arc::new(Schema::new(vec![
3542            Field::new("a", DataType::Int32, false),
3543            Field::new("b", DataType::Float64, false),
3544        ]));
3545
3546        // Test 1: Absent statistics remain absent
3547        let input = Arc::new(StatisticsExec::new(
3548            Statistics {
3549                num_rows: Precision::Exact(100),
3550                total_byte_size: Precision::Absent,
3551                column_statistics: vec![
3552                    ColumnStatistics::new_unknown(),
3553                    ColumnStatistics::new_unknown(),
3554                ],
3555            },
3556            (*schema).clone(),
3557        )) as Arc<dyn ExecutionPlan>;
3558
3559        let agg = Arc::new(AggregateExec::try_new(
3560            AggregateMode::Final,
3561            PhysicalGroupBy::default(),
3562            vec![Arc::new(
3563                AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3564                    .schema(Arc::clone(&schema))
3565                    .alias("COUNT(a)")
3566                    .build()?,
3567            )],
3568            vec![None],
3569            input,
3570            Arc::clone(&schema),
3571        )?);
3572
3573        let stats = agg.partition_statistics(None)?;
3574        assert_eq!(stats.total_byte_size, Precision::Absent);
3575
3576        // Test 2: Zero rows returns Absent (can't estimate output size from zero input)
3577        let input_zero = Arc::new(StatisticsExec::new(
3578            Statistics {
3579                num_rows: Precision::Exact(0),
3580                total_byte_size: Precision::Exact(0),
3581                column_statistics: vec![
3582                    ColumnStatistics::new_unknown(),
3583                    ColumnStatistics::new_unknown(),
3584                ],
3585            },
3586            (*schema).clone(),
3587        )) as Arc<dyn ExecutionPlan>;
3588
3589        let agg_zero = Arc::new(AggregateExec::try_new(
3590            AggregateMode::Final,
3591            PhysicalGroupBy::default(),
3592            vec![Arc::new(
3593                AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3594                    .schema(Arc::clone(&schema))
3595                    .alias("COUNT(a)")
3596                    .build()?,
3597            )],
3598            vec![None],
3599            input_zero,
3600            Arc::clone(&schema),
3601        )?);
3602
3603        let stats_zero = agg_zero.partition_statistics(None)?;
3604        assert_eq!(stats_zero.total_byte_size, Precision::Absent);
3605
3606        Ok(())
3607    }
3608
3609    #[tokio::test]
3610    async fn test_order_is_retained_when_spilling() -> Result<()> {
3611        let schema = Arc::new(Schema::new(vec![
3612            Field::new("a", DataType::Int64, false),
3613            Field::new("b", DataType::Int64, false),
3614            Field::new("c", DataType::Int64, false),
3615        ]));
3616
3617        let batches = vec![vec![
3618            RecordBatch::try_new(
3619                Arc::clone(&schema),
3620                vec![
3621                    Arc::new(Int64Array::from(vec![2])),
3622                    Arc::new(Int64Array::from(vec![2])),
3623                    Arc::new(Int64Array::from(vec![1])),
3624                ],
3625            )?,
3626            RecordBatch::try_new(
3627                Arc::clone(&schema),
3628                vec![
3629                    Arc::new(Int64Array::from(vec![1])),
3630                    Arc::new(Int64Array::from(vec![1])),
3631                    Arc::new(Int64Array::from(vec![1])),
3632                ],
3633            )?,
3634            RecordBatch::try_new(
3635                Arc::clone(&schema),
3636                vec![
3637                    Arc::new(Int64Array::from(vec![0])),
3638                    Arc::new(Int64Array::from(vec![0])),
3639                    Arc::new(Int64Array::from(vec![1])),
3640                ],
3641            )?,
3642        ]];
3643        let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3644        let scan = scan.try_with_sort_information(vec![
3645            LexOrdering::new([PhysicalSortExpr::new(
3646                col("b", schema.as_ref())?,
3647                SortOptions::default().desc(),
3648            )])
3649            .unwrap(),
3650        ])?;
3651
3652        let aggr = Arc::new(AggregateExec::try_new(
3653            AggregateMode::Single,
3654            PhysicalGroupBy::new(
3655                vec![
3656                    (col("b", schema.as_ref())?, "b".to_string()),
3657                    (col("c", schema.as_ref())?, "c".to_string()),
3658                ],
3659                vec![],
3660                vec![vec![false, false]],
3661                false,
3662            ),
3663            vec![Arc::new(
3664                AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
3665                    .schema(Arc::clone(&schema))
3666                    .alias("SUM(c)")
3667                    .build()?,
3668            )],
3669            vec![None],
3670            Arc::new(scan) as Arc<dyn ExecutionPlan>,
3671            Arc::clone(&schema),
3672        )?);
3673
3674        let task_ctx = new_spill_ctx(1, 600);
3675        let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
3676        assert_spill_count_metric(true, aggr);
3677
3678        allow_duplicates! {
3679            assert_snapshot!(batches_to_string(&result), @r"
3680            +---+---+--------+
3681            | b | c | SUM(c) |
3682            +---+---+--------+
3683            | 2 | 1 | 1      |
3684            | 1 | 1 | 1      |
3685            | 0 | 1 | 1      |
3686            +---+---+--------+
3687        ");
3688        }
3689        Ok(())
3690    }
3691}