Skip to main content

datafusion_physical_plan/aggregates/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregates functionalities
19
20use std::borrow::Cow;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25    no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26    topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30    ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
31    FilterPushdownPropagation, PushedDownPredicate,
32};
33use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
34use crate::{
35    DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36    SendableRecordBatchStream, Statistics, check_if_same_properties,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use parking_lot::Mutex;
41use std::collections::{HashMap, HashSet};
42
43use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use arrow_schema::FieldRef;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49    Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err,
50    internal_err, not_impl_err,
51};
52use datafusion_execution::TaskContext;
53use datafusion_expr::{Accumulator, Aggregate};
54use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
55use datafusion_physical_expr::equivalence::ProjectionMapping;
56use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
57use datafusion_physical_expr::{
58    ConstExpr, EquivalenceProperties, physical_exprs_contains,
59};
60use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
61use datafusion_physical_expr_common::sort_expr::{
62    LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
63};
64
65use datafusion_expr::utils::AggregateOrderSensitivity;
66use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
67use itertools::Itertools;
68use topk::hash_table::is_supported_hash_key_type;
69use topk::heap::is_supported_heap_type;
70
71pub mod group_values;
72mod no_grouping;
73pub mod order;
74mod row_hash;
75mod topk;
76mod topk_stream;
77
78/// Returns true if TopK aggregation data structures support the provided key and value types.
79///
80/// This function checks whether both the key type (used for grouping) and value type
81/// (used in min/max aggregation) can be handled by the TopK aggregation heap and hash table.
82/// Supported types include Arrow primitives (integers, floats, decimals, intervals) and
83/// UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`).
84/// ```text
85pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool {
86    is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type)
87}
88
89/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions.
90const AGGREGATION_HASH_SEED: datafusion_common::hash_utils::RandomState =
91    // This seed is chosen to be a large 64-bit number
92    datafusion_common::hash_utils::RandomState::with_seed(15395726432021054657);
93
94/// Whether an aggregate stage consumes raw input data or intermediate
95/// accumulator state from a previous aggregation stage.
96///
97/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
98/// for how this relates to aggregate modes.
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
100pub enum AggregateInputMode {
101    /// The stage consumes raw, unaggregated input data and calls
102    /// [`Accumulator::update_batch`].
103    Raw,
104    /// The stage consumes intermediate accumulator state from a previous
105    /// aggregation stage and calls [`Accumulator::merge_batch`].
106    Partial,
107}
108
109/// Whether an aggregate stage produces intermediate accumulator state
110/// or final output values.
111///
112/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
113/// for how this relates to aggregate modes.
114#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
115pub enum AggregateOutputMode {
116    /// The stage produces intermediate accumulator state, serialized via
117    /// [`Accumulator::state`].
118    Partial,
119    /// The stage produces final output values via
120    /// [`Accumulator::evaluate`].
121    Final,
122}
123
124/// Aggregation modes
125///
126/// See [`Accumulator::state`] for background information on multi-phase
127/// aggregation and how these modes are used.
128///
129/// # Variants and their input/output modes
130///
131/// Each variant can be characterized by its [`AggregateInputMode`] and
132/// [`AggregateOutputMode`]:
133///
134/// ```text
135///                       | Input: Raw data           | Input: Partial state
136/// Output: Final values  | Single, SinglePartitioned | Final, FinalPartitioned
137/// Output: Partial state | Partial                   | PartialReduce
138/// ```
139///
140/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`]
141/// to query these properties.
142#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
143pub enum AggregateMode {
144    /// One of multiple layers of aggregation, any input partitioning
145    ///
146    /// Partial aggregate that can be applied in parallel across input
147    /// partitions.
148    ///
149    /// This is the first phase of a multi-phase aggregation.
150    Partial,
151    /// *Final* of multiple layers of aggregation, in exactly one partition
152    ///
153    /// Final aggregate that produces a single partition of output by combining
154    /// the output of multiple partial aggregates.
155    ///
156    /// This is the second phase of a multi-phase aggregation.
157    ///
158    /// This mode requires that the input is a single partition
159    ///
160    /// Note: Adjacent `Partial` and `Final` mode aggregation is equivalent to a `Single`
161    /// mode aggregation node. The `Final` mode is required since this is used in an
162    /// intermediate step. The [`CombinePartialFinalAggregate`] physical optimizer rule
163    /// will replace this combination with `Single` mode for more efficient execution.
164    ///
165    /// [`CombinePartialFinalAggregate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/combine_partial_final_agg/struct.CombinePartialFinalAggregate.html
166    Final,
167    /// *Final* of multiple layers of aggregation, input is *Partitioned*
168    ///
169    /// Final aggregate that works on pre-partitioned data.
170    ///
171    /// This mode requires that all rows with a particular grouping key are in
172    /// the same partitions, such as is the case with Hash repartitioning on the
173    /// group keys. If a group key is duplicated, duplicate groups would be
174    /// produced
175    FinalPartitioned,
176    /// *Single* layer of Aggregation, input is exactly one partition
177    ///
178    /// Applies the entire logical aggregation operation in a single operator,
179    /// as opposed to Partial / Final modes which apply the logical aggregation using
180    /// two operators.
181    ///
182    /// This mode requires that the input is a single partition (like Final)
183    Single,
184    /// *Single* layer of Aggregation, input is *Partitioned*
185    ///
186    /// Applies the entire logical aggregation operation in a single operator,
187    /// as opposed to Partial / Final modes which apply the logical aggregation
188    /// using two operators.
189    ///
190    /// This mode requires that the input has more than one partition, and is
191    /// partitioned by group key (like FinalPartitioned).
192    SinglePartitioned,
193    /// Combine multiple partial aggregations to produce a new partial
194    /// aggregation.
195    ///
196    /// Input is intermediate accumulator state (like Final), but output is
197    /// also intermediate accumulator state (like Partial). This enables
198    /// tree-reduce aggregation strategies where partial results from
199    /// multiple workers are combined in multiple stages before a final
200    /// evaluation.
201    ///
202    /// ```text
203    ///               Final
204    ///            /        \
205    ///     PartialReduce   PartialReduce
206    ///     /         \      /         \
207    ///  Partial   Partial  Partial   Partial
208    /// ```
209    PartialReduce,
210}
211
212impl AggregateMode {
213    /// Returns the [`AggregateInputMode`] for this mode: whether this
214    /// stage consumes raw input data or intermediate accumulator state.
215    ///
216    /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes)
217    /// for details.
218    pub fn input_mode(&self) -> AggregateInputMode {
219        match self {
220            AggregateMode::Partial
221            | AggregateMode::Single
222            | AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
223            AggregateMode::Final
224            | AggregateMode::FinalPartitioned
225            | AggregateMode::PartialReduce => AggregateInputMode::Partial,
226        }
227    }
228
229    /// Returns the [`AggregateOutputMode`] for this mode: whether this
230    /// stage produces intermediate accumulator state or final output values.
231    ///
232    /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes)
233    /// for details.
234    pub fn output_mode(&self) -> AggregateOutputMode {
235        match self {
236            AggregateMode::Final
237            | AggregateMode::FinalPartitioned
238            | AggregateMode::Single
239            | AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
240            AggregateMode::Partial | AggregateMode::PartialReduce => {
241                AggregateOutputMode::Partial
242            }
243        }
244    }
245}
246
247/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
248/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
249/// and a single group [false, false].
250/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression
251/// into multiple groups, using null expressions to align each group.
252/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should
253/// create a `PhysicalGroupBy` like
254/// ```text
255/// PhysicalGroupBy {
256///     expr: [(col(a), a), (col(b), b)],
257///     null_expr: [(NULL, a), (NULL, b)],
258///     groups: [
259///         [false, false], // (a,b)
260///         [false, true],  // (a) <=> (a, NULL)
261///         [true, false]   // (b) <=> (NULL, b)
262///     ]
263/// }
264/// ```
265#[derive(Clone, Debug, Default)]
266pub struct PhysicalGroupBy {
267    /// Distinct (Physical Expr, Alias) in the grouping set
268    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
269    /// Corresponding NULL expressions for expr
270    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
271    /// Null mask for each group in this grouping set. Each group is
272    /// composed of either one of the group expressions in expr or a null
273    /// expression in null_expr. If `groups[i][j]` is true, then the
274    /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
275    groups: Vec<Vec<bool>>,
276    /// True when GROUPING SETS/CUBE/ROLLUP are used so `__grouping_id` should
277    /// be included in the output schema.
278    has_grouping_set: bool,
279}
280
281impl PhysicalGroupBy {
282    /// Create a new `PhysicalGroupBy`
283    pub fn new(
284        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
285        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
286        groups: Vec<Vec<bool>>,
287        has_grouping_set: bool,
288    ) -> Self {
289        Self {
290            expr,
291            null_expr,
292            groups,
293            has_grouping_set,
294        }
295    }
296
297    /// Create a GROUPING SET with only a single group. This is the "standard"
298    /// case when building a plan from an expression such as `GROUP BY a,b,c`
299    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
300        let num_exprs = expr.len();
301        Self {
302            expr,
303            null_expr: vec![],
304            groups: vec![vec![false; num_exprs]],
305            has_grouping_set: false,
306        }
307    }
308
309    /// Calculate GROUP BY expressions nullable
310    pub fn exprs_nullable(&self) -> Vec<bool> {
311        let mut exprs_nullable = vec![false; self.expr.len()];
312        for group in self.groups.iter() {
313            group.iter().enumerate().for_each(|(index, is_null)| {
314                if *is_null {
315                    exprs_nullable[index] = true;
316                }
317            })
318        }
319        exprs_nullable
320    }
321
322    /// Returns true if this has no grouping at all (including no GROUPING SETS)
323    pub fn is_true_no_grouping(&self) -> bool {
324        self.is_empty() && !self.has_grouping_set
325    }
326
327    /// Returns the group expressions
328    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
329        &self.expr
330    }
331
332    /// Returns the null expressions
333    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
334        &self.null_expr
335    }
336
337    /// Returns the group null masks
338    pub fn groups(&self) -> &[Vec<bool>] {
339        &self.groups
340    }
341
342    /// Returns true if this grouping uses GROUPING SETS, CUBE or ROLLUP.
343    pub fn has_grouping_set(&self) -> bool {
344        self.has_grouping_set
345    }
346
347    /// Returns true if this `PhysicalGroupBy` has no group expressions
348    pub fn is_empty(&self) -> bool {
349        self.expr.is_empty()
350    }
351
352    /// Returns true if this is a "simple" GROUP BY (not using GROUPING SETS/CUBE/ROLLUP).
353    /// This determines whether the `__grouping_id` column is included in the output schema.
354    pub fn is_single(&self) -> bool {
355        !self.has_grouping_set
356    }
357
358    /// Calculate GROUP BY expressions according to input schema.
359    pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
360        self.expr
361            .iter()
362            .map(|(expr, _alias)| Arc::clone(expr))
363            .collect()
364    }
365
366    /// The number of expressions in the output schema.
367    fn num_output_exprs(&self) -> usize {
368        let mut num_exprs = self.expr.len();
369        if self.has_grouping_set {
370            num_exprs += 1
371        }
372        num_exprs
373    }
374
375    /// Return grouping expressions as they occur in the output schema.
376    pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
377        let num_output_exprs = self.num_output_exprs();
378        let mut output_exprs = Vec::with_capacity(num_output_exprs);
379        output_exprs.extend(
380            self.expr
381                .iter()
382                .enumerate()
383                .take(num_output_exprs)
384                .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
385        );
386        if self.has_grouping_set {
387            output_exprs.push(Arc::new(Column::new(
388                Aggregate::INTERNAL_GROUPING_ID,
389                self.expr.len(),
390            )) as _);
391        }
392        output_exprs
393    }
394
395    /// Returns the number expression as grouping keys.
396    pub fn num_group_exprs(&self) -> usize {
397        self.expr.len() + usize::from(self.has_grouping_set)
398    }
399
400    /// Returns the Arrow data type of the `__grouping_id` column.
401    ///
402    /// The type is chosen to be wide enough to hold both the semantic bitmask
403    /// (in the low `n` bits, where `n` is the number of grouping expressions)
404    /// and the duplicate ordinal (in the high bits).
405    fn grouping_id_data_type(&self) -> DataType {
406        Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups))
407    }
408
409    pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
410        Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
411    }
412
413    /// Returns the fields that are used as the grouping keys.
414    fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
415        let mut fields = Vec::with_capacity(self.num_group_exprs());
416        for ((expr, name), group_expr_nullable) in
417            self.expr.iter().zip(self.exprs_nullable())
418        {
419            fields.push(
420                Field::new(
421                    name,
422                    expr.data_type(input_schema)?,
423                    group_expr_nullable || expr.nullable(input_schema)?,
424                )
425                .with_metadata(expr.return_field(input_schema)?.metadata().clone())
426                .into(),
427            );
428        }
429        if self.has_grouping_set {
430            fields.push(
431                Field::new(
432                    Aggregate::INTERNAL_GROUPING_ID,
433                    self.grouping_id_data_type(),
434                    false,
435                )
436                .into(),
437            );
438        }
439        Ok(fields)
440    }
441
442    /// Returns the output fields of the group by.
443    ///
444    /// This might be different from the `group_fields` that might contain internal expressions that
445    /// should not be part of the output schema.
446    fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
447        let mut fields = self.group_fields(input_schema)?;
448        fields.truncate(self.num_output_exprs());
449        Ok(fields)
450    }
451
452    /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
453    /// aggregation.
454    pub fn as_final(&self) -> PhysicalGroupBy {
455        let expr: Vec<_> =
456            self.output_exprs()
457                .into_iter()
458                .zip(
459                    self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
460                        Aggregate::INTERNAL_GROUPING_ID.to_owned(),
461                    )),
462                )
463                .collect();
464        let num_exprs = expr.len();
465        let groups = if self.expr.is_empty() && !self.has_grouping_set {
466            // No GROUP BY expressions - should have no groups
467            vec![]
468        } else {
469            vec![vec![false; num_exprs]]
470        };
471        Self {
472            expr,
473            null_expr: vec![],
474            groups,
475            has_grouping_set: false,
476        }
477    }
478}
479
480impl PartialEq for PhysicalGroupBy {
481    fn eq(&self, other: &PhysicalGroupBy) -> bool {
482        self.expr.len() == other.expr.len()
483            && self
484                .expr
485                .iter()
486                .zip(other.expr.iter())
487                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
488            && self.null_expr.len() == other.null_expr.len()
489            && self
490                .null_expr
491                .iter()
492                .zip(other.null_expr.iter())
493                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
494            && self.groups == other.groups
495            && self.has_grouping_set == other.has_grouping_set
496    }
497}
498
499#[expect(clippy::large_enum_variant)]
500enum StreamType {
501    AggregateStream(AggregateStream),
502    GroupedHash(GroupedHashAggregateStream),
503    GroupedPriorityQueue(GroupedTopKAggregateStream),
504}
505
506impl From<StreamType> for SendableRecordBatchStream {
507    fn from(stream: StreamType) -> Self {
508        match stream {
509            StreamType::AggregateStream(stream) => Box::pin(stream),
510            StreamType::GroupedHash(stream) => Box::pin(stream),
511            StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
512        }
513    }
514}
515
516/// # Aggregate Dynamic Filter Pushdown Overview
517///
518/// For queries like
519///   -- `example_table(type TEXT, val INT)`
520///   SELECT min(val)
521///   FROM example_table
522///   WHERE type='A';
523///
524/// And `example_table`'s physical representation is a partitioned parquet file with
525/// column statistics
526/// - part-0.parquet: val {min=0, max=100}
527/// - part-1.parquet: val {min=100, max=200}
528/// - ...
529/// - part-100.parquet: val {min=10000, max=10100}
530///
531/// After scanning the 1st file, we know we only have to read files if their minimal
532/// value on `val` column is less than 0, the minimal `val` value in the 1st file.
533///
534/// We can skip scanning the remaining file by implementing dynamic filter, the
535/// intuition is we keep a shared data structure for current min in both `AggregateExec`
536/// and `DataSourceExec`, and let it update during execution, so the scanner can
537/// know during execution if it's possible to skip scanning certain files. See
538/// physical optimizer rule `FilterPushdown` for details.
539///
540/// # Implementation
541///
542/// ## Enable Condition
543/// - No grouping (no `GROUP BY` clause in the sql, only a single global group to aggregate)
544/// - The aggregate expression must be `min`/`max`, and evaluate directly on columns.
545///   Note multiple aggregate expressions that satisfy this requirement are allowed,
546///   and a dynamic filter will be constructed combining all applicable expr's
547///   states. See more in the following example with dynamic filter on multiple columns.
548///
549/// ## Filter Construction
550/// The filter is kept in the `DataSourceExec`, and it will gets update during execution,
551/// the reader will interpret it as "the upstream only needs rows that such filter
552/// predicate is evaluated to true", and certain scanner implementation like `parquet`
553/// can evalaute column statistics on those dynamic filters, to decide if they can
554/// prune a whole range.
555///
556/// ### Examples
557/// - Expr: `min(a)`, Dynamic Filter: `a < a_cur_min`
558/// - Expr: `min(a), max(a), min(b)`, Dynamic Filter: `(a < a_cur_min) OR (a > a_cur_max) OR (b < b_cur_min)`
559#[derive(Debug, Clone)]
560struct AggrDynFilter {
561    /// The physical expr for the dynamic filter shared between the `AggregateExec`
562    /// and the parquet scanner.
563    filter: Arc<DynamicFilterPhysicalExpr>,
564    /// The current bounds for the dynamic filter, updates during the execution to
565    /// tighten the bound for more effective pruning.
566    ///
567    /// Each vector element is for the accumulators that support dynamic filter.
568    /// e.g. This `AggregateExec` has accumulator:
569    /// min(a), avg(a), max(b)
570    /// And this field stores [PerAccumulatorDynFilter(min(a)), PerAccumulatorDynFilter(min(b))]
571    supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
572}
573
574// ---- Aggregate Dynamic Filter Utility Structs ----
575
576/// Aggregate expressions that support the dynamic filter pushdown in aggregation.
577/// See comments in [`AggrDynFilter`] for conditions.
578#[derive(Debug, Clone)]
579struct PerAccumulatorDynFilter {
580    aggr_type: DynamicFilterAggregateType,
581    /// During planning and optimization, the parent structure is kept in `AggregateExec`,
582    /// this index is into `aggr_expr` vec inside `AggregateExec`.
583    /// During execution, the parent struct is moved into `AggregateStream` (stream
584    /// for no grouping aggregate execution), and this index is into    `aggregate_expressions`
585    /// vec inside `AggregateStreamInner`
586    aggr_index: usize,
587    // The current bound. Shared among all streams.
588    shared_bound: Arc<Mutex<ScalarValue>>,
589}
590
591/// Aggregate types that are supported for dynamic filter in `AggregateExec`
592#[derive(Debug, Clone)]
593enum DynamicFilterAggregateType {
594    Min,
595    Max,
596}
597
598/// Configuration for limit-based optimizations in aggregation
599#[derive(Debug, Clone, Copy, PartialEq, Eq)]
600pub struct LimitOptions {
601    /// The maximum number of rows to return
602    pub limit: usize,
603    /// Optional ordering direction (true = descending, false = ascending)
604    /// This is used for TopK aggregation to maintain a priority queue with the correct ordering
605    pub descending: Option<bool>,
606}
607
608impl LimitOptions {
609    /// Create a new LimitOptions with a limit and no specific ordering
610    pub fn new(limit: usize) -> Self {
611        Self {
612            limit,
613            descending: None,
614        }
615    }
616
617    /// Create a new LimitOptions with a limit and ordering direction
618    pub fn new_with_order(limit: usize, descending: bool) -> Self {
619        Self {
620            limit,
621            descending: Some(descending),
622        }
623    }
624
625    pub fn limit(&self) -> usize {
626        self.limit
627    }
628
629    pub fn descending(&self) -> Option<bool> {
630        self.descending
631    }
632}
633
634/// Hash aggregate execution plan
635#[derive(Debug, Clone)]
636pub struct AggregateExec {
637    /// Aggregation mode (full, partial)
638    mode: AggregateMode,
639    /// Group by expressions
640    /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance.
641    group_by: Arc<PhysicalGroupBy>,
642    /// Aggregate expressions
643    /// The same reason to [`Arc`] it as for [`Self::group_by`].
644    aggr_expr: Arc<[Arc<AggregateFunctionExpr>]>,
645    /// FILTER (WHERE clause) expression for each aggregate expression
646    /// The same reason to [`Arc`] it as for [`Self::group_by`].
647    filter_expr: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
648    /// Configuration for limit-based optimizations
649    limit_options: Option<LimitOptions>,
650    /// Input plan, could be a partial aggregate or the input to the aggregate
651    pub input: Arc<dyn ExecutionPlan>,
652    /// Schema after the aggregate is applied. Contains the group by columns followed by the
653    /// aggregate outputs.
654    schema: SchemaRef,
655    /// Input schema before any aggregation is applied. For partial aggregate this will be the
656    /// same as input.schema() but for the final aggregate it will be the same as the input
657    /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`.
658    /// We need the input schema of partial aggregate to be able to deserialize aggregate
659    /// expressions from protobuf for final aggregate.
660    pub input_schema: SchemaRef,
661    /// Execution metrics
662    metrics: ExecutionPlanMetricsSet,
663    required_input_ordering: Option<OrderingRequirements>,
664    /// Describes how the input is ordered relative to the group by columns
665    input_order_mode: InputOrderMode,
666    cache: Arc<PlanProperties>,
667    /// During initialization, if the plan supports dynamic filtering (see [`AggrDynFilter`]),
668    /// it is set to `Some(..)` regardless of whether it can be pushed down to a child node.
669    ///
670    /// During filter pushdown optimization, if a child node can accept this filter,
671    /// it remains `Some(..)` to enable dynamic filtering during aggregate execution;
672    /// otherwise, it is cleared to `None`.
673    dynamic_filter: Option<Arc<AggrDynFilter>>,
674}
675
676impl AggregateExec {
677    /// Function used in `OptimizeAggregateOrder` optimizer rule,
678    /// where we need parts of the new value, others cloned from the old one
679    /// Rewrites aggregate exec with new aggregate expressions.
680    pub fn with_new_aggr_exprs(
681        &self,
682        aggr_expr: impl Into<Arc<[Arc<AggregateFunctionExpr>]>>,
683    ) -> Self {
684        Self {
685            aggr_expr: aggr_expr.into(),
686            // clone the rest of the fields
687            required_input_ordering: self.required_input_ordering.clone(),
688            metrics: ExecutionPlanMetricsSet::new(),
689            input_order_mode: self.input_order_mode.clone(),
690            cache: Arc::clone(&self.cache),
691            mode: self.mode,
692            group_by: Arc::clone(&self.group_by),
693            filter_expr: Arc::clone(&self.filter_expr),
694            limit_options: self.limit_options,
695            input: Arc::clone(&self.input),
696            schema: Arc::clone(&self.schema),
697            input_schema: Arc::clone(&self.input_schema),
698            dynamic_filter: self.dynamic_filter.clone(),
699        }
700    }
701
702    /// Clone this exec, overriding only the limit hint.
703    pub fn with_new_limit_options(&self, limit_options: Option<LimitOptions>) -> Self {
704        Self {
705            limit_options,
706            // clone the rest of the fields
707            required_input_ordering: self.required_input_ordering.clone(),
708            metrics: ExecutionPlanMetricsSet::new(),
709            input_order_mode: self.input_order_mode.clone(),
710            cache: Arc::clone(&self.cache),
711            mode: self.mode,
712            group_by: Arc::clone(&self.group_by),
713            aggr_expr: Arc::clone(&self.aggr_expr),
714            filter_expr: Arc::clone(&self.filter_expr),
715            input: Arc::clone(&self.input),
716            schema: Arc::clone(&self.schema),
717            input_schema: Arc::clone(&self.input_schema),
718            dynamic_filter: self.dynamic_filter.clone(),
719        }
720    }
721
722    pub fn cache(&self) -> &PlanProperties {
723        &self.cache
724    }
725
726    /// Create a new hash aggregate execution plan
727    pub fn try_new(
728        mode: AggregateMode,
729        group_by: impl Into<Arc<PhysicalGroupBy>>,
730        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
731        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
732        input: Arc<dyn ExecutionPlan>,
733        input_schema: SchemaRef,
734    ) -> Result<Self> {
735        let group_by = group_by.into();
736        let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
737
738        let schema = Arc::new(schema);
739        AggregateExec::try_new_with_schema(
740            mode,
741            group_by,
742            aggr_expr,
743            filter_expr,
744            input,
745            input_schema,
746            schema,
747        )
748    }
749
750    /// Create a new hash aggregate execution plan with the given schema.
751    /// This constructor isn't part of the public API, it is used internally
752    /// by DataFusion to enforce schema consistency during when re-creating
753    /// `AggregateExec`s inside optimization rules. Schema field names of an
754    /// `AggregateExec` depends on the names of aggregate expressions. Since
755    /// a rule may re-write aggregate expressions (e.g. reverse them) during
756    /// initialization, field names may change inadvertently if one re-creates
757    /// the schema in such cases.
758    fn try_new_with_schema(
759        mode: AggregateMode,
760        group_by: impl Into<Arc<PhysicalGroupBy>>,
761        mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
762        filter_expr: impl Into<Arc<[Option<Arc<dyn PhysicalExpr>>]>>,
763        input: Arc<dyn ExecutionPlan>,
764        input_schema: SchemaRef,
765        schema: SchemaRef,
766    ) -> Result<Self> {
767        let group_by = group_by.into();
768        let filter_expr = filter_expr.into();
769
770        // Make sure arguments are consistent in size
771        assert_eq_or_internal_err!(
772            aggr_expr.len(),
773            filter_expr.len(),
774            "Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
775            aggr_expr,
776            filter_expr
777        );
778
779        let input_eq_properties = input.equivalence_properties();
780        // Get GROUP BY expressions:
781        let groupby_exprs = group_by.input_exprs();
782        // If existing ordering satisfies a prefix of the GROUP BY expressions,
783        // prefix requirements with this section. In this case, aggregation will
784        // work more efficiently.
785        // Copy the `PhysicalSortExpr`s to retain the sort options.
786        let (new_sort_exprs, indices) =
787            input_eq_properties.find_longest_permutation(&groupby_exprs)?;
788
789        let mut new_requirements = new_sort_exprs
790            .into_iter()
791            .map(PhysicalSortRequirement::from)
792            .collect::<Vec<_>>();
793
794        let req = get_finer_aggregate_exprs_requirement(
795            &mut aggr_expr,
796            &group_by,
797            input_eq_properties,
798            &mode,
799        )?;
800        new_requirements.extend(req);
801
802        let required_input_ordering =
803            LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
804
805        // If our aggregation has grouping sets then our base grouping exprs will
806        // be expanded based on the flags in `group_by.groups` where for each
807        // group we swap the grouping expr for `null` if the flag is `true`
808        // That means that each index in `indices` is valid if and only if
809        // it is not null in every group
810        let indices: Vec<usize> = indices
811            .into_iter()
812            .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
813            .collect();
814
815        let input_order_mode = if indices.len() == groupby_exprs.len()
816            && !indices.is_empty()
817            && group_by.groups.len() == 1
818        {
819            InputOrderMode::Sorted
820        } else if !indices.is_empty() {
821            InputOrderMode::PartiallySorted(indices)
822        } else {
823            InputOrderMode::Linear
824        };
825
826        // construct a map from the input expression to the output expression of the Aggregation group by
827        let group_expr_mapping =
828            ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
829
830        let cache = Self::compute_properties(
831            &input,
832            Arc::clone(&schema),
833            &group_expr_mapping,
834            &mode,
835            &input_order_mode,
836            aggr_expr.as_ref(),
837        )?;
838
839        let mut exec = AggregateExec {
840            mode,
841            group_by,
842            aggr_expr: aggr_expr.into(),
843            filter_expr,
844            input,
845            schema,
846            input_schema,
847            metrics: ExecutionPlanMetricsSet::new(),
848            required_input_ordering,
849            limit_options: None,
850            input_order_mode,
851            cache: Arc::new(cache),
852            dynamic_filter: None,
853        };
854
855        exec.init_dynamic_filter();
856
857        Ok(exec)
858    }
859
860    /// Aggregation mode (full, partial)
861    pub fn mode(&self) -> &AggregateMode {
862        &self.mode
863    }
864
865    /// Set the limit options for this AggExec
866    pub fn with_limit_options(mut self, limit_options: Option<LimitOptions>) -> Self {
867        self.limit_options = limit_options;
868        self
869    }
870
871    /// Get the limit options (if set)
872    pub fn limit_options(&self) -> Option<LimitOptions> {
873        self.limit_options
874    }
875
876    /// Grouping expressions
877    pub fn group_expr(&self) -> &PhysicalGroupBy {
878        &self.group_by
879    }
880
881    /// Grouping expressions as they occur in the output schema
882    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
883        self.group_by.output_exprs()
884    }
885
886    /// Aggregate expressions
887    pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
888        &self.aggr_expr
889    }
890
891    /// FILTER (WHERE clause) expression for each aggregate expression
892    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
893        &self.filter_expr
894    }
895
896    /// Returns the dynamic filter expression for this aggregate, if set.
897    pub fn dynamic_filter_expr(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> {
898        self.dynamic_filter.as_ref().map(|df| &df.filter)
899    }
900
901    /// Replace the dynamic filter expression. This method errors if the aggregate does not
902    /// support dynamic filtering or if the filter expression is incompatible with this
903    /// [`AggregateExec`].
904    pub fn with_dynamic_filter_expr(
905        mut self,
906        filter: Arc<DynamicFilterPhysicalExpr>,
907    ) -> Result<Self> {
908        // If there is no dynamic filter state initialized via `try_new`, then
909        // we can safely assume that the aggregate does not support dynamic filtering.
910        let Some(dyn_filter) = self.dynamic_filter.as_ref() else {
911            return internal_err!("Aggregate does not support dynamic filtering");
912        };
913
914        // Validate that the filter is compatible with the aggregation columns.
915        let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info);
916        if cols.len() != filter.children().len() {
917            return internal_err!(
918                "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
919            );
920        }
921        for (col, child) in cols.iter().zip(filter.children()) {
922            if !col.eq(child) {
923                return internal_err!(
924                    "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
925                );
926            }
927        }
928
929        // Overwrite our filter
930        self.dynamic_filter = Some(Arc::new(AggrDynFilter {
931            filter,
932            supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(),
933        }));
934        Ok(self)
935    }
936
937    /// Input plan
938    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
939        &self.input
940    }
941
942    /// Get the input schema before any aggregates are applied
943    pub fn input_schema(&self) -> SchemaRef {
944        Arc::clone(&self.input_schema)
945    }
946
947    fn execute_typed(
948        &self,
949        partition: usize,
950        context: &Arc<TaskContext>,
951    ) -> Result<StreamType> {
952        if self.group_by.is_true_no_grouping() {
953            return Ok(StreamType::AggregateStream(AggregateStream::new(
954                self, context, partition,
955            )?));
956        }
957
958        // grouping by an expression that has a sort/limit upstream
959        if let Some(config) = self.limit_options
960            && !self.is_unordered_unfiltered_group_by_distinct()
961        {
962            return Ok(StreamType::GroupedPriorityQueue(
963                GroupedTopKAggregateStream::new(self, context, partition, config.limit)?,
964            ));
965        }
966
967        // grouping by something else and we need to just materialize all results
968        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
969            self, context, partition,
970        )?))
971    }
972
973    /// Finds the DataType and SortDirection for this Aggregate, if there is one
974    pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
975        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
976        agg_expr.get_minmax_desc()
977    }
978
979    /// true, if this Aggregate has a group-by with no required or explicit ordering,
980    /// no filtering and no aggregate expressions
981    /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule
982    /// on an AggregateExec.
983    pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
984        if self
985            .limit_options()
986            .and_then(|config| config.descending)
987            .is_some()
988        {
989            return false;
990        }
991        // ensure there is a group by
992        if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
993            return false;
994        }
995        // ensure there are no aggregate expressions
996        if !self.aggr_expr().is_empty() {
997            return false;
998        }
999        // ensure there are no filters on aggregate expressions; the above check
1000        // may preclude this case
1001        if self.filter_expr().iter().any(|e| e.is_some()) {
1002            return false;
1003        }
1004        // ensure there are no order by expressions
1005        if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
1006            return false;
1007        }
1008        // ensure there is no output ordering; can this rule be relaxed?
1009        if self.properties().output_ordering().is_some() {
1010            return false;
1011        }
1012        // ensure no ordering is required on the input
1013        if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
1014            return matches!(requirement, OrderingRequirements::Hard(_));
1015        }
1016        true
1017    }
1018
1019    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1020    pub fn compute_properties(
1021        input: &Arc<dyn ExecutionPlan>,
1022        schema: SchemaRef,
1023        group_expr_mapping: &ProjectionMapping,
1024        mode: &AggregateMode,
1025        input_order_mode: &InputOrderMode,
1026        aggr_exprs: &[Arc<AggregateFunctionExpr>],
1027    ) -> Result<PlanProperties> {
1028        // Construct equivalence properties:
1029        let mut eq_properties = input
1030            .equivalence_properties()
1031            .project(group_expr_mapping, schema);
1032
1033        // If the group by is empty, then we ensure that the operator will produce
1034        // only one row, and mark the generated result as a constant value.
1035        if group_expr_mapping.is_empty() {
1036            let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
1037                let column = Arc::new(Column::new(func.name(), idx));
1038                ConstExpr::from(column as Arc<dyn PhysicalExpr>)
1039            });
1040            eq_properties.add_constants(new_constants)?;
1041        }
1042
1043        // Group by expression will be a distinct value after the aggregation.
1044        // Add it into the constraint set.
1045        let mut constraints = eq_properties.constraints().to_vec();
1046        let new_constraint = Constraint::Unique(
1047            group_expr_mapping
1048                .iter()
1049                .flat_map(|(_, target_cols)| {
1050                    target_cols.iter().flat_map(|(expr, _)| {
1051                        expr.downcast_ref::<Column>().map(|c| c.index())
1052                    })
1053                })
1054                .collect(),
1055        );
1056        constraints.push(new_constraint);
1057        eq_properties =
1058            eq_properties.with_constraints(Constraints::new_unverified(constraints));
1059
1060        // Get output partitioning:
1061        let input_partitioning = input.output_partitioning().clone();
1062        let output_partitioning = match mode.input_mode() {
1063            AggregateInputMode::Raw => {
1064                // First stage aggregation will not change the output partitioning,
1065                // but needs to respect aliases (e.g. mapping in the GROUP BY
1066                // expression).
1067                let input_eq_properties = input.equivalence_properties();
1068                input_partitioning.project(group_expr_mapping, input_eq_properties)
1069            }
1070            AggregateInputMode::Partial => input_partitioning.clone(),
1071        };
1072
1073        // TODO: Emission type and boundedness information can be enhanced here
1074        let emission_type = if *input_order_mode == InputOrderMode::Linear {
1075            EmissionType::Final
1076        } else {
1077            input.pipeline_behavior()
1078        };
1079
1080        Ok(PlanProperties::new(
1081            eq_properties,
1082            output_partitioning,
1083            emission_type,
1084            input.boundedness(),
1085        ))
1086    }
1087
1088    pub fn input_order_mode(&self) -> &InputOrderMode {
1089        &self.input_order_mode
1090    }
1091
1092    /// Estimates output statistics for this aggregate node.
1093    ///
1094    /// For grouped aggregations with known input row count > 1, the output row
1095    /// count is estimated as:
1096    ///
1097    /// ```text
1098    /// ndv        = sum over each grouping set of product(max(NDV_i + nulls_i, 1))
1099    /// output_rows = input_rows                       // baseline
1100    /// output_rows = min(output_rows, ndv)             // if NDV available
1101    /// output_rows = min(output_rows, limit)           // if TopK active
1102    /// ```
1103    ///
1104    /// **Example 1 — single group key:**
1105    /// `GROUP BY city` where input_rows = 10,000, NDV(city) = 200
1106    /// → output_rows = min(10_000, 200) = 200
1107    ///
1108    /// **Example 2 — two group keys with TopK:**
1109    /// `GROUP BY city, category` where input_rows = 10,000, NDV(city) = 200,
1110    /// NDV(category) = 5, limit = 100
1111    /// → ndv = 200 × 5 = 1,000
1112    /// → output_rows = min(10_000, 1_000) = 1,000
1113    /// → output_rows = min(1_000, 100) = 100
1114    ///
1115    /// When `input_rows` is absent but NDV is available, falls back to:
1116    ///
1117    /// ```text
1118    /// output_rows = min(ndv, limit)   // if both available
1119    /// output_rows = ndv               // if only NDV available
1120    /// output_rows = limit             // if only limit available
1121    /// ```
1122    ///
1123    /// NDV estimation details (see [`Self::compute_group_ndv`]):
1124    /// - For each grouping set, only active (non-NULL) columns contribute
1125    /// - Per-column contribution is `max(NDV + null_adj, 1)` where `null_adj`
1126    ///   is 1 when nulls are present, 0 otherwise (a null group is a distinct
1127    ///   output row; `.max(1)` prevents a zero NDV from zeroing the product)
1128    /// - Per-set products are summed across all grouping sets
1129    /// - Requires NDV stats for ALL active group-by columns; if any lacks stats,
1130    ///   falls back to `input_rows` (or `Absent` if that is also unknown)
1131    fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
1132        // TODO stats: group expressions:
1133        // - once expressions will be able to compute their own stats, use it here
1134        // - case where we group by on a column for which with have the `distinct` stat
1135        // TODO stats: aggr expression:
1136        // - aggregations sometimes also preserve invariants such as min, max...
1137
1138        let column_statistics = {
1139            // self.schema: [<group by exprs>, <aggregate exprs>]
1140            let mut column_statistics = Statistics::unknown_column(&self.schema());
1141
1142            for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
1143                if let Some(col) = expr.downcast_ref::<Column>() {
1144                    let child_col_stats =
1145                        &child_statistics.column_statistics[col.index()];
1146                    column_statistics[idx].max_value = child_col_stats.max_value.clone();
1147                    column_statistics[idx].min_value = child_col_stats.min_value.clone();
1148                    column_statistics[idx].distinct_count =
1149                        child_col_stats.distinct_count;
1150                }
1151            }
1152
1153            column_statistics
1154        };
1155        match self.mode {
1156            AggregateMode::Final | AggregateMode::FinalPartitioned
1157                if self.group_by.expr.is_empty() =>
1158            {
1159                let total_byte_size =
1160                    Self::calculate_scaled_byte_size(child_statistics, 1);
1161
1162                Ok(Statistics {
1163                    num_rows: Precision::Exact(1),
1164                    column_statistics,
1165                    total_byte_size,
1166                })
1167            }
1168            _ => {
1169                let num_rows = self.estimate_num_rows(child_statistics);
1170
1171                let total_byte_size = num_rows
1172                    .get_value()
1173                    .and_then(|&output_rows| {
1174                        Self::calculate_scaled_byte_size(child_statistics, output_rows)
1175                            .get_value()
1176                            .map(|&bytes| Precision::Inexact(bytes))
1177                    })
1178                    .unwrap_or(Precision::Absent);
1179
1180                Ok(Statistics {
1181                    num_rows,
1182                    column_statistics,
1183                    total_byte_size,
1184                })
1185            }
1186        }
1187    }
1188
1189    /// Estimates the output row count for grouped aggregations, combining NDV,
1190    /// input row count, and TopK limit into a single [`Precision<usize>`].
1191    fn estimate_num_rows(&self, child_statistics: &Statistics) -> Precision<usize> {
1192        let ndv = if !self.group_by.expr.is_empty() {
1193            self.compute_group_ndv(child_statistics)
1194        } else {
1195            None
1196        };
1197        let limit = self.limit_options.as_ref().map(|lo| lo.limit);
1198
1199        if let Some(&value) = child_statistics.num_rows.get_value() {
1200            if value > 1 {
1201                let mut num_rows = child_statistics.num_rows.to_inexact();
1202                if let Some(ndv) = ndv {
1203                    num_rows = num_rows.map(|n| n.min(ndv));
1204                }
1205                if let Some(limit) = limit {
1206                    num_rows = num_rows.map(|n| n.min(limit));
1207                }
1208                num_rows
1209            } else if value == 0 {
1210                child_statistics.num_rows
1211            } else {
1212                let grouping_set_num = self.group_by.groups.len();
1213                let mut num_rows =
1214                    child_statistics.num_rows.map(|x| x * grouping_set_num);
1215                if let Some(limit) = limit {
1216                    num_rows = num_rows.map(|n| n.min(limit));
1217                }
1218                num_rows
1219            }
1220        } else {
1221            match (ndv, limit) {
1222                (Some(n), Some(l)) => Precision::Inexact(n.min(l)),
1223                (Some(n), None) => Precision::Inexact(n),
1224                (None, Some(l)) => Precision::Inexact(l),
1225                (None, None) => Precision::Absent,
1226            }
1227        }
1228    }
1229
1230    /// Computes the estimated number of distinct groups across all grouping sets.
1231    /// For each grouping set, computes `product(NDV_i + null_adj_i)` for active columns,
1232    /// then sums across all sets. Returns `None` if any active column is not a direct
1233    /// column reference or lacks `distinct_count` stats. Non-column expressions
1234    /// (e.g. `abs(a)`) are not yet supported because expression-level statistics
1235    /// propagation is still in progress (see <https://github.com/apache/datafusion/pull/21122>).
1236    /// When `null_count` is absent or unknown, null_adjustment defaults to 0.
1237    ///
1238    /// **Single key:** `GROUP BY a` where NDV(a) = 100, null_count(a) = 5
1239    /// → product = max(100 + 1, 1) = 101, total = 101
1240    ///
1241    /// **Two keys:** `GROUP BY a, b` where NDV(a) = 100, NDV(b) = 50, no nulls
1242    /// → product = 100 × 50 = 5,000, total = 5,000
1243    ///
1244    /// **Grouping sets:** `GROUPING SETS ((a), (b), (a, b))` with NDV(a) = 100, NDV(b) = 50
1245    /// → set(a) = 100, set(b) = 50, set(a, b) = 100 × 50 = 5,000
1246    /// → total = 100 + 50 + 5,000 = 5,150
1247    fn compute_group_ndv(&self, child_statistics: &Statistics) -> Option<usize> {
1248        let mut total: usize = 0;
1249        for group_mask in &self.group_by.groups {
1250            let mut set_product: usize = 1;
1251            for (j, (expr, _)) in self.group_by.expr.iter().enumerate() {
1252                if group_mask[j] {
1253                    continue;
1254                }
1255                let col = expr.downcast_ref::<Column>()?;
1256                let col_stats = &child_statistics.column_statistics[col.index()];
1257                let ndv = *col_stats.distinct_count.get_value()?;
1258                let null_adjustment = match col_stats.null_count.get_value() {
1259                    Some(&n) if n > 0 => 1usize,
1260                    _ => 0,
1261                };
1262                set_product = set_product
1263                    .saturating_mul(ndv.saturating_add(null_adjustment).max(1));
1264            }
1265            total = total.saturating_add(set_product);
1266        }
1267        Some(total)
1268    }
1269
1270    /// Check if dynamic filter is possible for the current plan node.
1271    /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
1272    /// - If not supported, `self.dynamic_filter` should be kept `None`
1273    fn init_dynamic_filter(&mut self) {
1274        if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
1275            debug_assert!(
1276                self.dynamic_filter.is_none(),
1277                "The current operator node does not support dynamic filter"
1278            );
1279            return;
1280        }
1281
1282        // Already initialized.
1283        if self.dynamic_filter.is_some() {
1284            return;
1285        }
1286
1287        // Collect supported accumulators
1288        // It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1289        // to `AggregateStream`
1290        let mut aggr_dyn_filters = Vec::new();
1291        // All column references in the dynamic filter, used when initializing the dynamic
1292        // filter, and it's used to decide if this dynamic filter is able to get push
1293        // through certain node during optimization.
1294        let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
1295        for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
1296            // 1. Only `min` or `max` aggregate function
1297            let fun_name = aggr_expr.fun().name();
1298            // HACK: Should check the function type more precisely
1299            // Issue: <https://github.com/apache/datafusion/issues/18643>
1300            let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
1301                DynamicFilterAggregateType::Min
1302            } else if fun_name.eq_ignore_ascii_case("max") {
1303                DynamicFilterAggregateType::Max
1304            } else {
1305                return;
1306            };
1307
1308            // 2. arg should be only 1 column reference
1309            if let [arg] = aggr_expr.expressions().as_slice()
1310                && arg.is::<Column>()
1311            {
1312                all_cols.push(Arc::clone(arg));
1313                aggr_dyn_filters.push(PerAccumulatorDynFilter {
1314                    aggr_type,
1315                    aggr_index: i,
1316                    shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
1317                });
1318            }
1319        }
1320
1321        if !aggr_dyn_filters.is_empty() {
1322            self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1323                filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1324                supported_accumulators_info: aggr_dyn_filters,
1325            }))
1326        }
1327    }
1328
1329    // Collect column references for the dynamic filter expression from the supported accumulators.
1330    fn cols_for_dynamic_filter(
1331        &self,
1332        supported_accumulators_info: &[PerAccumulatorDynFilter],
1333    ) -> Vec<Arc<dyn PhysicalExpr>> {
1334        let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1335            .iter()
1336            .filter_map(|info| {
1337                // This should always be true due to how the supported accumulators
1338                // are constructed. See `init_dynamic_filter` for more details.
1339                if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice()
1340                    && arg.is::<Column>()
1341                {
1342                    return Some(Arc::clone(arg));
1343                }
1344                None
1345            })
1346            .collect();
1347        debug_assert!(all_cols.len() == supported_accumulators_info.len());
1348        all_cols
1349    }
1350
1351    /// Calculate scaled byte size based on row count ratio.
1352    /// Returns `Precision::Absent` if input statistics are insufficient.
1353    /// Returns `Precision::Inexact` with the scaled value otherwise.
1354    ///
1355    /// This is a simple heuristic that assumes uniform row sizes.
1356    #[inline]
1357    fn calculate_scaled_byte_size(
1358        input_stats: &Statistics,
1359        target_row_count: usize,
1360    ) -> Precision<usize> {
1361        match (
1362            input_stats.num_rows.get_value(),
1363            input_stats.total_byte_size.get_value(),
1364        ) {
1365            (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
1366                let bytes_per_row = input_bytes as f64 / input_rows as f64;
1367                let scaled_bytes =
1368                    (bytes_per_row * target_row_count as f64).ceil() as usize;
1369                Precision::Inexact(scaled_bytes)
1370            }
1371            _ => Precision::Absent,
1372        }
1373    }
1374
1375    fn with_new_children_and_same_properties(
1376        &self,
1377        mut children: Vec<Arc<dyn ExecutionPlan>>,
1378    ) -> Self {
1379        Self {
1380            input: children.swap_remove(0),
1381            metrics: ExecutionPlanMetricsSet::new(),
1382            ..Self::clone(self)
1383        }
1384    }
1385}
1386
1387impl DisplayAs for AggregateExec {
1388    fn fmt_as(
1389        &self,
1390        t: DisplayFormatType,
1391        f: &mut std::fmt::Formatter,
1392    ) -> std::fmt::Result {
1393        match t {
1394            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1395                let format_expr_with_alias =
1396                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1397                        let e = e.to_string();
1398                        if &e != alias {
1399                            format!("{e} as {alias}")
1400                        } else {
1401                            e
1402                        }
1403                    };
1404
1405                write!(f, "AggregateExec: mode={:?}", self.mode)?;
1406                let g: Vec<String> = if self.group_by.is_single() {
1407                    self.group_by
1408                        .expr
1409                        .iter()
1410                        .map(format_expr_with_alias)
1411                        .collect()
1412                } else {
1413                    self.group_by
1414                        .groups
1415                        .iter()
1416                        .map(|group| {
1417                            let terms = group
1418                                .iter()
1419                                .enumerate()
1420                                .map(|(idx, is_null)| {
1421                                    if *is_null {
1422                                        format_expr_with_alias(
1423                                            &self.group_by.null_expr[idx],
1424                                        )
1425                                    } else {
1426                                        format_expr_with_alias(&self.group_by.expr[idx])
1427                                    }
1428                                })
1429                                .collect::<Vec<String>>()
1430                                .join(", ");
1431                            format!("({terms})")
1432                        })
1433                        .collect()
1434                };
1435
1436                write!(f, ", gby=[{}]", g.join(", "))?;
1437
1438                let a: Vec<String> = self
1439                    .aggr_expr
1440                    .iter()
1441                    .map(|agg| format_aggregate_exec_expr(agg).to_string())
1442                    .collect();
1443                write!(f, ", aggr=[{}]", a.join(", "))?;
1444                if let Some(config) = self.limit_options {
1445                    write!(f, ", lim=[{}]", config.limit)?;
1446                }
1447
1448                if self.input_order_mode != InputOrderMode::Linear {
1449                    write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
1450                }
1451            }
1452            DisplayFormatType::TreeRender => {
1453                let format_expr_with_alias =
1454                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1455                        let expr_sql = fmt_sql(e.as_ref()).to_string();
1456                        if &expr_sql != alias {
1457                            format!("{expr_sql} as {alias}")
1458                        } else {
1459                            expr_sql
1460                        }
1461                    };
1462
1463                let g: Vec<String> = if self.group_by.is_single() {
1464                    self.group_by
1465                        .expr
1466                        .iter()
1467                        .map(format_expr_with_alias)
1468                        .collect()
1469                } else {
1470                    self.group_by
1471                        .groups
1472                        .iter()
1473                        .map(|group| {
1474                            let terms = group
1475                                .iter()
1476                                .enumerate()
1477                                .map(|(idx, is_null)| {
1478                                    if *is_null {
1479                                        format_expr_with_alias(
1480                                            &self.group_by.null_expr[idx],
1481                                        )
1482                                    } else {
1483                                        format_expr_with_alias(&self.group_by.expr[idx])
1484                                    }
1485                                })
1486                                .collect::<Vec<String>>()
1487                                .join(", ");
1488                            format!("({terms})")
1489                        })
1490                        .collect()
1491                };
1492                let a: Vec<String> = self
1493                    .aggr_expr
1494                    .iter()
1495                    .map(|agg| format_tree_aggregate_expr(agg).to_string())
1496                    .collect();
1497                writeln!(f, "mode={:?}", self.mode)?;
1498                if !g.is_empty() {
1499                    writeln!(f, "group_by={}", g.join(", "))?;
1500                }
1501                if !a.is_empty() {
1502                    writeln!(f, "aggr={}", a.join(", "))?;
1503                }
1504                if let Some(config) = self.limit_options {
1505                    writeln!(f, "limit={}", config.limit)?;
1506                }
1507            }
1508        }
1509        Ok(())
1510    }
1511}
1512
1513fn format_aggregate_exec_expr(agg: &AggregateFunctionExpr) -> Cow<'_, str> {
1514    match agg.human_display_alias() {
1515        Some(_) => format_human_display(agg.human_display(), agg.human_display_alias())
1516            .unwrap_or_else(|| Cow::Borrowed(agg.name())),
1517        None => Cow::Borrowed(agg.name()),
1518    }
1519}
1520
1521fn format_tree_aggregate_expr(agg: &AggregateFunctionExpr) -> Cow<'_, str> {
1522    format_human_display(agg.human_display(), agg.human_display_alias())
1523        .unwrap_or_else(|| Cow::Borrowed(agg.name()))
1524}
1525
1526fn format_human_display<'a>(
1527    human_display: Option<&'a str>,
1528    alias: Option<&'a str>,
1529) -> Option<Cow<'a, str>> {
1530    human_display.map(|human_display| match alias {
1531        Some(alias) => Cow::Owned(format!("{human_display} as {alias}")),
1532        None => Cow::Borrowed(human_display),
1533    })
1534}
1535
1536impl ExecutionPlan for AggregateExec {
1537    fn name(&self) -> &'static str {
1538        "AggregateExec"
1539    }
1540
1541    /// Return a reference to Any that can be used for down-casting
1542    fn properties(&self) -> &Arc<PlanProperties> {
1543        &self.cache
1544    }
1545
1546    fn required_input_distribution(&self) -> Vec<Distribution> {
1547        match &self.mode {
1548            AggregateMode::Partial | AggregateMode::PartialReduce => {
1549                vec![Distribution::UnspecifiedDistribution]
1550            }
1551            AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
1552                vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
1553            }
1554            AggregateMode::Final | AggregateMode::Single => {
1555                vec![Distribution::SinglePartition]
1556            }
1557        }
1558    }
1559
1560    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
1561        vec![self.required_input_ordering.clone()]
1562    }
1563
1564    /// The output ordering of [`AggregateExec`] is determined by its `group_by`
1565    /// columns. Although this method is not explicitly used by any optimizer
1566    /// rules yet, overriding the default implementation ensures that it
1567    /// accurately reflects the actual behavior.
1568    ///
1569    /// If the [`InputOrderMode`] is `Linear`, the `group_by` columns don't have
1570    /// an ordering, which means the results do not either. However, in the
1571    /// `Ordered` and `PartiallyOrdered` cases, the `group_by` columns do have
1572    /// an ordering, which is preserved in the output.
1573    fn maintains_input_order(&self) -> Vec<bool> {
1574        vec![self.input_order_mode != InputOrderMode::Linear]
1575    }
1576
1577    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1578        vec![&self.input]
1579    }
1580
1581    fn with_new_children(
1582        self: Arc<Self>,
1583        children: Vec<Arc<dyn ExecutionPlan>>,
1584    ) -> Result<Arc<dyn ExecutionPlan>> {
1585        check_if_same_properties!(self, children);
1586
1587        let mut me = AggregateExec::try_new_with_schema(
1588            self.mode,
1589            Arc::clone(&self.group_by),
1590            self.aggr_expr.to_vec(),
1591            Arc::clone(&self.filter_expr),
1592            Arc::clone(&children[0]),
1593            Arc::clone(&self.input_schema),
1594            Arc::clone(&self.schema),
1595        )?;
1596        me.limit_options = self.limit_options;
1597        me.dynamic_filter.clone_from(&self.dynamic_filter);
1598
1599        Ok(Arc::new(me))
1600    }
1601
1602    fn execute(
1603        &self,
1604        partition: usize,
1605        context: Arc<TaskContext>,
1606    ) -> Result<SendableRecordBatchStream> {
1607        self.execute_typed(partition, &context)
1608            .map(|stream| stream.into())
1609    }
1610
1611    fn metrics(&self) -> Option<MetricsSet> {
1612        Some(self.metrics.clone_inner())
1613    }
1614
1615    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
1616        let child_statistics = self.input().partition_statistics(partition)?;
1617        Ok(Arc::new(self.statistics_inner(&child_statistics)?))
1618    }
1619
1620    fn cardinality_effect(&self) -> CardinalityEffect {
1621        CardinalityEffect::LowerEqual
1622    }
1623
1624    /// Push down parent filters when possible (see implementation comment for details),
1625    /// and also pushdown self dynamic filters (see `AggrDynFilter` for details)
1626    fn gather_filters_for_pushdown(
1627        &self,
1628        phase: FilterPushdownPhase,
1629        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1630        config: &ConfigOptions,
1631    ) -> Result<FilterDescription> {
1632        // It's safe to push down filters through aggregates when filters only reference
1633        // grouping columns, because such filters determine which groups to compute, not
1634        // *how* to compute them. Each group's aggregate values (SUM, COUNT, etc.) are
1635        // calculated from the same input rows regardless of whether we filter before or
1636        // after grouping - filtering before just eliminates entire groups early.
1637        // This optimization is NOT safe for filters on aggregated columns (like filtering on
1638        // the result of SUM or COUNT), as those require computing all groups first.
1639
1640        // Build grouping columns using output indices because parent filters reference the
1641        // AggregateExec's output schema where grouping columns in the output schema. The
1642        // grouping expressions reference input columns which may not match the output schema.
1643        //
1644        // It is safe to assume that the output_schema contains group by columns in the same order
1645        // as the group by expression. See [`create_schema`] and [`AggregateExec`].
1646        let output_schema = self.schema();
1647        let grouping_columns: HashSet<_> = (0..self.group_by.expr().len())
1648            .map(|i| Column::new(output_schema.field(i).name(), i))
1649            .collect();
1650
1651        // Analyze each filter separately to determine if it can be pushed down
1652        let mut safe_filters = Vec::new();
1653        let mut unsafe_filters = Vec::new();
1654
1655        for filter in parent_filters {
1656            let filter_columns: HashSet<_> =
1657                collect_columns(&filter).into_iter().collect();
1658
1659            // Check if this filter references non-grouping columns
1660            let references_non_grouping = !grouping_columns.is_empty()
1661                && !filter_columns.is_subset(&grouping_columns);
1662
1663            if references_non_grouping {
1664                unsafe_filters.push(filter);
1665                continue;
1666            }
1667
1668            // For GROUPING SETS, verify this filter's columns appear in all grouping sets
1669            if self.group_by.groups().len() > 1 {
1670                let filter_column_indices: Vec<usize> = filter_columns
1671                    .iter()
1672                    .filter_map(|filter_col| {
1673                        grouping_columns.get(filter_col).map(|col| col.index())
1674                    })
1675                    .collect();
1676
1677                // Check if any of this filter's columns are missing from any grouping set
1678                let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1679                    filter_column_indices
1680                        .iter()
1681                        .any(|&idx| null_mask.get(idx) == Some(&true))
1682                });
1683
1684                if has_missing_column {
1685                    unsafe_filters.push(filter);
1686                    continue;
1687                }
1688            }
1689
1690            // This filter is safe to push down
1691            safe_filters.push(filter);
1692        }
1693
1694        // Build child filter description with both safe and unsafe filters
1695        let child = self.children()[0];
1696        let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1697
1698        // Add unsafe filters as unsupported
1699        child_desc.parent_filters.extend(
1700            unsafe_filters
1701                .into_iter()
1702                .map(PushedDownPredicate::unsupported),
1703        );
1704
1705        // Include self dynamic filter when it's possible
1706        if phase == FilterPushdownPhase::Post
1707            && config.optimizer.enable_aggregate_dynamic_filter_pushdown
1708            && let Some(self_dyn_filter) = &self.dynamic_filter
1709        {
1710            let dyn_filter = Arc::clone(&self_dyn_filter.filter);
1711            child_desc = child_desc.with_self_filter(dyn_filter);
1712        }
1713
1714        Ok(FilterDescription::new().with_child(child_desc))
1715    }
1716
1717    /// If child accepts self's dynamic filter, keep `self.dynamic_filter` with Some,
1718    /// otherwise clear it to None.
1719    fn handle_child_pushdown_result(
1720        &self,
1721        phase: FilterPushdownPhase,
1722        child_pushdown_result: ChildPushdownResult,
1723        _config: &ConfigOptions,
1724    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1725        let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
1726
1727        // If this node tried to pushdown some dynamic filter before, now we check
1728        // if the child accept the filter
1729        if phase == FilterPushdownPhase::Post
1730            && let Some(dyn_filter) = &self.dynamic_filter
1731        {
1732            // let child_accepts_dyn_filter = child_pushdown_result
1733            //     .self_filters
1734            //     .first()
1735            //     .map(|filters| {
1736            //         assert_eq_or_internal_err!(
1737            //             filters.len(),
1738            //             1,
1739            //             "Aggregate only pushdown one self dynamic filter"
1740            //         );
1741            //         let filter = filters.get(0).unwrap(); // Asserted above
1742            //         Ok(matches!(filter.discriminant, PushedDown::Yes))
1743            //     })
1744            //     .unwrap_or_else(|| internal_err!("The length of self filters equals to the number of child of this ExecutionPlan, so it must be 1"))?;
1745
1746            // HACK: The above snippet should be used, however, now the child reply
1747            // `PushDown::No` can indicate they're not able to push down row-level
1748            // filter, but still keep the filter for statistics pruning.
1749            // So here, we try to use ref count to determine if the dynamic filter
1750            // has actually be pushed down.
1751            // Issue: <https://github.com/apache/datafusion/issues/18856>
1752            let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
1753
1754            if !child_accepts_dyn_filter {
1755                // Child can't consume the self dynamic filter, so disable it by setting
1756                // to `None`
1757                let mut new_node = self.clone();
1758                new_node.dynamic_filter = None;
1759
1760                result = result
1761                    .with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
1762            }
1763        }
1764
1765        Ok(result)
1766    }
1767}
1768
1769/// Creates the output schema for an [`AggregateExec`] containing the group by columns followed
1770/// by the aggregate columns.
1771fn create_schema(
1772    input_schema: &Schema,
1773    group_by: &PhysicalGroupBy,
1774    aggr_expr: &[Arc<AggregateFunctionExpr>],
1775    mode: AggregateMode,
1776) -> Result<Schema> {
1777    let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1778    fields.extend(group_by.output_fields(input_schema)?);
1779
1780    match mode.output_mode() {
1781        AggregateOutputMode::Final => {
1782            // in final mode, the field with the final result of the accumulator
1783            for expr in aggr_expr {
1784                fields.push(expr.field())
1785            }
1786        }
1787        AggregateOutputMode::Partial => {
1788            // in partial mode, the fields of the accumulator's state
1789            for expr in aggr_expr {
1790                fields.extend(expr.state_fields()?.iter().cloned());
1791            }
1792        }
1793    }
1794
1795    Ok(Schema::new_with_metadata(
1796        fields,
1797        input_schema.metadata().clone(),
1798    ))
1799}
1800
1801/// Determines the lexical ordering requirement for an aggregate expression.
1802///
1803/// # Parameters
1804///
1805/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the
1806///   aggregate expression.
1807/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1808///   physical GROUP BY expression.
1809/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1810///   mode of aggregation.
1811/// - `include_soft_requirement`: When `false`, only hard requirements are
1812///   considered, as indicated by [`AggregateFunctionExpr::order_sensitivity`]
1813///   returning [`AggregateOrderSensitivity::HardRequirement`].
1814///   Otherwise, also soft requirements ([`AggregateOrderSensitivity::SoftRequirement`])
1815///   are considered.
1816///
1817/// # Returns
1818///
1819/// A `LexOrdering` instance indicating the lexical ordering requirement for
1820/// the aggregate expression.
1821fn get_aggregate_expr_req(
1822    aggr_expr: &AggregateFunctionExpr,
1823    group_by: &PhysicalGroupBy,
1824    agg_mode: &AggregateMode,
1825    include_soft_requirement: bool,
1826) -> Option<LexOrdering> {
1827    // If the aggregation is performing a "second stage" calculation,
1828    // then ignore the ordering requirement. Ordering requirement applies
1829    // only to the aggregation input data.
1830    if agg_mode.input_mode() == AggregateInputMode::Partial {
1831        return None;
1832    }
1833
1834    match aggr_expr.order_sensitivity() {
1835        AggregateOrderSensitivity::Insensitive => return None,
1836        AggregateOrderSensitivity::HardRequirement => {}
1837        AggregateOrderSensitivity::SoftRequirement => {
1838            if !include_soft_requirement {
1839                return None;
1840            }
1841        }
1842        AggregateOrderSensitivity::Beneficial => return None,
1843    }
1844
1845    let mut sort_exprs = aggr_expr.order_bys().to_vec();
1846    // In non-first stage modes, we accumulate data (using `merge_batch`) from
1847    // different partitions (i.e. merge partial results). During this merge, we
1848    // consider the ordering of each partial result. Hence, we do not need to
1849    // use the ordering requirement in such modes as long as partial results are
1850    // generated with the correct ordering.
1851    if group_by.is_single() {
1852        // Remove all orderings that occur in the group by. These requirements
1853        // will definitely be satisfied -- Each group by expression will have
1854        // distinct values per group, hence all requirements are satisfied.
1855        let physical_exprs = group_by.input_exprs();
1856        sort_exprs.retain(|sort_expr| {
1857            !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1858        });
1859    }
1860    LexOrdering::new(sort_exprs)
1861}
1862
1863/// Concatenates the given slices.
1864pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1865    [lhs, rhs].concat()
1866}
1867
1868// Determines if the candidate ordering is finer than the current ordering.
1869// Returns `None` if they are incomparable, `Some(true)` if there is no current
1870// ordering or candidate ordering is finer, and `Some(false)` otherwise.
1871fn determine_finer(
1872    current: &Option<LexOrdering>,
1873    candidate: &LexOrdering,
1874) -> Option<bool> {
1875    if let Some(ordering) = current {
1876        candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1877    } else {
1878        Some(true)
1879    }
1880}
1881
1882/// Gets the common requirement that satisfies all the aggregate expressions.
1883/// When possible, chooses the requirement that is already satisfied by the
1884/// equivalence properties.
1885///
1886/// # Parameters
1887///
1888/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the
1889///   aggregate expressions.
1890/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1891///   physical GROUP BY expression.
1892/// - `eq_properties`: A reference to an `EquivalenceProperties` instance
1893///   representing equivalence properties for ordering.
1894/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1895///   mode of aggregation.
1896///
1897/// # Returns
1898///
1899/// A `Result<Vec<PhysicalSortRequirement>>` instance, which is the requirement
1900/// that satisfies all the aggregate requirements. Returns an error in case of
1901/// conflicting requirements.
1902pub fn get_finer_aggregate_exprs_requirement(
1903    aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1904    group_by: &PhysicalGroupBy,
1905    eq_properties: &EquivalenceProperties,
1906    agg_mode: &AggregateMode,
1907) -> Result<Vec<PhysicalSortRequirement>> {
1908    let mut requirement = None;
1909
1910    // First try and find a match for all hard and soft requirements.
1911    // If a match can't be found, try a second time just matching hard
1912    // requirements.
1913    for include_soft_requirement in [false, true] {
1914        for aggr_expr in aggr_exprs.iter_mut() {
1915            let Some(aggr_req) = get_aggregate_expr_req(
1916                aggr_expr,
1917                group_by,
1918                agg_mode,
1919                include_soft_requirement,
1920            )
1921            .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1922                // There is no aggregate ordering requirement, or it is trivially
1923                // satisfied -- we can skip this expression.
1924                continue;
1925            };
1926            // If the common requirement is finer than the current expression's,
1927            // we can skip this expression. If the latter is finer than the former,
1928            // adopt it if it is satisfied by the equivalence properties. Otherwise,
1929            // defer the analysis to the reverse expression.
1930            let forward_finer = determine_finer(&requirement, &aggr_req);
1931            if let Some(finer) = forward_finer {
1932                if !finer {
1933                    continue;
1934                } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1935                    requirement = Some(aggr_req);
1936                    continue;
1937                }
1938            }
1939            if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1940                let Some(rev_aggr_req) = get_aggregate_expr_req(
1941                    &reverse_aggr_expr,
1942                    group_by,
1943                    agg_mode,
1944                    include_soft_requirement,
1945                )
1946                .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1947                    // The reverse requirement is trivially satisfied -- just reverse
1948                    // the expression and continue with the next one:
1949                    *aggr_expr = Arc::new(reverse_aggr_expr);
1950                    continue;
1951                };
1952                // If the common requirement is finer than the reverse expression's,
1953                // just reverse it and continue the loop with the next aggregate
1954                // expression. If the latter is finer than the former, adopt it if
1955                // it is satisfied by the equivalence properties. Otherwise, adopt
1956                // the forward expression.
1957                if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1958                    if !finer {
1959                        *aggr_expr = Arc::new(reverse_aggr_expr);
1960                    } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1961                        *aggr_expr = Arc::new(reverse_aggr_expr);
1962                        requirement = Some(rev_aggr_req);
1963                    } else {
1964                        requirement = Some(aggr_req);
1965                    }
1966                } else if forward_finer.is_some() {
1967                    requirement = Some(aggr_req);
1968                } else {
1969                    // Neither the existing requirement nor the current aggregate
1970                    // requirement satisfy the other (forward or reverse), this
1971                    // means they are conflicting. This is a problem only for hard
1972                    // requirements. Unsatisfied soft requirements can be ignored.
1973                    if !include_soft_requirement {
1974                        return not_impl_err!(
1975                            "Conflicting ordering requirements in aggregate functions is not supported"
1976                        );
1977                    }
1978                }
1979            }
1980        }
1981    }
1982
1983    Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1984}
1985
1986/// Returns physical expressions for arguments to evaluate against a batch.
1987///
1988/// The expressions are different depending on `mode`:
1989/// * Partial: AggregateFunctionExpr::expressions
1990/// * Final: columns of `AggregateFunctionExpr::state_fields()`
1991pub fn aggregate_expressions(
1992    aggr_expr: &[Arc<AggregateFunctionExpr>],
1993    mode: &AggregateMode,
1994    col_idx_base: usize,
1995) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1996    match mode.input_mode() {
1997        AggregateInputMode::Raw => Ok(aggr_expr
1998            .iter()
1999            .map(|agg| {
2000                let mut result = agg.expressions();
2001                // Append ordering requirements to expressions' results. This
2002                // way order sensitive aggregators can satisfy requirement
2003                // themselves.
2004                result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
2005                result
2006            })
2007            .collect()),
2008        AggregateInputMode::Partial => {
2009            // In merge mode, we build the merge expressions of the aggregation.
2010            let mut col_idx_base = col_idx_base;
2011            aggr_expr
2012                .iter()
2013                .map(|agg| {
2014                    let exprs = merge_expressions(col_idx_base, agg)?;
2015                    col_idx_base += exprs.len();
2016                    Ok(exprs)
2017                })
2018                .collect()
2019        }
2020    }
2021}
2022
2023/// uses `state_fields` to build a vec of physical column expressions required to merge the
2024/// AggregateFunctionExpr' accumulator's state.
2025///
2026/// `index_base` is the starting physical column index for the next expanded state field.
2027fn merge_expressions(
2028    index_base: usize,
2029    expr: &AggregateFunctionExpr,
2030) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
2031    expr.state_fields().map(|fields| {
2032        fields
2033            .iter()
2034            .enumerate()
2035            .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
2036            .collect()
2037    })
2038}
2039
2040pub type AccumulatorItem = Box<dyn Accumulator>;
2041
2042pub fn create_accumulators(
2043    aggr_expr: &[Arc<AggregateFunctionExpr>],
2044) -> Result<Vec<AccumulatorItem>> {
2045    aggr_expr
2046        .iter()
2047        .map(|expr| expr.create_accumulator())
2048        .collect()
2049}
2050
2051/// returns a vector of ArrayRefs, where each entry corresponds to either the
2052/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
2053pub fn finalize_aggregation(
2054    accumulators: &mut [AccumulatorItem],
2055    mode: &AggregateMode,
2056) -> Result<Vec<ArrayRef>> {
2057    match mode.output_mode() {
2058        AggregateOutputMode::Final => {
2059            // Merge the state to the final value
2060            accumulators
2061                .iter_mut()
2062                .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
2063                .collect()
2064        }
2065        AggregateOutputMode::Partial => {
2066            // Build the vector of states
2067            accumulators
2068                .iter_mut()
2069                .map(|accumulator| {
2070                    accumulator.state().and_then(|e| {
2071                        e.iter()
2072                            .map(|v| v.to_array())
2073                            .collect::<Result<Vec<ArrayRef>>>()
2074                    })
2075                })
2076                .flatten_ok()
2077                .collect()
2078        }
2079    }
2080}
2081
2082/// Evaluates groups of expressions against a record batch.
2083pub fn evaluate_many(
2084    expr: &[Vec<Arc<dyn PhysicalExpr>>],
2085    batch: &RecordBatch,
2086) -> Result<Vec<Vec<ArrayRef>>> {
2087    expr.iter()
2088        .map(|expr| evaluate_expressions_to_arrays(expr, batch))
2089        .collect()
2090}
2091
2092fn evaluate_optional(
2093    expr: &[Option<Arc<dyn PhysicalExpr>>],
2094    batch: &RecordBatch,
2095) -> Result<Vec<Option<ArrayRef>>> {
2096    expr.iter()
2097        .map(|expr| {
2098            expr.as_ref()
2099                .map(|expr| {
2100                    expr.evaluate(batch)
2101                        .and_then(|v| v.into_array(batch.num_rows()))
2102                })
2103                .transpose()
2104        })
2105        .collect()
2106}
2107
2108/// Builds the internal `__grouping_id` array for a single grouping set.
2109///
2110/// The returned array packs two values into a single integer:
2111///
2112/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask.  A `1` bit
2113///   at position `i` means that the `i`-th grouping column (counting from the
2114///   least significant bit, i.e. the *last* column in the `group` slice) is
2115///   `NULL` for this grouping set.
2116/// - High bits (positions n and above): the duplicate `ordinal`, which
2117///   distinguishes multiple occurrences of the same grouping-set pattern.  The
2118///   ordinal is `0` for the first occurrence, `1` for the second, and so on.
2119///
2120/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
2121/// UInt64` that can represent both parts.  It matches the type returned by
2122/// [`Aggregate::grouping_id_type`].
2123pub(crate) fn group_id_array(
2124    group: &[bool],
2125    ordinal: usize,
2126    max_ordinal: usize,
2127    num_rows: usize,
2128) -> Result<ArrayRef> {
2129    let n = group.len();
2130    if n > 64 {
2131        return not_impl_err!(
2132            "Grouping sets with more than 64 columns are not supported"
2133        );
2134    }
2135    let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize;
2136    let total_bits = n + ordinal_bits;
2137    if total_bits > 64 {
2138        return not_impl_err!(
2139            "Grouping sets with {n} columns and a maximum duplicate ordinal of \
2140             {max_ordinal} require {total_bits} bits, which exceeds 64"
2141        );
2142    }
2143    let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
2144        (acc << 1) | if is_null { 1 } else { 0 }
2145    });
2146    let full_id = semantic_id | ((ordinal as u64) << n);
2147    if total_bits <= 8 {
2148        Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
2149    } else if total_bits <= 16 {
2150        Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
2151    } else if total_bits <= 32 {
2152        Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
2153    } else {
2154        Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
2155    }
2156}
2157
2158/// Returns the highest duplicate ordinal across all grouping sets.
2159///
2160/// At the call-site, the ordinal is the 0-based index assigned to each
2161/// occurrence of a repeated grouping-set pattern: the first occurrence gets
2162/// ordinal 0, the second gets 1, and so on.  If the same `Vec<bool>` appears
2163/// three times the ordinals are 0, 1, 2 and this function returns 2.
2164/// Returns 0 when no grouping set is duplicated.
2165pub(crate) fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
2166    let mut counts: HashMap<&[bool], usize> = HashMap::new();
2167    for group in groups {
2168        *counts.entry(group).or_insert(0) += 1;
2169    }
2170    counts.into_values().max().unwrap_or(0).saturating_sub(1)
2171}
2172
2173/// Evaluate a group by expression against a `RecordBatch`
2174///
2175/// Arguments:
2176/// - `group_by`: the expression to evaluate
2177/// - `batch`: the `RecordBatch` to evaluate against
2178///
2179/// Returns: A Vec of Vecs of Array of results
2180/// The outer Vec appears to be for grouping sets
2181/// The inner Vec contains the results per expression
2182/// The inner-inner Array contains the results per row
2183pub fn evaluate_group_by(
2184    group_by: &PhysicalGroupBy,
2185    batch: &RecordBatch,
2186) -> Result<Vec<Vec<ArrayRef>>> {
2187    let max_ordinal = max_duplicate_ordinal(&group_by.groups);
2188    let mut ordinal_per_pattern: HashMap<&[bool], usize> = HashMap::new();
2189    let exprs = evaluate_expressions_to_arrays(
2190        group_by.expr.iter().map(|(expr, _)| expr),
2191        batch,
2192    )?;
2193    let null_exprs = evaluate_expressions_to_arrays(
2194        group_by.null_expr.iter().map(|(expr, _)| expr),
2195        batch,
2196    )?;
2197
2198    group_by
2199        .groups
2200        .iter()
2201        .map(|group| {
2202            let ordinal = ordinal_per_pattern.entry(group).or_insert(0);
2203            let current_ordinal = *ordinal;
2204            *ordinal += 1;
2205
2206            let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
2207            group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
2208                if *is_null {
2209                    Arc::clone(&null_exprs[idx])
2210                } else {
2211                    Arc::clone(&exprs[idx])
2212                }
2213            }));
2214            if !group_by.is_single() {
2215                group_values.push(group_id_array(
2216                    group,
2217                    current_ordinal,
2218                    max_ordinal,
2219                    batch.num_rows(),
2220                )?);
2221            }
2222            Ok(group_values)
2223        })
2224        .collect()
2225}
2226
2227#[cfg(test)]
2228mod tests {
2229    use std::task::{Context, Poll};
2230
2231    use super::*;
2232    use crate::RecordBatchStream;
2233    use crate::coalesce_partitions::CoalescePartitionsExec;
2234    use crate::common;
2235    use crate::common::collect;
2236    use crate::empty::EmptyExec;
2237    use crate::execution_plan::Boundedness;
2238    use crate::expressions::col;
2239    use crate::metrics::MetricValue;
2240    use crate::test::TestMemoryExec;
2241    use crate::test::assert_is_pending;
2242    use crate::test::exec::{
2243        BlockingExec, StatisticsExec, assert_strong_count_converges_to_zero,
2244    };
2245
2246    use arrow::array::{
2247        DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
2248        UInt32Array, UInt64Array,
2249    };
2250    use arrow::compute::{SortOptions, concat_batches};
2251    use arrow::datatypes::Int32Type;
2252    use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
2253    use datafusion_common::{DataFusionError, internal_err};
2254    use datafusion_execution::config::SessionConfig;
2255    use datafusion_execution::memory_pool::FairSpillPool;
2256    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2257    use datafusion_functions_aggregate::array_agg::array_agg_udaf;
2258    use datafusion_functions_aggregate::average::avg_udaf;
2259    use datafusion_functions_aggregate::count::count_udaf;
2260    use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
2261    use datafusion_functions_aggregate::median::median_udaf;
2262    use datafusion_functions_aggregate::min_max::min_udaf;
2263    use datafusion_functions_aggregate::sum::sum_udaf;
2264    use datafusion_physical_expr::Partitioning;
2265    use datafusion_physical_expr::PhysicalSortExpr;
2266    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
2267    use datafusion_physical_expr::expressions::Literal;
2268
2269    use crate::projection::ProjectionExec;
2270    use datafusion_physical_expr::projection::ProjectionExpr;
2271    use futures::{FutureExt, Stream};
2272    use insta::{allow_duplicates, assert_snapshot};
2273
2274    // Generate a schema which consists of 5 columns (a, b, c, d, e)
2275    fn create_test_schema() -> Result<SchemaRef> {
2276        let a = Field::new("a", DataType::Int32, true);
2277        let b = Field::new("b", DataType::Int32, true);
2278        let c = Field::new("c", DataType::Int32, true);
2279        let d = Field::new("d", DataType::Int32, true);
2280        let e = Field::new("e", DataType::Int32, true);
2281        let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
2282
2283        Ok(schema)
2284    }
2285
2286    /// some mock data to aggregates
2287    fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
2288        // define a schema.
2289        let schema = Arc::new(Schema::new(vec![
2290            Field::new("a", DataType::UInt32, false),
2291            Field::new("b", DataType::Float64, false),
2292        ]));
2293
2294        // define data.
2295        (
2296            Arc::clone(&schema),
2297            vec![
2298                RecordBatch::try_new(
2299                    Arc::clone(&schema),
2300                    vec![
2301                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2302                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2303                    ],
2304                )
2305                .unwrap(),
2306                RecordBatch::try_new(
2307                    schema,
2308                    vec![
2309                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2310                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2311                    ],
2312                )
2313                .unwrap(),
2314            ],
2315        )
2316    }
2317
2318    /// Generates some mock data for aggregate tests.
2319    fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
2320        // Define a schema:
2321        let schema = Arc::new(Schema::new(vec![
2322            Field::new("a", DataType::UInt32, false),
2323            Field::new("b", DataType::Float64, false),
2324        ]));
2325
2326        // Generate data so that first and last value results are at 2nd and
2327        // 3rd partitions.  With this construction, we guarantee we don't receive
2328        // the expected result by accident, but merging actually works properly;
2329        // i.e. it doesn't depend on the data insertion order.
2330        (
2331            Arc::clone(&schema),
2332            vec![
2333                RecordBatch::try_new(
2334                    Arc::clone(&schema),
2335                    vec![
2336                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2337                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2338                    ],
2339                )
2340                .unwrap(),
2341                RecordBatch::try_new(
2342                    Arc::clone(&schema),
2343                    vec![
2344                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2345                        Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
2346                    ],
2347                )
2348                .unwrap(),
2349                RecordBatch::try_new(
2350                    Arc::clone(&schema),
2351                    vec![
2352                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2353                        Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
2354                    ],
2355                )
2356                .unwrap(),
2357                RecordBatch::try_new(
2358                    schema,
2359                    vec![
2360                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2361                        Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
2362                    ],
2363                )
2364                .unwrap(),
2365            ],
2366        )
2367    }
2368
2369    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
2370        let session_config = SessionConfig::new().with_batch_size(batch_size);
2371        let runtime = RuntimeEnvBuilder::new()
2372            .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
2373            .build_arc()
2374            .unwrap();
2375        let task_ctx = TaskContext::default()
2376            .with_session_config(session_config)
2377            .with_runtime(runtime);
2378        Arc::new(task_ctx)
2379    }
2380
2381    async fn check_grouping_sets(
2382        input: Arc<dyn ExecutionPlan>,
2383        spill: bool,
2384    ) -> Result<()> {
2385        let input_schema = input.schema();
2386
2387        let grouping_set = PhysicalGroupBy::new(
2388            vec![
2389                (col("a", &input_schema)?, "a".to_string()),
2390                (col("b", &input_schema)?, "b".to_string()),
2391            ],
2392            vec![
2393                (lit(ScalarValue::UInt32(None)), "a".to_string()),
2394                (lit(ScalarValue::Float64(None)), "b".to_string()),
2395            ],
2396            vec![
2397                vec![false, true],  // (a, NULL)
2398                vec![true, false],  // (NULL, b)
2399                vec![false, false], // (a,b)
2400            ],
2401            true,
2402        );
2403
2404        let aggregates = vec![Arc::new(
2405            AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
2406                .schema(Arc::clone(&input_schema))
2407                .alias("COUNT(1)")
2408                .build()?,
2409        )];
2410
2411        let task_ctx = if spill {
2412            // adjust the max memory size to have the partial aggregate result for spill mode.
2413            new_spill_ctx(4, 500)
2414        } else {
2415            Arc::new(TaskContext::default())
2416        };
2417
2418        let partial_aggregate = Arc::new(AggregateExec::try_new(
2419            AggregateMode::Partial,
2420            grouping_set.clone(),
2421            aggregates.clone(),
2422            vec![None],
2423            input,
2424            Arc::clone(&input_schema),
2425        )?);
2426
2427        let result =
2428            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2429
2430        if spill {
2431            // In spill mode, we test with the limited memory, if the mem usage exceeds,
2432            // we trigger the early emit rule, which turns out the partial aggregate result.
2433            allow_duplicates! {
2434            assert_snapshot!(batches_to_sort_string(&result),
2435            @r"
2436            +---+-----+---------------+-----------------+
2437            | a | b   | __grouping_id | COUNT(1)[count] |
2438            +---+-----+---------------+-----------------+
2439            |   | 1.0 | 2             | 1               |
2440            |   | 1.0 | 2             | 1               |
2441            |   | 2.0 | 2             | 1               |
2442            |   | 2.0 | 2             | 1               |
2443            |   | 3.0 | 2             | 1               |
2444            |   | 3.0 | 2             | 1               |
2445            |   | 4.0 | 2             | 1               |
2446            |   | 4.0 | 2             | 1               |
2447            | 2 |     | 1             | 1               |
2448            | 2 |     | 1             | 1               |
2449            | 2 | 1.0 | 0             | 1               |
2450            | 2 | 1.0 | 0             | 1               |
2451            | 3 |     | 1             | 1               |
2452            | 3 |     | 1             | 2               |
2453            | 3 | 2.0 | 0             | 2               |
2454            | 3 | 3.0 | 0             | 1               |
2455            | 4 |     | 1             | 1               |
2456            | 4 |     | 1             | 2               |
2457            | 4 | 3.0 | 0             | 1               |
2458            | 4 | 4.0 | 0             | 2               |
2459            +---+-----+---------------+-----------------+
2460            "
2461            );
2462            }
2463        } else {
2464            allow_duplicates! {
2465            assert_snapshot!(batches_to_sort_string(&result),
2466            @r"
2467            +---+-----+---------------+-----------------+
2468            | a | b   | __grouping_id | COUNT(1)[count] |
2469            +---+-----+---------------+-----------------+
2470            |   | 1.0 | 2             | 2               |
2471            |   | 2.0 | 2             | 2               |
2472            |   | 3.0 | 2             | 2               |
2473            |   | 4.0 | 2             | 2               |
2474            | 2 |     | 1             | 2               |
2475            | 2 | 1.0 | 0             | 2               |
2476            | 3 |     | 1             | 3               |
2477            | 3 | 2.0 | 0             | 2               |
2478            | 3 | 3.0 | 0             | 1               |
2479            | 4 |     | 1             | 3               |
2480            | 4 | 3.0 | 0             | 1               |
2481            | 4 | 4.0 | 0             | 2               |
2482            +---+-----+---------------+-----------------+
2483            "
2484            );
2485            }
2486        };
2487
2488        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2489
2490        let final_grouping_set = grouping_set.as_final();
2491
2492        let task_ctx = if spill {
2493            new_spill_ctx(4, 3160)
2494        } else {
2495            task_ctx
2496        };
2497
2498        let merged_aggregate = Arc::new(AggregateExec::try_new(
2499            AggregateMode::Final,
2500            final_grouping_set,
2501            aggregates,
2502            vec![None],
2503            merge,
2504            input_schema,
2505        )?);
2506
2507        let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2508        let batch = concat_batches(&result[0].schema(), &result)?;
2509        assert_eq!(batch.num_columns(), 4);
2510        assert_eq!(batch.num_rows(), 12);
2511
2512        allow_duplicates! {
2513        assert_snapshot!(
2514            batches_to_sort_string(&result),
2515            @r"
2516        +---+-----+---------------+----------+
2517        | a | b   | __grouping_id | COUNT(1) |
2518        +---+-----+---------------+----------+
2519        |   | 1.0 | 2             | 2        |
2520        |   | 2.0 | 2             | 2        |
2521        |   | 3.0 | 2             | 2        |
2522        |   | 4.0 | 2             | 2        |
2523        | 2 |     | 1             | 2        |
2524        | 2 | 1.0 | 0             | 2        |
2525        | 3 |     | 1             | 3        |
2526        | 3 | 2.0 | 0             | 2        |
2527        | 3 | 3.0 | 0             | 1        |
2528        | 4 |     | 1             | 3        |
2529        | 4 | 3.0 | 0             | 1        |
2530        | 4 | 4.0 | 0             | 2        |
2531        +---+-----+---------------+----------+
2532        "
2533        );
2534        }
2535
2536        let metrics = merged_aggregate.metrics().unwrap();
2537        let output_rows = metrics.output_rows().unwrap();
2538        assert_eq!(12, output_rows);
2539
2540        Ok(())
2541    }
2542
2543    /// build the aggregates on the data from some_data() and check the results
2544    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
2545        let input_schema = input.schema();
2546
2547        let grouping_set = PhysicalGroupBy::new(
2548            vec![(col("a", &input_schema)?, "a".to_string())],
2549            vec![],
2550            vec![vec![false]],
2551            false,
2552        );
2553
2554        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2555            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2556                .schema(Arc::clone(&input_schema))
2557                .alias("AVG(b)")
2558                .build()?,
2559        )];
2560
2561        let task_ctx = if spill {
2562            // set to an appropriate value to trigger spill
2563            new_spill_ctx(2, 1600)
2564        } else {
2565            Arc::new(TaskContext::default())
2566        };
2567
2568        let partial_aggregate = Arc::new(AggregateExec::try_new(
2569            AggregateMode::Partial,
2570            grouping_set.clone(),
2571            aggregates.clone(),
2572            vec![None],
2573            input,
2574            Arc::clone(&input_schema),
2575        )?);
2576
2577        let result =
2578            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2579
2580        if spill {
2581            allow_duplicates! {
2582            assert_snapshot!(batches_to_sort_string(&result), @r"
2583            +---+---------------+-------------+
2584            | a | AVG(b)[count] | AVG(b)[sum] |
2585            +---+---------------+-------------+
2586            | 2 | 1             | 1.0         |
2587            | 2 | 1             | 1.0         |
2588            | 3 | 1             | 2.0         |
2589            | 3 | 2             | 5.0         |
2590            | 4 | 3             | 11.0        |
2591            +---+---------------+-------------+
2592            ");
2593            }
2594        } else {
2595            allow_duplicates! {
2596            assert_snapshot!(batches_to_sort_string(&result), @r"
2597            +---+---------------+-------------+
2598            | a | AVG(b)[count] | AVG(b)[sum] |
2599            +---+---------------+-------------+
2600            | 2 | 2             | 2.0         |
2601            | 3 | 3             | 7.0         |
2602            | 4 | 3             | 11.0        |
2603            +---+---------------+-------------+
2604            ");
2605            }
2606        };
2607
2608        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2609
2610        let final_grouping_set = grouping_set.as_final();
2611
2612        let merged_aggregate = Arc::new(AggregateExec::try_new(
2613            AggregateMode::Final,
2614            final_grouping_set,
2615            aggregates,
2616            vec![None],
2617            merge,
2618            input_schema,
2619        )?);
2620
2621        // Verify statistics are preserved proportionally through aggregation
2622        let final_stats = merged_aggregate.partition_statistics(None)?;
2623        assert!(final_stats.total_byte_size.get_value().is_some());
2624
2625        let task_ctx = if spill {
2626            // enlarge memory limit to let the final aggregation finish
2627            new_spill_ctx(2, 2600)
2628        } else {
2629            Arc::clone(&task_ctx)
2630        };
2631        let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
2632        let batch = concat_batches(&result[0].schema(), &result)?;
2633        assert_eq!(batch.num_columns(), 2);
2634        assert_eq!(batch.num_rows(), 3);
2635
2636        allow_duplicates! {
2637        assert_snapshot!(batches_to_sort_string(&result), @r"
2638        +---+--------------------+
2639        | a | AVG(b)             |
2640        +---+--------------------+
2641        | 2 | 1.0                |
2642        | 3 | 2.3333333333333335 |
2643        | 4 | 3.6666666666666665 |
2644        +---+--------------------+
2645        ");
2646            // For row 2: 3, (2 + 3 + 2) / 3
2647            // For row 3: 4, (3 + 4 + 4) / 3
2648        }
2649
2650        let metrics = merged_aggregate.metrics().unwrap();
2651        let output_rows = metrics.output_rows().unwrap();
2652        let spill_count = metrics.spill_count().unwrap();
2653        let spilled_bytes = metrics.spilled_bytes().unwrap();
2654        let spilled_rows = metrics.spilled_rows().unwrap();
2655
2656        if spill {
2657            // When spilling, the output rows metrics become partial output size + final output size
2658            // This is because final aggregation starts while partial aggregation is still emitting
2659            assert_eq!(8, output_rows);
2660
2661            assert!(spill_count > 0);
2662            assert!(spilled_bytes > 0);
2663            assert!(spilled_rows > 0);
2664        } else {
2665            assert_eq!(3, output_rows);
2666
2667            assert_eq!(0, spill_count);
2668            assert_eq!(0, spilled_bytes);
2669            assert_eq!(0, spilled_rows);
2670        }
2671
2672        Ok(())
2673    }
2674
2675    /// Define a test source that can yield back to runtime before returning its first item ///
2676
2677    #[derive(Debug)]
2678    struct TestYieldingExec {
2679        /// True if this exec should yield back to runtime the first time it is polled
2680        pub yield_first: bool,
2681        cache: Arc<PlanProperties>,
2682    }
2683
2684    impl TestYieldingExec {
2685        fn new(yield_first: bool) -> Self {
2686            let schema = some_data().0;
2687            let cache = Self::compute_properties(schema);
2688            Self {
2689                yield_first,
2690                cache: Arc::new(cache),
2691            }
2692        }
2693
2694        /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
2695        fn compute_properties(schema: SchemaRef) -> PlanProperties {
2696            PlanProperties::new(
2697                EquivalenceProperties::new(schema),
2698                Partitioning::UnknownPartitioning(1),
2699                EmissionType::Incremental,
2700                Boundedness::Bounded,
2701            )
2702        }
2703    }
2704
2705    impl DisplayAs for TestYieldingExec {
2706        fn fmt_as(
2707            &self,
2708            t: DisplayFormatType,
2709            f: &mut std::fmt::Formatter,
2710        ) -> std::fmt::Result {
2711            match t {
2712                DisplayFormatType::Default | DisplayFormatType::Verbose => {
2713                    write!(f, "TestYieldingExec")
2714                }
2715                DisplayFormatType::TreeRender => {
2716                    // TODO: collect info
2717                    write!(f, "")
2718                }
2719            }
2720        }
2721    }
2722
2723    impl ExecutionPlan for TestYieldingExec {
2724        fn name(&self) -> &'static str {
2725            "TestYieldingExec"
2726        }
2727
2728        fn properties(&self) -> &Arc<PlanProperties> {
2729            &self.cache
2730        }
2731
2732        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2733            vec![]
2734        }
2735
2736        fn with_new_children(
2737            self: Arc<Self>,
2738            _: Vec<Arc<dyn ExecutionPlan>>,
2739        ) -> Result<Arc<dyn ExecutionPlan>> {
2740            internal_err!("Children cannot be replaced in {self:?}")
2741        }
2742
2743        fn execute(
2744            &self,
2745            _partition: usize,
2746            _context: Arc<TaskContext>,
2747        ) -> Result<SendableRecordBatchStream> {
2748            let stream = if self.yield_first {
2749                TestYieldingStream::New
2750            } else {
2751                TestYieldingStream::Yielded
2752            };
2753
2754            Ok(Box::pin(stream))
2755        }
2756
2757        fn partition_statistics(
2758            &self,
2759            partition: Option<usize>,
2760        ) -> Result<Arc<Statistics>> {
2761            if partition.is_some() {
2762                return Ok(Arc::new(Statistics::new_unknown(self.schema().as_ref())));
2763            }
2764            let (_, batches) = some_data();
2765            Ok(Arc::new(common::compute_record_batch_statistics(
2766                &[batches],
2767                &self.schema(),
2768                None,
2769            )))
2770        }
2771    }
2772
2773    /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records
2774    enum TestYieldingStream {
2775        New,
2776        Yielded,
2777        ReturnedBatch1,
2778        ReturnedBatch2,
2779    }
2780
2781    impl Stream for TestYieldingStream {
2782        type Item = Result<RecordBatch>;
2783
2784        fn poll_next(
2785            mut self: std::pin::Pin<&mut Self>,
2786            cx: &mut Context<'_>,
2787        ) -> Poll<Option<Self::Item>> {
2788            match &*self {
2789                TestYieldingStream::New => {
2790                    *(self.as_mut()) = TestYieldingStream::Yielded;
2791                    cx.waker().wake_by_ref();
2792                    Poll::Pending
2793                }
2794                TestYieldingStream::Yielded => {
2795                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2796                    Poll::Ready(Some(Ok(some_data().1[0].clone())))
2797                }
2798                TestYieldingStream::ReturnedBatch1 => {
2799                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2800                    Poll::Ready(Some(Ok(some_data().1[1].clone())))
2801                }
2802                TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2803            }
2804        }
2805    }
2806
2807    impl RecordBatchStream for TestYieldingStream {
2808        fn schema(&self) -> SchemaRef {
2809            some_data().0
2810        }
2811    }
2812
2813    //--- Tests ---//
2814
2815    #[tokio::test]
2816    async fn aggregate_source_not_yielding() -> Result<()> {
2817        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2818
2819        check_aggregates(input, false).await
2820    }
2821
2822    #[tokio::test]
2823    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2824        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2825
2826        check_grouping_sets(input, false).await
2827    }
2828
2829    #[tokio::test]
2830    async fn aggregate_source_with_yielding() -> Result<()> {
2831        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2832
2833        check_aggregates(input, false).await
2834    }
2835
2836    #[tokio::test]
2837    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2838        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2839
2840        check_grouping_sets(input, false).await
2841    }
2842
2843    #[tokio::test]
2844    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2845        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2846
2847        check_aggregates(input, true).await
2848    }
2849
2850    #[tokio::test]
2851    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2852        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2853
2854        check_grouping_sets(input, true).await
2855    }
2856
2857    #[tokio::test]
2858    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2859        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2860
2861        check_aggregates(input, true).await
2862    }
2863
2864    #[tokio::test]
2865    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2866        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2867
2868        check_grouping_sets(input, true).await
2869    }
2870
2871    // Median(a)
2872    fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2873        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2874            .schema(schema)
2875            .alias("MEDIAN(a)")
2876            .build()
2877    }
2878
2879    #[tokio::test]
2880    async fn test_oom() -> Result<()> {
2881        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2882        let input_schema = input.schema();
2883
2884        let runtime = RuntimeEnvBuilder::new()
2885            .with_memory_limit(1, 1.0)
2886            .build_arc()?;
2887        let task_ctx = TaskContext::default().with_runtime(runtime);
2888        let task_ctx = Arc::new(task_ctx);
2889
2890        let groups_none = PhysicalGroupBy::default();
2891        let groups_some = PhysicalGroupBy::new(
2892            vec![(col("a", &input_schema)?, "a".to_string())],
2893            vec![],
2894            vec![vec![false]],
2895            false,
2896        );
2897
2898        // something that allocates within the aggregator
2899        let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2900            vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2901
2902        // use fast-path in `row_hash.rs`.
2903        let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2904            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2905                .schema(Arc::clone(&input_schema))
2906                .alias("AVG(b)")
2907                .build()?,
2908        )];
2909
2910        for (version, groups, aggregates) in [
2911            (0, groups_none, aggregates_v0),
2912            (2, groups_some, aggregates_v2),
2913        ] {
2914            let n_aggr = aggregates.len();
2915            let partial_aggregate = Arc::new(AggregateExec::try_new(
2916                AggregateMode::Single,
2917                groups,
2918                aggregates,
2919                vec![None; n_aggr],
2920                Arc::clone(&input),
2921                Arc::clone(&input_schema),
2922            )?);
2923
2924            let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
2925
2926            // ensure that we really got the version we wanted
2927            match version {
2928                0 => {
2929                    assert!(matches!(stream, StreamType::AggregateStream(_)));
2930                }
2931                1 => {
2932                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2933                }
2934                2 => {
2935                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2936                }
2937                _ => panic!("Unknown version: {version}"),
2938            }
2939
2940            let stream: SendableRecordBatchStream = stream.into();
2941            let err = collect(stream).await.unwrap_err();
2942
2943            // error root cause traversal is a bit complicated, see #4172.
2944            let err = err.find_root();
2945            assert!(
2946                matches!(err, DataFusionError::ResourcesExhausted(_)),
2947                "Wrong error type: {err}",
2948            );
2949        }
2950
2951        Ok(())
2952    }
2953
2954    #[tokio::test]
2955    async fn test_drop_cancel_without_groups() -> Result<()> {
2956        let task_ctx = Arc::new(TaskContext::default());
2957        let schema =
2958            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2959
2960        let groups = PhysicalGroupBy::default();
2961
2962        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2963            AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2964                .schema(Arc::clone(&schema))
2965                .alias("AVG(a)")
2966                .build()?,
2967        )];
2968
2969        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2970        let refs = blocking_exec.refs();
2971        let aggregate_exec = Arc::new(AggregateExec::try_new(
2972            AggregateMode::Partial,
2973            groups.clone(),
2974            aggregates.clone(),
2975            vec![None],
2976            blocking_exec,
2977            schema,
2978        )?);
2979
2980        let fut = crate::collect(aggregate_exec, task_ctx);
2981        let mut fut = fut.boxed();
2982
2983        assert_is_pending(&mut fut);
2984        drop(fut);
2985        assert_strong_count_converges_to_zero(refs).await;
2986
2987        Ok(())
2988    }
2989
2990    #[tokio::test]
2991    async fn test_drop_cancel_with_groups() -> Result<()> {
2992        let task_ctx = Arc::new(TaskContext::default());
2993        let schema = Arc::new(Schema::new(vec![
2994            Field::new("a", DataType::Float64, true),
2995            Field::new("b", DataType::Float64, true),
2996        ]));
2997
2998        let groups =
2999            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
3000
3001        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
3002            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3003                .schema(Arc::clone(&schema))
3004                .alias("AVG(b)")
3005                .build()?,
3006        )];
3007
3008        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
3009        let refs = blocking_exec.refs();
3010        let aggregate_exec = Arc::new(AggregateExec::try_new(
3011            AggregateMode::Partial,
3012            groups,
3013            aggregates.clone(),
3014            vec![None],
3015            blocking_exec,
3016            schema,
3017        )?);
3018
3019        let fut = crate::collect(aggregate_exec, task_ctx);
3020        let mut fut = fut.boxed();
3021
3022        assert_is_pending(&mut fut);
3023        drop(fut);
3024        assert_strong_count_converges_to_zero(refs).await;
3025
3026        Ok(())
3027    }
3028
3029    #[tokio::test]
3030    async fn run_first_last_multi_partitions() -> Result<()> {
3031        for is_first_acc in [false, true] {
3032            for spill in [false, true] {
3033                first_last_multi_partitions(is_first_acc, spill, 4200).await?
3034            }
3035        }
3036        Ok(())
3037    }
3038
3039    // FIRST_VALUE(b ORDER BY b <SortOptions>)
3040    fn test_first_value_agg_expr(
3041        schema: &Schema,
3042        sort_options: SortOptions,
3043    ) -> Result<Arc<AggregateFunctionExpr>> {
3044        let order_bys = vec![PhysicalSortExpr {
3045            expr: col("b", schema)?,
3046            options: sort_options,
3047        }];
3048        let args = [col("b", schema)?];
3049
3050        AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
3051            .order_by(order_bys)
3052            .schema(Arc::new(schema.clone()))
3053            .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
3054            .build()
3055            .map(Arc::new)
3056    }
3057
3058    // LAST_VALUE(b ORDER BY b <SortOptions>)
3059    fn test_last_value_agg_expr(
3060        schema: &Schema,
3061        sort_options: SortOptions,
3062    ) -> Result<Arc<AggregateFunctionExpr>> {
3063        let order_bys = vec![PhysicalSortExpr {
3064            expr: col("b", schema)?,
3065            options: sort_options,
3066        }];
3067        let args = [col("b", schema)?];
3068        AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
3069            .order_by(order_bys)
3070            .schema(Arc::new(schema.clone()))
3071            .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
3072            .build()
3073            .map(Arc::new)
3074    }
3075
3076    fn first_value_agg_expr(
3077        schema: &SchemaRef,
3078        column: &str,
3079        alias: &str,
3080        human_display: Option<&str>,
3081        human_display_alias: Option<&str>,
3082    ) -> Result<AggregateFunctionExpr> {
3083        let mut builder =
3084            AggregateExprBuilder::new(first_value_udaf(), vec![col(column, schema)?])
3085                .order_by(vec![PhysicalSortExpr {
3086                    expr: col(column, schema)?,
3087                    options: SortOptions::new(false, false),
3088                }])
3089                .schema(Arc::clone(schema))
3090                .alias(alias);
3091
3092        if let Some(human_display) = human_display {
3093            builder = builder.human_display(human_display);
3094        }
3095        if let Some(human_display_alias) = human_display_alias {
3096            builder = builder.human_display_alias(human_display_alias);
3097        }
3098
3099        builder.build()
3100    }
3101
3102    #[test]
3103    fn test_reverse_expr_preserves_aliased_human_display() -> Result<()> {
3104        let schema = create_test_schema()?;
3105        let agg = first_value_agg_expr(
3106            &schema,
3107            "b",
3108            "agg",
3109            Some("first_value(b) ORDER BY [b ASC NULLS LAST]"),
3110            Some("agg"),
3111        )?;
3112
3113        let reversed = agg.reverse_expr().expect("expected reverse expr");
3114
3115        assert_eq!(reversed.name(), "agg");
3116        assert_eq!(reversed.human_display_alias(), Some("agg"));
3117        assert_eq!(
3118            format_tree_aggregate_expr(&reversed),
3119            "last_value(b) ORDER BY [b DESC NULLS FIRST] as agg"
3120        );
3121        assert_eq!(
3122            reversed.human_display(),
3123            Some("last_value(b) ORDER BY [b DESC NULLS FIRST]")
3124        );
3125
3126        Ok(())
3127    }
3128
3129    #[test]
3130    fn test_reverse_expr_does_not_rewrite_column_names_in_human_display() -> Result<()> {
3131        let schema = Arc::new(Schema::new(vec![Field::new(
3132            "first_value_col",
3133            DataType::Int32,
3134            true,
3135        )]));
3136        let agg = first_value_agg_expr(
3137            &schema,
3138            "first_value_col",
3139            "agg",
3140            Some(
3141                "first_value(first_value_col) ORDER BY [first_value_col ASC NULLS LAST]",
3142            ),
3143            Some("agg"),
3144        )?;
3145
3146        let reversed = agg.reverse_expr().expect("expected reverse expr");
3147
3148        assert_eq!(reversed.name(), "agg");
3149        assert_eq!(
3150            reversed.human_display(),
3151            Some(
3152                "last_value(first_value_col) ORDER BY [first_value_col DESC NULLS FIRST]"
3153            )
3154        );
3155        assert_eq!(
3156            format_tree_aggregate_expr(&reversed),
3157            "last_value(first_value_col) ORDER BY [first_value_col DESC NULLS FIRST] as agg"
3158        );
3159
3160        Ok(())
3161    }
3162
3163    #[test]
3164    fn test_empty_human_display_is_treated_as_absent() -> Result<()> {
3165        let schema = create_test_schema()?;
3166        let agg = first_value_agg_expr(&schema, "b", "agg", Some(""), None)?;
3167
3168        assert_eq!(agg.human_display(), None);
3169        assert_eq!(format_tree_aggregate_expr(&agg), "agg");
3170
3171        Ok(())
3172    }
3173
3174    #[test]
3175    fn test_human_display_alias_must_match_name() -> Result<()> {
3176        let schema = create_test_schema()?;
3177        let error = first_value_agg_expr(
3178            &schema,
3179            "b",
3180            "agg",
3181            Some("first_value(b) ORDER BY [b ASC NULLS LAST]"),
3182            Some("other_alias"),
3183        )
3184        .unwrap_err();
3185
3186        assert!(
3187            error
3188                .to_string()
3189                .contains("aggregate human_display_alias must match")
3190        );
3191
3192        Ok(())
3193    }
3194
3195    #[test]
3196    fn test_reverse_expr_preserves_non_aliased_display_path() -> Result<()> {
3197        let schema = create_test_schema()?;
3198        let agg = first_value_agg_expr(
3199            &schema,
3200            "b",
3201            "first_value(b) ORDER BY [b ASC NULLS LAST]",
3202            None,
3203            None,
3204        )?;
3205
3206        let reversed = agg.reverse_expr().expect("expected reverse expr");
3207
3208        assert_eq!(
3209            reversed.name(),
3210            "last_value(b) ORDER BY [b DESC NULLS FIRST]"
3211        );
3212        assert_eq!(reversed.human_display(), None);
3213
3214        Ok(())
3215    }
3216
3217    // This function constructs the physical plan below,
3218    //
3219    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
3220    // "  CoalescePartitionsExec",
3221    // "    AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
3222    // "      DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
3223    //
3224    // and checks whether the function `merge_batch` works correctly for
3225    // FIRST_VALUE and LAST_VALUE functions.
3226    async fn first_last_multi_partitions(
3227        is_first_acc: bool,
3228        spill: bool,
3229        max_memory: usize,
3230    ) -> Result<()> {
3231        let task_ctx = if spill {
3232            new_spill_ctx(2, max_memory)
3233        } else {
3234            Arc::new(TaskContext::default())
3235        };
3236
3237        let (schema, data) = some_data_v2();
3238        let partition1 = data[0].clone();
3239        let partition2 = data[1].clone();
3240        let partition3 = data[2].clone();
3241        let partition4 = data[3].clone();
3242
3243        let groups =
3244            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
3245
3246        let sort_options = SortOptions {
3247            descending: false,
3248            nulls_first: false,
3249        };
3250        let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
3251            vec![test_first_value_agg_expr(&schema, sort_options)?]
3252        } else {
3253            vec![test_last_value_agg_expr(&schema, sort_options)?]
3254        };
3255
3256        let memory_exec = TestMemoryExec::try_new_exec(
3257            &[
3258                vec![partition1],
3259                vec![partition2],
3260                vec![partition3],
3261                vec![partition4],
3262            ],
3263            Arc::clone(&schema),
3264            None,
3265        )?;
3266        let aggregate_exec = Arc::new(AggregateExec::try_new(
3267            AggregateMode::Partial,
3268            groups.clone(),
3269            aggregates.clone(),
3270            vec![None],
3271            memory_exec,
3272            Arc::clone(&schema),
3273        )?);
3274        let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec))
3275            as Arc<dyn ExecutionPlan>;
3276        let aggregate_final = Arc::new(AggregateExec::try_new(
3277            AggregateMode::Final,
3278            groups,
3279            aggregates.clone(),
3280            vec![None],
3281            coalesce,
3282            schema,
3283        )?) as Arc<dyn ExecutionPlan>;
3284
3285        let result = crate::collect(aggregate_final, task_ctx).await?;
3286        if is_first_acc {
3287            allow_duplicates! {
3288            assert_snapshot!(batches_to_string(&result), @r"
3289            +---+--------------------------------------------+
3290            | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
3291            +---+--------------------------------------------+
3292            | 2 | 0.0                                        |
3293            | 3 | 1.0                                        |
3294            | 4 | 3.0                                        |
3295            +---+--------------------------------------------+
3296            ");
3297            }
3298        } else {
3299            allow_duplicates! {
3300            assert_snapshot!(batches_to_string(&result), @r"
3301            +---+-------------------------------------------+
3302            | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
3303            +---+-------------------------------------------+
3304            | 2 | 3.0                                       |
3305            | 3 | 5.0                                       |
3306            | 4 | 6.0                                       |
3307            +---+-------------------------------------------+
3308            ");
3309            }
3310        };
3311        Ok(())
3312    }
3313
3314    #[tokio::test]
3315    async fn test_get_finest_requirements() -> Result<()> {
3316        let test_schema = create_test_schema()?;
3317
3318        let options = SortOptions {
3319            descending: false,
3320            nulls_first: false,
3321        };
3322        let col_a = &col("a", &test_schema)?;
3323        let col_b = &col("b", &test_schema)?;
3324        let col_c = &col("c", &test_schema)?;
3325        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
3326        // Columns a and b are equal.
3327        eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
3328        // Aggregate requirements are
3329        // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively
3330        let order_by_exprs = vec![
3331            vec![],
3332            vec![PhysicalSortExpr {
3333                expr: Arc::clone(col_a),
3334                options,
3335            }],
3336            vec![
3337                PhysicalSortExpr {
3338                    expr: Arc::clone(col_a),
3339                    options,
3340                },
3341                PhysicalSortExpr {
3342                    expr: Arc::clone(col_b),
3343                    options,
3344                },
3345                PhysicalSortExpr {
3346                    expr: Arc::clone(col_c),
3347                    options,
3348                },
3349            ],
3350            vec![
3351                PhysicalSortExpr {
3352                    expr: Arc::clone(col_a),
3353                    options,
3354                },
3355                PhysicalSortExpr {
3356                    expr: Arc::clone(col_b),
3357                    options,
3358                },
3359            ],
3360        ];
3361
3362        let common_requirement = vec![
3363            PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
3364            PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
3365        ];
3366        let mut aggr_exprs = order_by_exprs
3367            .into_iter()
3368            .map(|order_by_expr| {
3369                AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
3370                    .alias("a")
3371                    .order_by(order_by_expr)
3372                    .schema(Arc::clone(&test_schema))
3373                    .build()
3374                    .map(Arc::new)
3375                    .unwrap()
3376            })
3377            .collect::<Vec<_>>();
3378        let group_by = PhysicalGroupBy::new_single(vec![]);
3379        let result = get_finer_aggregate_exprs_requirement(
3380            &mut aggr_exprs,
3381            &group_by,
3382            &eq_properties,
3383            &AggregateMode::Partial,
3384        )?;
3385        assert_eq!(result, common_requirement);
3386        Ok(())
3387    }
3388
3389    #[test]
3390    fn test_agg_exec_same_schema() -> Result<()> {
3391        let schema = Arc::new(Schema::new(vec![
3392            Field::new("a", DataType::Float32, true),
3393            Field::new("b", DataType::Float32, true),
3394        ]));
3395
3396        let col_a = col("a", &schema)?;
3397        let option_desc = SortOptions {
3398            descending: true,
3399            nulls_first: true,
3400        };
3401        let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
3402
3403        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3404            test_first_value_agg_expr(&schema, option_desc)?,
3405            test_last_value_agg_expr(&schema, option_desc)?,
3406        ];
3407        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
3408        let aggregate_exec = Arc::new(AggregateExec::try_new(
3409            AggregateMode::Partial,
3410            groups,
3411            aggregates,
3412            vec![None, None],
3413            Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
3414            schema,
3415        )?);
3416        let new_agg =
3417            Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
3418        assert_eq!(new_agg.schema(), aggregate_exec.schema());
3419        Ok(())
3420    }
3421
3422    #[tokio::test]
3423    async fn test_agg_exec_group_by_const() -> Result<()> {
3424        let schema = Arc::new(Schema::new(vec![
3425            Field::new("a", DataType::Float32, true),
3426            Field::new("b", DataType::Float32, true),
3427            Field::new("const", DataType::Int32, false),
3428        ]));
3429
3430        let col_a = col("a", &schema)?;
3431        let col_b = col("b", &schema)?;
3432        let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
3433
3434        let groups = PhysicalGroupBy::new(
3435            vec![
3436                (col_a, "a".to_string()),
3437                (col_b, "b".to_string()),
3438                (const_expr, "const".to_string()),
3439            ],
3440            vec![
3441                (
3442                    Arc::new(Literal::new(ScalarValue::Float32(None))),
3443                    "a".to_string(),
3444                ),
3445                (
3446                    Arc::new(Literal::new(ScalarValue::Float32(None))),
3447                    "b".to_string(),
3448                ),
3449                (
3450                    Arc::new(Literal::new(ScalarValue::Int32(None))),
3451                    "const".to_string(),
3452                ),
3453            ],
3454            vec![
3455                vec![false, true, true],
3456                vec![true, false, true],
3457                vec![true, true, false],
3458            ],
3459            true,
3460        );
3461
3462        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3463            AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
3464                .schema(Arc::clone(&schema))
3465                .alias("1")
3466                .build()
3467                .map(Arc::new)?,
3468        ];
3469
3470        let input_batches = (0..4)
3471            .map(|_| {
3472                let a = Arc::new(Float32Array::from(vec![0.; 8192]));
3473                let b = Arc::new(Float32Array::from(vec![0.; 8192]));
3474                let c = Arc::new(Int32Array::from(vec![1; 8192]));
3475
3476                RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
3477            })
3478            .collect();
3479
3480        let input =
3481            TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
3482
3483        let aggregate_exec = Arc::new(AggregateExec::try_new(
3484            AggregateMode::Single,
3485            groups,
3486            aggregates.clone(),
3487            vec![None],
3488            input,
3489            schema,
3490        )?);
3491
3492        let output =
3493            collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
3494
3495        allow_duplicates! {
3496        assert_snapshot!(batches_to_sort_string(&output), @r"
3497        +-----+-----+-------+---------------+-------+
3498        | a   | b   | const | __grouping_id | 1     |
3499        +-----+-----+-------+---------------+-------+
3500        |     |     | 1     | 6             | 32768 |
3501        |     | 0.0 |       | 5             | 32768 |
3502        | 0.0 |     |       | 3             | 32768 |
3503        +-----+-----+-------+---------------+-------+
3504        ");
3505        }
3506
3507        Ok(())
3508    }
3509
3510    #[tokio::test]
3511    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
3512        let batch = RecordBatch::try_new(
3513            Arc::new(Schema::new(vec![
3514                Field::new(
3515                    "labels".to_string(),
3516                    DataType::Struct(
3517                        vec![
3518                            Field::new(
3519                                "a".to_string(),
3520                                DataType::Dictionary(
3521                                    Box::new(DataType::Int32),
3522                                    Box::new(DataType::Utf8),
3523                                ),
3524                                true,
3525                            ),
3526                            Field::new(
3527                                "b".to_string(),
3528                                DataType::Dictionary(
3529                                    Box::new(DataType::Int32),
3530                                    Box::new(DataType::Utf8),
3531                                ),
3532                                true,
3533                            ),
3534                        ]
3535                        .into(),
3536                    ),
3537                    false,
3538                ),
3539                Field::new("value", DataType::UInt64, false),
3540            ])),
3541            vec![
3542                Arc::new(StructArray::from(vec![
3543                    (
3544                        Arc::new(Field::new(
3545                            "a".to_string(),
3546                            DataType::Dictionary(
3547                                Box::new(DataType::Int32),
3548                                Box::new(DataType::Utf8),
3549                            ),
3550                            true,
3551                        )),
3552                        Arc::new(
3553                            vec![Some("a"), None, Some("a")]
3554                                .into_iter()
3555                                .collect::<DictionaryArray<Int32Type>>(),
3556                        ) as ArrayRef,
3557                    ),
3558                    (
3559                        Arc::new(Field::new(
3560                            "b".to_string(),
3561                            DataType::Dictionary(
3562                                Box::new(DataType::Int32),
3563                                Box::new(DataType::Utf8),
3564                            ),
3565                            true,
3566                        )),
3567                        Arc::new(
3568                            vec![Some("b"), Some("c"), Some("b")]
3569                                .into_iter()
3570                                .collect::<DictionaryArray<Int32Type>>(),
3571                        ) as ArrayRef,
3572                    ),
3573                ])),
3574                Arc::new(UInt64Array::from(vec![1, 1, 1])),
3575            ],
3576        )
3577        .expect("Failed to create RecordBatch");
3578
3579        let group_by = PhysicalGroupBy::new_single(vec![(
3580            col("labels", &batch.schema())?,
3581            "labels".to_string(),
3582        )]);
3583
3584        let aggr_expr = vec![
3585            AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
3586                .schema(Arc::clone(&batch.schema()))
3587                .alias(String::from("SUM(value)"))
3588                .build()
3589                .map(Arc::new)?,
3590        ];
3591
3592        let input = TestMemoryExec::try_new_exec(
3593            &[vec![batch.clone()]],
3594            Arc::<Schema>::clone(&batch.schema()),
3595            None,
3596        )?;
3597        let aggregate_exec = Arc::new(AggregateExec::try_new(
3598            AggregateMode::FinalPartitioned,
3599            group_by,
3600            aggr_expr,
3601            vec![None],
3602            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3603            batch.schema(),
3604        )?);
3605
3606        let session_config = SessionConfig::default();
3607        let ctx = TaskContext::default().with_session_config(session_config);
3608        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3609
3610        allow_duplicates! {
3611        assert_snapshot!(batches_to_string(&output), @r"
3612        +--------------+------------+
3613        | labels       | SUM(value) |
3614        +--------------+------------+
3615        | {a: a, b: b} | 2          |
3616        | {a: , b: c}  | 1          |
3617        +--------------+------------+
3618        ");
3619        }
3620
3621        Ok(())
3622    }
3623
3624    #[tokio::test]
3625    async fn test_skip_aggregation_after_first_batch() -> Result<()> {
3626        let schema = Arc::new(Schema::new(vec![
3627            Field::new("key", DataType::Int32, true),
3628            Field::new("val", DataType::Int32, true),
3629        ]));
3630
3631        let group_by =
3632            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3633
3634        let aggr_expr = vec![
3635            AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3636                .schema(Arc::clone(&schema))
3637                .alias(String::from("COUNT(val)"))
3638                .build()
3639                .map(Arc::new)?,
3640        ];
3641
3642        let input_data = vec![
3643            RecordBatch::try_new(
3644                Arc::clone(&schema),
3645                vec![
3646                    Arc::new(Int32Array::from(vec![1, 2, 3])),
3647                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3648                ],
3649            )
3650            .unwrap(),
3651            RecordBatch::try_new(
3652                Arc::clone(&schema),
3653                vec![
3654                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3655                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3656                ],
3657            )
3658            .unwrap(),
3659        ];
3660
3661        let input =
3662            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3663        let aggregate_exec = Arc::new(AggregateExec::try_new(
3664            AggregateMode::Partial,
3665            group_by,
3666            aggr_expr,
3667            vec![None],
3668            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3669            schema,
3670        )?);
3671
3672        let mut session_config = SessionConfig::default();
3673        session_config = session_config.set(
3674            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3675            &ScalarValue::Int64(Some(2)),
3676        );
3677        session_config = session_config.set(
3678            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3679            &ScalarValue::Float64(Some(0.1)),
3680        );
3681
3682        let ctx = TaskContext::default().with_session_config(session_config);
3683        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3684
3685        allow_duplicates! {
3686            assert_snapshot!(batches_to_string(&output), @r"
3687            +-----+-------------------+
3688            | key | COUNT(val)[count] |
3689            +-----+-------------------+
3690            | 1   | 1                 |
3691            | 2   | 1                 |
3692            | 3   | 1                 |
3693            | 2   | 1                 |
3694            | 3   | 1                 |
3695            | 4   | 1                 |
3696            +-----+-------------------+
3697            ");
3698        }
3699
3700        Ok(())
3701    }
3702
3703    #[tokio::test]
3704    async fn test_skip_aggregation_after_threshold() -> Result<()> {
3705        let schema = Arc::new(Schema::new(vec![
3706            Field::new("key", DataType::Int32, true),
3707            Field::new("val", DataType::Int32, true),
3708        ]));
3709
3710        let group_by =
3711            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3712
3713        let aggr_expr = vec![
3714            AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3715                .schema(Arc::clone(&schema))
3716                .alias(String::from("COUNT(val)"))
3717                .build()
3718                .map(Arc::new)?,
3719        ];
3720
3721        let input_data = vec![
3722            RecordBatch::try_new(
3723                Arc::clone(&schema),
3724                vec![
3725                    Arc::new(Int32Array::from(vec![1, 2, 3])),
3726                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3727                ],
3728            )
3729            .unwrap(),
3730            RecordBatch::try_new(
3731                Arc::clone(&schema),
3732                vec![
3733                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3734                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3735                ],
3736            )
3737            .unwrap(),
3738            RecordBatch::try_new(
3739                Arc::clone(&schema),
3740                vec![
3741                    Arc::new(Int32Array::from(vec![2, 3, 4])),
3742                    Arc::new(Int32Array::from(vec![0, 0, 0])),
3743                ],
3744            )
3745            .unwrap(),
3746        ];
3747
3748        let input =
3749            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3750        let aggregate_exec = Arc::new(AggregateExec::try_new(
3751            AggregateMode::Partial,
3752            group_by,
3753            aggr_expr,
3754            vec![None],
3755            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3756            schema,
3757        )?);
3758
3759        let mut session_config = SessionConfig::default();
3760        session_config = session_config.set(
3761            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3762            &ScalarValue::Int64(Some(5)),
3763        );
3764        session_config = session_config.set(
3765            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3766            &ScalarValue::Float64(Some(0.1)),
3767        );
3768
3769        let ctx = TaskContext::default().with_session_config(session_config);
3770        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3771
3772        allow_duplicates! {
3773            assert_snapshot!(batches_to_string(&output), @r"
3774            +-----+-------------------+
3775            | key | COUNT(val)[count] |
3776            +-----+-------------------+
3777            | 1   | 1                 |
3778            | 2   | 2                 |
3779            | 3   | 2                 |
3780            | 4   | 1                 |
3781            | 2   | 1                 |
3782            | 3   | 1                 |
3783            | 4   | 1                 |
3784            +-----+-------------------+
3785            ");
3786        }
3787
3788        Ok(())
3789    }
3790
3791    #[test]
3792    fn group_exprs_nullable() -> Result<()> {
3793        let input_schema = Arc::new(Schema::new(vec![
3794            Field::new("a", DataType::Float32, false),
3795            Field::new("b", DataType::Float32, false),
3796        ]));
3797
3798        let aggr_expr = vec![
3799            AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
3800                .schema(Arc::clone(&input_schema))
3801                .alias("COUNT(a)")
3802                .build()
3803                .map(Arc::new)?,
3804        ];
3805
3806        let grouping_set = PhysicalGroupBy::new(
3807            vec![
3808                (col("a", &input_schema)?, "a".to_string()),
3809                (col("b", &input_schema)?, "b".to_string()),
3810            ],
3811            vec![
3812                (lit(ScalarValue::Float32(None)), "a".to_string()),
3813                (lit(ScalarValue::Float32(None)), "b".to_string()),
3814            ],
3815            vec![
3816                vec![false, true],  // (a, NULL)
3817                vec![false, false], // (a,b)
3818            ],
3819            true,
3820        );
3821        let aggr_schema = create_schema(
3822            &input_schema,
3823            &grouping_set,
3824            &aggr_expr,
3825            AggregateMode::Final,
3826        )?;
3827        let expected_schema = Schema::new(vec![
3828            Field::new("a", DataType::Float32, false),
3829            Field::new("b", DataType::Float32, true),
3830            Field::new("__grouping_id", DataType::UInt8, false),
3831            Field::new("COUNT(a)", DataType::Int64, false),
3832        ]);
3833        assert_eq!(aggr_schema, expected_schema);
3834        Ok(())
3835    }
3836
3837    // test for https://github.com/apache/datafusion/issues/13949
3838    async fn run_test_with_spill_pool_if_necessary(
3839        pool_size: usize,
3840        expect_spill: bool,
3841    ) -> Result<()> {
3842        fn create_record_batch(
3843            schema: &Arc<Schema>,
3844            data: (Vec<u32>, Vec<f64>),
3845        ) -> Result<RecordBatch> {
3846            Ok(RecordBatch::try_new(
3847                Arc::clone(schema),
3848                vec![
3849                    Arc::new(UInt32Array::from(data.0)),
3850                    Arc::new(Float64Array::from(data.1)),
3851                ],
3852            )?)
3853        }
3854
3855        let schema = Arc::new(Schema::new(vec![
3856            Field::new("a", DataType::UInt32, false),
3857            Field::new("b", DataType::Float64, false),
3858        ]));
3859
3860        let batches = vec![
3861            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3862            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3863        ];
3864        let plan: Arc<dyn ExecutionPlan> =
3865            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3866
3867        let grouping_set = PhysicalGroupBy::new(
3868            vec![(col("a", &schema)?, "a".to_string())],
3869            vec![],
3870            vec![vec![false]],
3871            false,
3872        );
3873
3874        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
3875        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3876            Arc::new(
3877                AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3878                    .schema(Arc::clone(&schema))
3879                    .alias("MIN(b)")
3880                    .build()?,
3881            ),
3882            Arc::new(
3883                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3884                    .schema(Arc::clone(&schema))
3885                    .alias("AVG(b)")
3886                    .build()?,
3887            ),
3888        ];
3889
3890        let single_aggregate = Arc::new(AggregateExec::try_new(
3891            AggregateMode::Single,
3892            grouping_set,
3893            aggregates,
3894            vec![None, None],
3895            plan,
3896            Arc::clone(&schema),
3897        )?);
3898
3899        let batch_size = 2;
3900        let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3901        let task_ctx = Arc::new(
3902            TaskContext::default()
3903                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3904                .with_runtime(Arc::new(
3905                    RuntimeEnvBuilder::new()
3906                        .with_memory_pool(memory_pool)
3907                        .build()?,
3908                )),
3909        );
3910
3911        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3912
3913        assert_spill_count_metric(expect_spill, single_aggregate);
3914
3915        allow_duplicates! {
3916            assert_snapshot!(batches_to_string(&result), @r"
3917            +---+--------+--------+
3918            | a | MIN(b) | AVG(b) |
3919            +---+--------+--------+
3920            | 2 | 1.0    | 1.0    |
3921            | 3 | 2.0    | 2.0    |
3922            | 4 | 3.0    | 3.5    |
3923            +---+--------+--------+
3924            ");
3925        }
3926
3927        Ok(())
3928    }
3929
3930    fn assert_spill_count_metric(
3931        expect_spill: bool,
3932        single_aggregate: Arc<AggregateExec>,
3933    ) {
3934        if let Some(metrics_set) = single_aggregate.metrics() {
3935            let mut spill_count = 0;
3936
3937            // Inspect metrics for SpillCount
3938            for metric in metrics_set.iter() {
3939                if let MetricValue::SpillCount(count) = metric.value() {
3940                    spill_count = count.value();
3941                    break;
3942                }
3943            }
3944
3945            if expect_spill && spill_count == 0 {
3946                panic!(
3947                    "Expected spill but SpillCount metric not found or SpillCount was 0."
3948                );
3949            } else if !expect_spill && spill_count > 0 {
3950                panic!(
3951                    "Expected no spill but found SpillCount metric with value greater than 0."
3952                );
3953            }
3954        } else {
3955            panic!("No metrics returned from the operator; cannot verify spilling.");
3956        }
3957    }
3958
3959    #[tokio::test]
3960    async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3961        // test with spill
3962        run_test_with_spill_pool_if_necessary(2_000, true).await?;
3963        // test without spill
3964        run_test_with_spill_pool_if_necessary(20_000, false).await?;
3965        Ok(())
3966    }
3967
3968    #[tokio::test]
3969    async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
3970        // test with spill
3971        fn create_record_batch(
3972            schema: &Arc<Schema>,
3973            data: (Vec<u32>, Vec<f64>),
3974        ) -> Result<RecordBatch> {
3975            Ok(RecordBatch::try_new(
3976                Arc::clone(schema),
3977                vec![
3978                    Arc::new(UInt32Array::from(data.0)),
3979                    Arc::new(Float64Array::from(data.1)),
3980                ],
3981            )?)
3982        }
3983
3984        let schema = Arc::new(Schema::new(vec![
3985            Field::new("a", DataType::UInt32, false),
3986            Field::new("b", DataType::Float64, false),
3987        ]));
3988
3989        let batches = vec![
3990            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3991            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3992        ];
3993        let plan: Arc<dyn ExecutionPlan> =
3994            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3995        let proj = ProjectionExec::try_new(
3996            vec![
3997                ProjectionExpr::new(lit("0"), "l".to_string()),
3998                ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
3999                ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
4000            ],
4001            plan,
4002        )?;
4003        let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
4004        let schema = plan.schema();
4005
4006        let grouping_set = PhysicalGroupBy::new(
4007            vec![
4008                (col("l", &schema)?, "l".to_string()),
4009                (col("a", &schema)?, "a".to_string()),
4010            ],
4011            vec![],
4012            vec![vec![false, false]],
4013            false,
4014        );
4015
4016        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
4017        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
4018            Arc::new(
4019                AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
4020                    .schema(Arc::clone(&schema))
4021                    .alias("MIN(b)")
4022                    .build()?,
4023            ),
4024            Arc::new(
4025                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
4026                    .schema(Arc::clone(&schema))
4027                    .alias("AVG(b)")
4028                    .build()?,
4029            ),
4030        ];
4031
4032        let single_aggregate = Arc::new(AggregateExec::try_new(
4033            AggregateMode::Single,
4034            grouping_set,
4035            aggregates,
4036            vec![None, None],
4037            plan,
4038            Arc::clone(&schema),
4039        )?);
4040
4041        let batch_size = 2;
4042        let memory_pool = Arc::new(FairSpillPool::new(2000));
4043        let task_ctx = Arc::new(
4044            TaskContext::default()
4045                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
4046                .with_runtime(Arc::new(
4047                    RuntimeEnvBuilder::new()
4048                        .with_memory_pool(memory_pool)
4049                        .build()?,
4050                )),
4051        );
4052
4053        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
4054        match result {
4055            Ok(result) => {
4056                assert_spill_count_metric(true, single_aggregate);
4057
4058                allow_duplicates! {
4059                    assert_snapshot!(batches_to_string(&result), @r"
4060                +---+---+--------+--------+
4061                | l | a | MIN(b) | AVG(b) |
4062                +---+---+--------+--------+
4063                | 0 | 2 | 1.0    | 1.0    |
4064                | 0 | 3 | 2.0    | 2.0    |
4065                | 0 | 4 | 3.0    | 3.5    |
4066                +---+---+--------+--------+
4067            ");
4068                }
4069            }
4070            Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
4071        }
4072
4073        Ok(())
4074    }
4075
4076    #[tokio::test]
4077    async fn test_aggregate_statistics_edge_cases() -> Result<()> {
4078        use datafusion_common::ColumnStatistics;
4079
4080        let schema = Arc::new(Schema::new(vec![
4081            Field::new("a", DataType::Int32, false),
4082            Field::new("b", DataType::Float64, false),
4083        ]));
4084
4085        let absent_byte_stats = Statistics {
4086            num_rows: Precision::Exact(100),
4087            total_byte_size: Precision::Absent,
4088            column_statistics: vec![
4089                ColumnStatistics::new_unknown(),
4090                ColumnStatistics::new_unknown(),
4091            ],
4092        };
4093        let agg = build_test_aggregate(
4094            &schema,
4095            absent_byte_stats,
4096            PhysicalGroupBy::default(),
4097            None,
4098        )?;
4099        let stats = agg.partition_statistics(None)?;
4100        assert_eq!(stats.total_byte_size, Precision::Absent);
4101
4102        let zero_row_stats = Statistics {
4103            num_rows: Precision::Exact(0),
4104            total_byte_size: Precision::Exact(0),
4105            column_statistics: vec![
4106                ColumnStatistics::new_unknown(),
4107                ColumnStatistics::new_unknown(),
4108            ],
4109        };
4110        let agg_zero = build_test_aggregate(
4111            &schema,
4112            zero_row_stats,
4113            PhysicalGroupBy::default(),
4114            None,
4115        )?;
4116        let stats_zero = agg_zero.partition_statistics(None)?;
4117        assert_eq!(stats_zero.total_byte_size, Precision::Absent);
4118
4119        Ok(())
4120    }
4121
4122    fn build_test_aggregate(
4123        schema: &SchemaRef,
4124        stats: Statistics,
4125        group_by: PhysicalGroupBy,
4126        limit: Option<LimitOptions>,
4127    ) -> Result<AggregateExec> {
4128        let input = Arc::new(StatisticsExec::new(stats, (**schema).clone()))
4129            as Arc<dyn ExecutionPlan>;
4130
4131        let mut agg = AggregateExec::try_new(
4132            AggregateMode::Final,
4133            group_by,
4134            vec![Arc::new(
4135                AggregateExprBuilder::new(count_udaf(), vec![col("a", schema)?])
4136                    .schema(Arc::clone(schema))
4137                    .alias("COUNT(a)")
4138                    .build()?,
4139            )],
4140            vec![None],
4141            input,
4142            Arc::clone(schema),
4143        )?;
4144
4145        if let Some(limit) = limit {
4146            agg = agg.with_limit_options(Some(limit));
4147        }
4148
4149        Ok(agg)
4150    }
4151
4152    fn simple_group_by(schema: &SchemaRef, cols: &[&str]) -> PhysicalGroupBy {
4153        if cols.is_empty() {
4154            PhysicalGroupBy::default()
4155        } else {
4156            PhysicalGroupBy::new_single(
4157                cols.iter()
4158                    .map(|name| {
4159                        (
4160                            col(name, schema).unwrap() as Arc<dyn PhysicalExpr>,
4161                            name.to_string(),
4162                        )
4163                    })
4164                    .collect(),
4165            )
4166        }
4167    }
4168
4169    #[test]
4170    fn test_aggregate_cardinality_estimation() -> Result<()> {
4171        use datafusion_common::ColumnStatistics;
4172
4173        let schema = Arc::new(Schema::new(vec![
4174            Field::new("a", DataType::Int32, true),
4175            Field::new("b", DataType::Int32, true),
4176        ]));
4177
4178        struct TestCase {
4179            name: &'static str,
4180            input_rows: Precision<usize>,
4181            col_a_stats: ColumnStatistics,
4182            col_b_stats: ColumnStatistics,
4183            group_by_cols: Vec<&'static str>,
4184            limit_options: Option<LimitOptions>,
4185            expected_num_rows: Precision<usize>,
4186        }
4187
4188        let cases = vec![
4189            // --- NDV-based estimation ---
4190            TestCase {
4191                name: "single group-by col with NDV tightens estimate",
4192                input_rows: Precision::Exact(1_000_000),
4193                col_a_stats: ColumnStatistics {
4194                    distinct_count: Precision::Exact(500),
4195                    ..ColumnStatistics::new_unknown()
4196                },
4197                col_b_stats: ColumnStatistics::new_unknown(),
4198                group_by_cols: vec!["a"],
4199                limit_options: None,
4200                expected_num_rows: Precision::Inexact(500),
4201            },
4202            TestCase {
4203                name: "multi-col group-by multiplies NDVs",
4204                input_rows: Precision::Exact(1_000_000),
4205                col_a_stats: ColumnStatistics {
4206                    distinct_count: Precision::Exact(100),
4207                    ..ColumnStatistics::new_unknown()
4208                },
4209                col_b_stats: ColumnStatistics {
4210                    distinct_count: Precision::Exact(50),
4211                    ..ColumnStatistics::new_unknown()
4212                },
4213                group_by_cols: vec!["a", "b"],
4214                limit_options: None,
4215                expected_num_rows: Precision::Inexact(5_000),
4216            },
4217            TestCase {
4218                name: "NDV product capped by input rows",
4219                input_rows: Precision::Exact(200),
4220                col_a_stats: ColumnStatistics {
4221                    distinct_count: Precision::Exact(100),
4222                    ..ColumnStatistics::new_unknown()
4223                },
4224                col_b_stats: ColumnStatistics {
4225                    distinct_count: Precision::Exact(50),
4226                    ..ColumnStatistics::new_unknown()
4227                },
4228                group_by_cols: vec!["a", "b"],
4229                limit_options: None,
4230                expected_num_rows: Precision::Inexact(200),
4231            },
4232            TestCase {
4233                name: "null adjustment adds +1 per column",
4234                input_rows: Precision::Exact(1_000_000),
4235                col_a_stats: ColumnStatistics {
4236                    distinct_count: Precision::Exact(99),
4237                    null_count: Precision::Exact(10),
4238                    ..ColumnStatistics::new_unknown()
4239                },
4240                col_b_stats: ColumnStatistics::new_unknown(),
4241                group_by_cols: vec!["a"],
4242                limit_options: None,
4243                // 99 + 1 (null adjustment) = 100
4244                expected_num_rows: Precision::Inexact(100),
4245            },
4246            TestCase {
4247                name: "null adjustment on multiple columns",
4248                input_rows: Precision::Exact(1_000_000),
4249                col_a_stats: ColumnStatistics {
4250                    distinct_count: Precision::Exact(99),
4251                    null_count: Precision::Exact(5),
4252                    ..ColumnStatistics::new_unknown()
4253                },
4254                col_b_stats: ColumnStatistics {
4255                    distinct_count: Precision::Exact(49),
4256                    null_count: Precision::Exact(3),
4257                    ..ColumnStatistics::new_unknown()
4258                },
4259                group_by_cols: vec!["a", "b"],
4260                limit_options: None,
4261                // (99+1) * (49+1) = 100 * 50 = 5000
4262                expected_num_rows: Precision::Inexact(5_000),
4263            },
4264            TestCase {
4265                name: "zero null_count means no adjustment",
4266                input_rows: Precision::Exact(1_000_000),
4267                col_a_stats: ColumnStatistics {
4268                    distinct_count: Precision::Exact(100),
4269                    null_count: Precision::Exact(0),
4270                    ..ColumnStatistics::new_unknown()
4271                },
4272                col_b_stats: ColumnStatistics::new_unknown(),
4273                group_by_cols: vec!["a"],
4274                limit_options: None,
4275                expected_num_rows: Precision::Inexact(100),
4276            },
4277            // --- Bail-out: partial NDV stats (Spark-style) ---
4278            TestCase {
4279                name: "bail out when one group-by col lacks NDV",
4280                input_rows: Precision::Exact(1_000_000),
4281                col_a_stats: ColumnStatistics {
4282                    distinct_count: Precision::Exact(100),
4283                    ..ColumnStatistics::new_unknown()
4284                },
4285                col_b_stats: ColumnStatistics::new_unknown(),
4286                group_by_cols: vec!["a", "b"],
4287                limit_options: None,
4288                expected_num_rows: Precision::Inexact(1_000_000),
4289            },
4290            TestCase {
4291                name: "bail out when all group-by cols lack NDV",
4292                input_rows: Precision::Exact(1_000_000),
4293                col_a_stats: ColumnStatistics::new_unknown(),
4294                col_b_stats: ColumnStatistics::new_unknown(),
4295                group_by_cols: vec!["a"],
4296                limit_options: None,
4297                expected_num_rows: Precision::Inexact(1_000_000),
4298            },
4299            // --- TopK limit capping ---
4300            TestCase {
4301                name: "TopK limit caps output rows",
4302                input_rows: Precision::Exact(1_000_000),
4303                col_a_stats: ColumnStatistics::new_unknown(),
4304                col_b_stats: ColumnStatistics::new_unknown(),
4305                group_by_cols: vec!["a"],
4306                limit_options: Some(LimitOptions::new(10)),
4307                expected_num_rows: Precision::Inexact(10),
4308            },
4309            TestCase {
4310                name: "NDV + TopK limit: min(NDV, limit) when NDV < limit",
4311                input_rows: Precision::Exact(1_000_000),
4312                col_a_stats: ColumnStatistics {
4313                    distinct_count: Precision::Exact(5),
4314                    ..ColumnStatistics::new_unknown()
4315                },
4316                col_b_stats: ColumnStatistics::new_unknown(),
4317                group_by_cols: vec!["a"],
4318                limit_options: Some(LimitOptions::new(10)),
4319                expected_num_rows: Precision::Inexact(5),
4320            },
4321            TestCase {
4322                name: "NDV + TopK limit: min(NDV, limit) when limit < NDV",
4323                input_rows: Precision::Exact(1_000_000),
4324                col_a_stats: ColumnStatistics {
4325                    distinct_count: Precision::Exact(500),
4326                    ..ColumnStatistics::new_unknown()
4327                },
4328                col_b_stats: ColumnStatistics::new_unknown(),
4329                group_by_cols: vec!["a"],
4330                limit_options: Some(LimitOptions::new(10)),
4331                expected_num_rows: Precision::Inexact(10),
4332            },
4333            // --- Absent input rows ---
4334            TestCase {
4335                name: "absent input rows without limit stays absent",
4336                input_rows: Precision::Absent,
4337                col_a_stats: ColumnStatistics::new_unknown(),
4338                col_b_stats: ColumnStatistics::new_unknown(),
4339                group_by_cols: vec!["a"],
4340                limit_options: None,
4341                expected_num_rows: Precision::Absent,
4342            },
4343            TestCase {
4344                name: "absent input rows with TopK limit gives inexact(limit)",
4345                input_rows: Precision::Absent,
4346                col_a_stats: ColumnStatistics::new_unknown(),
4347                col_b_stats: ColumnStatistics::new_unknown(),
4348                group_by_cols: vec!["a"],
4349                limit_options: Some(LimitOptions::new(10)),
4350                expected_num_rows: Precision::Inexact(10),
4351            },
4352            // --- No group-by (global aggregation) ---
4353            TestCase {
4354                name: "no group-by cols (Final mode) returns Exact(1)",
4355                input_rows: Precision::Exact(1_000_000),
4356                col_a_stats: ColumnStatistics::new_unknown(),
4357                col_b_stats: ColumnStatistics::new_unknown(),
4358                group_by_cols: vec![],
4359                limit_options: None,
4360                expected_num_rows: Precision::Exact(1),
4361            },
4362            // --- One input row ---
4363            TestCase {
4364                name: "one input row returns Exact(1)",
4365                input_rows: Precision::Exact(1),
4366                col_a_stats: ColumnStatistics {
4367                    distinct_count: Precision::Exact(1),
4368                    ..ColumnStatistics::new_unknown()
4369                },
4370                col_b_stats: ColumnStatistics::new_unknown(),
4371                group_by_cols: vec!["a"],
4372                limit_options: None,
4373                expected_num_rows: Precision::Exact(1),
4374            },
4375            // --- Zero input rows ---
4376            TestCase {
4377                name: "zero input rows returns Exact(0)",
4378                input_rows: Precision::Exact(0),
4379                col_a_stats: ColumnStatistics::new_unknown(),
4380                col_b_stats: ColumnStatistics::new_unknown(),
4381                group_by_cols: vec!["a"],
4382                limit_options: None,
4383                expected_num_rows: Precision::Exact(0),
4384            },
4385            // --- Inexact NDV stats ---
4386            TestCase {
4387                name: "inexact NDV still used for estimation",
4388                input_rows: Precision::Exact(1_000_000),
4389                col_a_stats: ColumnStatistics {
4390                    distinct_count: Precision::Inexact(200),
4391                    ..ColumnStatistics::new_unknown()
4392                },
4393                col_b_stats: ColumnStatistics::new_unknown(),
4394                group_by_cols: vec!["a"],
4395                limit_options: None,
4396                expected_num_rows: Precision::Inexact(200),
4397            },
4398            TestCase {
4399                name: "inexact NDV combined with limit",
4400                input_rows: Precision::Exact(1_000_000),
4401                col_a_stats: ColumnStatistics {
4402                    distinct_count: Precision::Inexact(200),
4403                    ..ColumnStatistics::new_unknown()
4404                },
4405                col_b_stats: ColumnStatistics::new_unknown(),
4406                group_by_cols: vec!["a"],
4407                limit_options: Some(LimitOptions::new(10)),
4408                expected_num_rows: Precision::Inexact(10),
4409            },
4410            // --- NDV zero column (all-null) ---
4411            TestCase {
4412                name: "all-null column contributes 1 to the product, not 0",
4413                input_rows: Precision::Exact(1_000),
4414                col_a_stats: ColumnStatistics {
4415                    distinct_count: Precision::Exact(0),
4416                    null_count: Precision::Exact(1_000),
4417                    ..ColumnStatistics::new_unknown()
4418                },
4419                col_b_stats: ColumnStatistics {
4420                    distinct_count: Precision::Exact(50),
4421                    ..ColumnStatistics::new_unknown()
4422                },
4423                group_by_cols: vec!["a", "b"],
4424                limit_options: None,
4425                // NDV(a)=0 with nulls => max(0+1, 1)=1, NDV(b)=50 => 1*50=50
4426                expected_num_rows: Precision::Inexact(50),
4427            },
4428            // --- Absent num_rows with NDV ---
4429            TestCase {
4430                name: "absent num_rows falls back to NDV estimate",
4431                input_rows: Precision::Absent,
4432                col_a_stats: ColumnStatistics {
4433                    distinct_count: Precision::Exact(100),
4434                    ..ColumnStatistics::new_unknown()
4435                },
4436                col_b_stats: ColumnStatistics::new_unknown(),
4437                group_by_cols: vec!["a"],
4438                limit_options: None,
4439                expected_num_rows: Precision::Inexact(100),
4440            },
4441            TestCase {
4442                name: "absent num_rows with NDV and limit returns min(ndv, limit)",
4443                input_rows: Precision::Absent,
4444                col_a_stats: ColumnStatistics {
4445                    distinct_count: Precision::Exact(100),
4446                    ..ColumnStatistics::new_unknown()
4447                },
4448                col_b_stats: ColumnStatistics::new_unknown(),
4449                group_by_cols: vec!["a"],
4450                limit_options: Some(LimitOptions::new(10)),
4451                expected_num_rows: Precision::Inexact(10),
4452            },
4453        ];
4454
4455        for case in cases {
4456            let input_stats = Statistics {
4457                num_rows: case.input_rows,
4458                total_byte_size: Precision::Inexact(1_000_000),
4459                column_statistics: vec![
4460                    case.col_a_stats.clone(),
4461                    case.col_b_stats.clone(),
4462                ],
4463            };
4464
4465            let group_by = simple_group_by(&schema, &case.group_by_cols);
4466            let agg =
4467                build_test_aggregate(&schema, input_stats, group_by, case.limit_options)?;
4468
4469            let stats = agg.partition_statistics(None)?;
4470            assert_eq!(
4471                stats.num_rows, case.expected_num_rows,
4472                "FAILED: '{}' — expected {:?}, got {:?}",
4473                case.name, case.expected_num_rows, stats.num_rows
4474            );
4475        }
4476
4477        Ok(())
4478    }
4479
4480    #[test]
4481    fn test_aggregate_stats_distinct_count_propagation() -> Result<()> {
4482        use datafusion_common::ColumnStatistics;
4483
4484        let schema = Arc::new(Schema::new(vec![
4485            Field::new("a", DataType::Int32, true),
4486            Field::new("b", DataType::Int32, true),
4487        ]));
4488
4489        let input_stats = Statistics {
4490            num_rows: Precision::Exact(1000),
4491            total_byte_size: Precision::Inexact(10000),
4492            column_statistics: vec![
4493                ColumnStatistics {
4494                    distinct_count: Precision::Exact(100),
4495                    null_count: Precision::Exact(5),
4496                    ..ColumnStatistics::new_unknown()
4497                },
4498                ColumnStatistics::new_unknown(),
4499            ],
4500        };
4501        let agg = build_test_aggregate(
4502            &schema,
4503            input_stats,
4504            simple_group_by(&schema, &["a"]),
4505            None,
4506        )?;
4507
4508        let stats = agg.partition_statistics(None)?;
4509        assert_eq!(
4510            stats.column_statistics[0].distinct_count,
4511            Precision::Exact(100),
4512            "distinct_count should be propagated from child for group-by columns"
4513        );
4514
4515        Ok(())
4516    }
4517
4518    #[test]
4519    fn test_aggregate_stats_grouping_sets() -> Result<()> {
4520        use datafusion_common::ColumnStatistics;
4521
4522        let schema = Arc::new(Schema::new(vec![
4523            Field::new("a", DataType::Int32, true),
4524            Field::new("b", DataType::Int32, true),
4525        ]));
4526
4527        let input_stats = Statistics {
4528            num_rows: Precision::Exact(1_000_000),
4529            total_byte_size: Precision::Inexact(1_000_000),
4530            column_statistics: vec![
4531                ColumnStatistics {
4532                    distinct_count: Precision::Exact(100),
4533                    ..ColumnStatistics::new_unknown()
4534                },
4535                ColumnStatistics {
4536                    distinct_count: Precision::Exact(50),
4537                    ..ColumnStatistics::new_unknown()
4538                },
4539            ],
4540        };
4541
4542        // CUBE-like grouping set: (a, NULL), (NULL, b), (a, b) — 3 groups
4543        let grouping_set = PhysicalGroupBy::new(
4544            vec![
4545                (col("a", &schema)? as Arc<dyn PhysicalExpr>, "a".to_string()),
4546                (col("b", &schema)? as Arc<dyn PhysicalExpr>, "b".to_string()),
4547            ],
4548            vec![
4549                (lit(ScalarValue::Int32(None)), "a".to_string()),
4550                (lit(ScalarValue::Int32(None)), "b".to_string()),
4551            ],
4552            vec![
4553                vec![false, true],  // (a, NULL)
4554                vec![true, false],  // (NULL, b)
4555                vec![false, false], // (a, b)
4556            ],
4557            true,
4558        );
4559
4560        let agg = build_test_aggregate(&schema, input_stats, grouping_set, None)?;
4561
4562        let stats = agg.partition_statistics(None)?;
4563        // Per-set NDV: (a,NULL)=100, (NULL,b)=50, (a,b)=100*50=5000
4564        // Total = 100 + 50 + 5000 = 5150
4565        assert_eq!(
4566            stats.num_rows,
4567            Precision::Inexact(5_150),
4568            "grouping sets should sum per-set NDV products"
4569        );
4570
4571        Ok(())
4572    }
4573
4574    #[test]
4575    fn test_aggregate_stats_non_column_expr_bails_out() -> Result<()> {
4576        use datafusion_common::ColumnStatistics;
4577        use datafusion_expr::Operator;
4578        use datafusion_physical_expr::expressions::BinaryExpr;
4579
4580        let schema = Arc::new(Schema::new(vec![
4581            Field::new("a", DataType::Int32, true),
4582            Field::new("b", DataType::Int32, true),
4583        ]));
4584
4585        let input_stats = Statistics {
4586            num_rows: Precision::Exact(1_000_000),
4587            total_byte_size: Precision::Inexact(1_000_000),
4588            column_statistics: vec![
4589                ColumnStatistics {
4590                    distinct_count: Precision::Exact(100),
4591                    ..ColumnStatistics::new_unknown()
4592                },
4593                ColumnStatistics {
4594                    distinct_count: Precision::Exact(50),
4595                    ..ColumnStatistics::new_unknown()
4596                },
4597            ],
4598        };
4599
4600        // GROUP BY (a + b) — not a direct column reference
4601        let expr_a_plus_b: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
4602            col("a", &schema)?,
4603            Operator::Plus,
4604            col("b", &schema)?,
4605        ));
4606
4607        let group_by =
4608            PhysicalGroupBy::new_single(vec![(expr_a_plus_b, "a+b".to_string())]);
4609        let agg = build_test_aggregate(&schema, input_stats, group_by, None)?;
4610
4611        let stats = agg.partition_statistics(None)?;
4612        assert_eq!(
4613            stats.num_rows,
4614            Precision::Inexact(1_000_000),
4615            "non-column group-by expression should bail out to input_rows"
4616        );
4617
4618        Ok(())
4619    }
4620
4621    #[tokio::test]
4622    async fn test_order_is_retained_when_spilling() -> Result<()> {
4623        let schema = Arc::new(Schema::new(vec![
4624            Field::new("a", DataType::Int64, false),
4625            Field::new("b", DataType::Int64, false),
4626            Field::new("c", DataType::Int64, false),
4627        ]));
4628
4629        let batches = vec![vec![
4630            RecordBatch::try_new(
4631                Arc::clone(&schema),
4632                vec![
4633                    Arc::new(Int64Array::from(vec![2])),
4634                    Arc::new(Int64Array::from(vec![2])),
4635                    Arc::new(Int64Array::from(vec![1])),
4636                ],
4637            )?,
4638            RecordBatch::try_new(
4639                Arc::clone(&schema),
4640                vec![
4641                    Arc::new(Int64Array::from(vec![1])),
4642                    Arc::new(Int64Array::from(vec![1])),
4643                    Arc::new(Int64Array::from(vec![1])),
4644                ],
4645            )?,
4646            RecordBatch::try_new(
4647                Arc::clone(&schema),
4648                vec![
4649                    Arc::new(Int64Array::from(vec![0])),
4650                    Arc::new(Int64Array::from(vec![0])),
4651                    Arc::new(Int64Array::from(vec![1])),
4652                ],
4653            )?,
4654        ]];
4655        let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
4656        let scan = scan.try_with_sort_information(vec![
4657            LexOrdering::new([PhysicalSortExpr::new(
4658                col("b", schema.as_ref())?,
4659                SortOptions::default().desc(),
4660            )])
4661            .unwrap(),
4662        ])?;
4663
4664        let aggr = Arc::new(AggregateExec::try_new(
4665            AggregateMode::Single,
4666            PhysicalGroupBy::new(
4667                vec![
4668                    (col("b", schema.as_ref())?, "b".to_string()),
4669                    (col("c", schema.as_ref())?, "c".to_string()),
4670                ],
4671                vec![],
4672                vec![vec![false, false]],
4673                false,
4674            ),
4675            vec![Arc::new(
4676                AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
4677                    .schema(Arc::clone(&schema))
4678                    .alias("SUM(c)")
4679                    .build()?,
4680            )],
4681            vec![None],
4682            Arc::new(scan) as Arc<dyn ExecutionPlan>,
4683            Arc::clone(&schema),
4684        )?);
4685
4686        let task_ctx = new_spill_ctx(1, 600);
4687        let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
4688        assert_spill_count_metric(true, aggr);
4689
4690        allow_duplicates! {
4691            assert_snapshot!(batches_to_string(&result), @r"
4692            +---+---+--------+
4693            | b | c | SUM(c) |
4694            +---+---+--------+
4695            | 2 | 1 | 1      |
4696            | 1 | 1 | 1      |
4697            | 0 | 1 | 1      |
4698            +---+---+--------+
4699        ");
4700        }
4701        Ok(())
4702    }
4703
4704    /// Tests that when the memory pool is too small to accommodate the sort
4705    /// reservation during spill, the error is properly propagated as
4706    /// ResourcesExhausted rather than silently exceeding memory limits.
4707    #[tokio::test]
4708    async fn test_sort_reservation_fails_during_spill() -> Result<()> {
4709        let schema = Arc::new(Schema::new(vec![
4710            Field::new("g", DataType::Int64, false),
4711            Field::new("a", DataType::Float64, false),
4712            Field::new("b", DataType::Float64, false),
4713            Field::new("c", DataType::Float64, false),
4714            Field::new("d", DataType::Float64, false),
4715            Field::new("e", DataType::Float64, false),
4716        ]));
4717
4718        let batches = vec![vec![
4719            RecordBatch::try_new(
4720                Arc::clone(&schema),
4721                vec![
4722                    Arc::new(Int64Array::from(vec![1])),
4723                    Arc::new(Float64Array::from(vec![10.0])),
4724                    Arc::new(Float64Array::from(vec![20.0])),
4725                    Arc::new(Float64Array::from(vec![30.0])),
4726                    Arc::new(Float64Array::from(vec![40.0])),
4727                    Arc::new(Float64Array::from(vec![50.0])),
4728                ],
4729            )?,
4730            RecordBatch::try_new(
4731                Arc::clone(&schema),
4732                vec![
4733                    Arc::new(Int64Array::from(vec![2])),
4734                    Arc::new(Float64Array::from(vec![11.0])),
4735                    Arc::new(Float64Array::from(vec![21.0])),
4736                    Arc::new(Float64Array::from(vec![31.0])),
4737                    Arc::new(Float64Array::from(vec![41.0])),
4738                    Arc::new(Float64Array::from(vec![51.0])),
4739                ],
4740            )?,
4741            RecordBatch::try_new(
4742                Arc::clone(&schema),
4743                vec![
4744                    Arc::new(Int64Array::from(vec![3])),
4745                    Arc::new(Float64Array::from(vec![12.0])),
4746                    Arc::new(Float64Array::from(vec![22.0])),
4747                    Arc::new(Float64Array::from(vec![32.0])),
4748                    Arc::new(Float64Array::from(vec![42.0])),
4749                    Arc::new(Float64Array::from(vec![52.0])),
4750                ],
4751            )?,
4752        ]];
4753
4754        let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
4755
4756        let aggr = Arc::new(AggregateExec::try_new(
4757            AggregateMode::Single,
4758            PhysicalGroupBy::new(
4759                vec![(col("g", schema.as_ref())?, "g".to_string())],
4760                vec![],
4761                vec![vec![false]],
4762                false,
4763            ),
4764            vec![
4765                Arc::new(
4766                    AggregateExprBuilder::new(
4767                        avg_udaf(),
4768                        vec![col("a", schema.as_ref())?],
4769                    )
4770                    .schema(Arc::clone(&schema))
4771                    .alias("AVG(a)")
4772                    .build()?,
4773                ),
4774                Arc::new(
4775                    AggregateExprBuilder::new(
4776                        avg_udaf(),
4777                        vec![col("b", schema.as_ref())?],
4778                    )
4779                    .schema(Arc::clone(&schema))
4780                    .alias("AVG(b)")
4781                    .build()?,
4782                ),
4783                Arc::new(
4784                    AggregateExprBuilder::new(
4785                        avg_udaf(),
4786                        vec![col("c", schema.as_ref())?],
4787                    )
4788                    .schema(Arc::clone(&schema))
4789                    .alias("AVG(c)")
4790                    .build()?,
4791                ),
4792                Arc::new(
4793                    AggregateExprBuilder::new(
4794                        avg_udaf(),
4795                        vec![col("d", schema.as_ref())?],
4796                    )
4797                    .schema(Arc::clone(&schema))
4798                    .alias("AVG(d)")
4799                    .build()?,
4800                ),
4801                Arc::new(
4802                    AggregateExprBuilder::new(
4803                        avg_udaf(),
4804                        vec![col("e", schema.as_ref())?],
4805                    )
4806                    .schema(Arc::clone(&schema))
4807                    .alias("AVG(e)")
4808                    .build()?,
4809                ),
4810            ],
4811            vec![None, None, None, None, None],
4812            Arc::new(scan) as Arc<dyn ExecutionPlan>,
4813            Arc::clone(&schema),
4814        )?);
4815
4816        // Pool must be large enough for accumulation to start but too small for
4817        // sort_memory after clearing.
4818        let task_ctx = new_spill_ctx(1, 500);
4819        let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;
4820
4821        match &result {
4822            Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
4823            Err(e) => {
4824                let root = e.find_root();
4825                assert!(
4826                    matches!(root, DataFusionError::ResourcesExhausted(_)),
4827                    "Expected ResourcesExhausted, got: {root}",
4828                );
4829                let msg = root.to_string();
4830                assert!(
4831                    msg.contains("Failed to reserve memory for sort during spill"),
4832                    "Expected sort reservation error, got: {msg}",
4833                );
4834            }
4835        }
4836
4837        Ok(())
4838    }
4839
4840    /// Tests that PartialReduce mode:
4841    /// 1. Accepts state as input (like Final)
4842    /// 2. Produces state as output (like Partial)
4843    /// 3. Can be followed by a Final stage to get the correct result
4844    ///
4845    /// This simulates a tree-reduce pattern:
4846    ///   Partial -> PartialReduce -> Final
4847    #[tokio::test]
4848    async fn test_partial_reduce_mode() -> Result<()> {
4849        let schema = Arc::new(Schema::new(vec![
4850            Field::new("a", DataType::UInt32, false),
4851            Field::new("b", DataType::Float64, false),
4852        ]));
4853
4854        // Produce two partitions of input data
4855        let batch1 = RecordBatch::try_new(
4856            Arc::clone(&schema),
4857            vec![
4858                Arc::new(UInt32Array::from(vec![1, 2, 3])),
4859                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
4860            ],
4861        )?;
4862        let batch2 = RecordBatch::try_new(
4863            Arc::clone(&schema),
4864            vec![
4865                Arc::new(UInt32Array::from(vec![1, 2, 3])),
4866                Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
4867            ],
4868        )?;
4869
4870        let groups =
4871            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
4872        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
4873            AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
4874                .schema(Arc::clone(&schema))
4875                .alias("SUM(b)")
4876                .build()?,
4877        )];
4878
4879        // Step 1: Partial aggregation on partition 1
4880        let input1 =
4881            TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?;
4882        let partial1 = Arc::new(AggregateExec::try_new(
4883            AggregateMode::Partial,
4884            groups.clone(),
4885            aggregates.clone(),
4886            vec![None],
4887            input1,
4888            Arc::clone(&schema),
4889        )?);
4890
4891        // Step 2: Partial aggregation on partition 2
4892        let input2 =
4893            TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?;
4894        let partial2 = Arc::new(AggregateExec::try_new(
4895            AggregateMode::Partial,
4896            groups.clone(),
4897            aggregates.clone(),
4898            vec![None],
4899            input2,
4900            Arc::clone(&schema),
4901        )?);
4902
4903        // Collect partial results
4904        let task_ctx = Arc::new(TaskContext::default());
4905        let partial_result1 =
4906            crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?;
4907        let partial_result2 =
4908            crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?;
4909
4910        // The partial results have state schema (group cols + accumulator state)
4911        let partial_schema = partial1.schema();
4912
4913        // Step 3: PartialReduce — combine partial results, still producing state
4914        let combined_input = TestMemoryExec::try_new_exec(
4915            &[partial_result1, partial_result2],
4916            Arc::clone(&partial_schema),
4917            None,
4918        )?;
4919        // Coalesce into a single partition for the PartialReduce
4920        let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
4921
4922        let partial_reduce = Arc::new(AggregateExec::try_new(
4923            AggregateMode::PartialReduce,
4924            groups.clone(),
4925            aggregates.clone(),
4926            vec![None],
4927            coalesced,
4928            Arc::clone(&partial_schema),
4929        )?);
4930
4931        // Verify PartialReduce output schema matches Partial output schema
4932        // (both produce state, not final values)
4933        assert_eq!(partial_reduce.schema(), partial_schema);
4934
4935        // Collect PartialReduce results
4936        let reduce_result =
4937            crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx))
4938                .await?;
4939
4940        // Step 4: Final aggregation on the PartialReduce output
4941        let final_input = TestMemoryExec::try_new_exec(
4942            &[reduce_result],
4943            Arc::clone(&partial_schema),
4944            None,
4945        )?;
4946        let final_agg = Arc::new(AggregateExec::try_new(
4947            AggregateMode::Final,
4948            groups.clone(),
4949            aggregates.clone(),
4950            vec![None],
4951            final_input,
4952            Arc::clone(&partial_schema),
4953        )?);
4954
4955        let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
4956
4957        // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 30+60=90
4958        assert_snapshot!(batches_to_sort_string(&result), @r"
4959            +---+--------+
4960            | a | SUM(b) |
4961            +---+--------+
4962            | 1 | 50.0   |
4963            | 2 | 70.0   |
4964            | 3 | 90.0   |
4965            +---+--------+
4966        ");
4967
4968        Ok(())
4969    }
4970
4971    /// Test that [`AggregateExec::with_dynamic_filter_expr`] overrides the existing dynamic filter
4972    #[test]
4973    fn test_with_dynamic_filter() -> Result<()> {
4974        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
4975        let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4976
4977        // Partial min aggregate supports dynamic filtering
4978        let agg = AggregateExec::try_new(
4979            AggregateMode::Partial,
4980            PhysicalGroupBy::new_single(vec![]),
4981            vec![Arc::new(
4982                AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
4983                    .schema(Arc::clone(&schema))
4984                    .alias("min_a")
4985                    .build()?,
4986            )],
4987            vec![None],
4988            child,
4989            Arc::clone(&schema),
4990        )?;
4991
4992        // Assertion 1: A filter with the same children can override the existing
4993        // dynamic filter.
4994        let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
4995            vec![col("a", &schema)?],
4996            lit(false),
4997        ));
4998        let agg = agg.with_dynamic_filter_expr(Arc::clone(&new_df))?;
4999
5000        // The aggregate's filter should now resolve to the new inner expression.
5001        let swapped = agg
5002            .dynamic_filter_expr()
5003            .expect("should still have dynamic filter")
5004            .current()?;
5005        assert_eq!(format!("{swapped}"), format!("{}", lit(false)));
5006
5007        // Assertion 2: A filter that has been through `PhysicalExpr::with_new_children`
5008        // should still be accepted when the new children are equivalent to the originals.
5009        let new_df_as_pexpr: Arc<dyn PhysicalExpr> =
5010            Arc::<DynamicFilterPhysicalExpr>::clone(&new_df);
5011        let remapped_pexpr =
5012            new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?;
5013        let Ok(remapped_df) = (remapped_pexpr as Arc<dyn std::any::Any + Send + Sync>)
5014            .downcast::<DynamicFilterPhysicalExpr>()
5015        else {
5016            panic!("should be DynamicFilterPhysicalExpr after with_new_children");
5017        };
5018        // Hard to assert this because the filter is identical. No error means
5019        // the filter was accepted. That's a good enough assertion for now.
5020        let _agg = agg.with_dynamic_filter_expr(remapped_df)?;
5021        Ok(())
5022    }
5023
5024    /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the aggregate does not support dynamic filtering
5025    #[test]
5026    fn test_with_dynamic_filter_error_unsupported() -> Result<()> {
5027        let schema = Arc::new(Schema::new(vec![
5028            Field::new("a", DataType::Int64, false),
5029            Field::new("b", DataType::Int64, false),
5030        ]));
5031        let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5032
5033        // Final mode with a group-by does not support dynamic filters.
5034        let agg = AggregateExec::try_new(
5035            AggregateMode::Final,
5036            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]),
5037            vec![Arc::new(
5038                AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
5039                    .schema(Arc::clone(&schema))
5040                    .alias("sum_b")
5041                    .build()?,
5042            )],
5043            vec![None],
5044            child,
5045            Arc::clone(&schema),
5046        )?;
5047        assert!(agg.dynamic_filter_expr().is_none());
5048
5049        let df = Arc::new(DynamicFilterPhysicalExpr::new(
5050            vec![col("a", &schema)?],
5051            lit(true),
5052        ));
5053        assert!(agg.with_dynamic_filter_expr(df).is_err());
5054        Ok(())
5055    }
5056
5057    /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the column is not in the schema
5058    #[test]
5059    fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> {
5060        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
5061        let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5062
5063        let agg = AggregateExec::try_new(
5064            AggregateMode::Partial,
5065            PhysicalGroupBy::new_single(vec![]),
5066            vec![Arc::new(
5067                AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
5068                    .schema(Arc::clone(&schema))
5069                    .alias("min_a")
5070                    .build()?,
5071            )],
5072            vec![None],
5073            child,
5074            Arc::clone(&schema),
5075        )?;
5076
5077        let df = Arc::new(DynamicFilterPhysicalExpr::new(
5078            vec![Arc::new(Column::new("bad", 99)) as _],
5079            lit(true),
5080        ));
5081        assert!(agg.with_dynamic_filter_expr(df).is_err());
5082        Ok(())
5083    }
5084}