datafusion_physical_plan/aggregates/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregates functionalities
19
20use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25    no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26    topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30    ChildFilterDescription, FilterDescription, FilterPushdownPhase, PushedDownPredicate,
31};
32use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
33use crate::windows::get_ordered_partition_by_indices;
34use crate::{
35    DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36    SendableRecordBatchStream, Statistics,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use std::collections::HashSet;
41
42use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
43use arrow::datatypes::{Field, Schema, SchemaRef};
44use arrow::record_batch::RecordBatch;
45use arrow_schema::FieldRef;
46use datafusion_common::stats::Precision;
47use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result};
48use datafusion_execution::TaskContext;
49use datafusion_expr::{Accumulator, Aggregate};
50use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
51use datafusion_physical_expr::equivalence::ProjectionMapping;
52use datafusion_physical_expr::expressions::Column;
53use datafusion_physical_expr::{
54    physical_exprs_contains, ConstExpr, EquivalenceProperties,
55};
56use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExpr};
57use datafusion_physical_expr_common::sort_expr::{
58    LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
59};
60
61use datafusion_expr::utils::AggregateOrderSensitivity;
62use itertools::Itertools;
63
64pub mod group_values;
65mod no_grouping;
66pub mod order;
67mod row_hash;
68mod topk;
69mod topk_stream;
70
71/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions.
72const AGGREGATION_HASH_SEED: ahash::RandomState =
73    ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
74
75/// Aggregation modes
76///
77/// See [`Accumulator::state`] for background information on multi-phase
78/// aggregation and how these modes are used.
79#[derive(Debug, Copy, Clone, PartialEq, Eq)]
80pub enum AggregateMode {
81    /// One of multiple layers of aggregation, any input partitioning
82    ///
83    /// Partial aggregate that can be applied in parallel across input
84    /// partitions.
85    ///
86    /// This is the first phase of a multi-phase aggregation.
87    Partial,
88    /// *Final* of multiple layers of aggregation, in exactly one partition
89    ///
90    /// Final aggregate that produces a single partition of output by combining
91    /// the output of multiple partial aggregates.
92    ///
93    /// This is the second phase of a multi-phase aggregation.
94    ///
95    /// This mode requires that the input is a single partition
96    ///
97    /// Note: Adjacent `Partial` and `Final` mode aggregation is equivalent to a `Single`
98    /// mode aggregation node. The `Final` mode is required since this is used in an
99    /// intermediate step. The [`CombinePartialFinalAggregate`] physical optimizer rule
100    /// will replace this combination with `Single` mode for more efficient execution.
101    ///
102    /// [`CombinePartialFinalAggregate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/combine_partial_final_agg/struct.CombinePartialFinalAggregate.html
103    Final,
104    /// *Final* of multiple layers of aggregation, input is *Partitioned*
105    ///
106    /// Final aggregate that works on pre-partitioned data.
107    ///
108    /// This mode requires that all rows with a particular grouping key are in
109    /// the same partitions, such as is the case with Hash repartitioning on the
110    /// group keys. If a group key is duplicated, duplicate groups would be
111    /// produced
112    FinalPartitioned,
113    /// *Single* layer of Aggregation, input is exactly one partition
114    ///
115    /// Applies the entire logical aggregation operation in a single operator,
116    /// as opposed to Partial / Final modes which apply the logical aggregation using
117    /// two operators.
118    ///
119    /// This mode requires that the input is a single partition (like Final)
120    Single,
121    /// *Single* layer of Aggregation, input is *Partitioned*
122    ///
123    /// Applies the entire logical aggregation operation in a single operator,
124    /// as opposed to Partial / Final modes which apply the logical aggregation
125    /// using two operators.
126    ///
127    /// This mode requires that the input has more than one partition, and is
128    /// partitioned by group key (like FinalPartitioned).
129    SinglePartitioned,
130}
131
132impl AggregateMode {
133    /// Checks whether this aggregation step describes a "first stage" calculation.
134    /// In other words, its input is not another aggregation result and the
135    /// `merge_batch` method will not be called for these modes.
136    pub fn is_first_stage(&self) -> bool {
137        match self {
138            AggregateMode::Partial
139            | AggregateMode::Single
140            | AggregateMode::SinglePartitioned => true,
141            AggregateMode::Final | AggregateMode::FinalPartitioned => false,
142        }
143    }
144}
145
146/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
147/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
148/// and a single group [false, false].
149/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression
150/// into multiple groups, using null expressions to align each group.
151/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should
152/// create a `PhysicalGroupBy` like
153/// ```text
154/// PhysicalGroupBy {
155///     expr: [(col(a), a), (col(b), b)],
156///     null_expr: [(NULL, a), (NULL, b)],
157///     groups: [
158///         [false, false], // (a,b)
159///         [false, true],  // (a) <=> (a, NULL)
160///         [true, false]   // (b) <=> (NULL, b)
161///     ]
162/// }
163/// ```
164#[derive(Clone, Debug, Default)]
165pub struct PhysicalGroupBy {
166    /// Distinct (Physical Expr, Alias) in the grouping set
167    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
168    /// Corresponding NULL expressions for expr
169    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
170    /// Null mask for each group in this grouping set. Each group is
171    /// composed of either one of the group expressions in expr or a null
172    /// expression in null_expr. If `groups[i][j]` is true, then the
173    /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
174    groups: Vec<Vec<bool>>,
175}
176
177impl PhysicalGroupBy {
178    /// Create a new `PhysicalGroupBy`
179    pub fn new(
180        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
181        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
182        groups: Vec<Vec<bool>>,
183    ) -> Self {
184        Self {
185            expr,
186            null_expr,
187            groups,
188        }
189    }
190
191    /// Create a GROUPING SET with only a single group. This is the "standard"
192    /// case when building a plan from an expression such as `GROUP BY a,b,c`
193    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
194        let num_exprs = expr.len();
195        Self {
196            expr,
197            null_expr: vec![],
198            groups: vec![vec![false; num_exprs]],
199        }
200    }
201
202    /// Calculate GROUP BY expressions nullable
203    pub fn exprs_nullable(&self) -> Vec<bool> {
204        let mut exprs_nullable = vec![false; self.expr.len()];
205        for group in self.groups.iter() {
206            group.iter().enumerate().for_each(|(index, is_null)| {
207                if *is_null {
208                    exprs_nullable[index] = true;
209                }
210            })
211        }
212        exprs_nullable
213    }
214
215    /// Returns the group expressions
216    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
217        &self.expr
218    }
219
220    /// Returns the null expressions
221    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
222        &self.null_expr
223    }
224
225    /// Returns the group null masks
226    pub fn groups(&self) -> &[Vec<bool>] {
227        &self.groups
228    }
229
230    /// Returns true if this `PhysicalGroupBy` has no group expressions
231    pub fn is_empty(&self) -> bool {
232        self.expr.is_empty()
233    }
234
235    /// Check whether grouping set is single group
236    pub fn is_single(&self) -> bool {
237        self.null_expr.is_empty()
238    }
239
240    /// Calculate GROUP BY expressions according to input schema.
241    pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
242        self.expr
243            .iter()
244            .map(|(expr, _alias)| Arc::clone(expr))
245            .collect()
246    }
247
248    /// The number of expressions in the output schema.
249    fn num_output_exprs(&self) -> usize {
250        let mut num_exprs = self.expr.len();
251        if !self.is_single() {
252            num_exprs += 1
253        }
254        num_exprs
255    }
256
257    /// Return grouping expressions as they occur in the output schema.
258    pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
259        let num_output_exprs = self.num_output_exprs();
260        let mut output_exprs = Vec::with_capacity(num_output_exprs);
261        output_exprs.extend(
262            self.expr
263                .iter()
264                .enumerate()
265                .take(num_output_exprs)
266                .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
267        );
268        if !self.is_single() {
269            output_exprs.push(Arc::new(Column::new(
270                Aggregate::INTERNAL_GROUPING_ID,
271                self.expr.len(),
272            )) as _);
273        }
274        output_exprs
275    }
276
277    /// Returns the number expression as grouping keys.
278    pub fn num_group_exprs(&self) -> usize {
279        if self.is_single() {
280            self.expr.len()
281        } else {
282            self.expr.len() + 1
283        }
284    }
285
286    pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
287        Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
288    }
289
290    /// Returns the fields that are used as the grouping keys.
291    fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
292        let mut fields = Vec::with_capacity(self.num_group_exprs());
293        for ((expr, name), group_expr_nullable) in
294            self.expr.iter().zip(self.exprs_nullable().into_iter())
295        {
296            fields.push(
297                Field::new(
298                    name,
299                    expr.data_type(input_schema)?,
300                    group_expr_nullable || expr.nullable(input_schema)?,
301                )
302                .with_metadata(expr.return_field(input_schema)?.metadata().clone())
303                .into(),
304            );
305        }
306        if !self.is_single() {
307            fields.push(
308                Field::new(
309                    Aggregate::INTERNAL_GROUPING_ID,
310                    Aggregate::grouping_id_type(self.expr.len()),
311                    false,
312                )
313                .into(),
314            );
315        }
316        Ok(fields)
317    }
318
319    /// Returns the output fields of the group by.
320    ///
321    /// This might be different from the `group_fields` that might contain internal expressions that
322    /// should not be part of the output schema.
323    fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
324        let mut fields = self.group_fields(input_schema)?;
325        fields.truncate(self.num_output_exprs());
326        Ok(fields)
327    }
328
329    /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
330    /// aggregation.
331    pub fn as_final(&self) -> PhysicalGroupBy {
332        let expr: Vec<_> =
333            self.output_exprs()
334                .into_iter()
335                .zip(
336                    self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
337                        Aggregate::INTERNAL_GROUPING_ID.to_owned(),
338                    )),
339                )
340                .collect();
341        let num_exprs = expr.len();
342        let groups = if self.expr.is_empty() {
343            // No GROUP BY expressions - should have no groups
344            vec![]
345        } else {
346            // Has GROUP BY expressions - create a single group
347            vec![vec![false; num_exprs]]
348        };
349        Self {
350            expr,
351            null_expr: vec![],
352            groups,
353        }
354    }
355}
356
357impl PartialEq for PhysicalGroupBy {
358    fn eq(&self, other: &PhysicalGroupBy) -> bool {
359        self.expr.len() == other.expr.len()
360            && self
361                .expr
362                .iter()
363                .zip(other.expr.iter())
364                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
365            && self.null_expr.len() == other.null_expr.len()
366            && self
367                .null_expr
368                .iter()
369                .zip(other.null_expr.iter())
370                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
371            && self.groups == other.groups
372    }
373}
374
375#[allow(clippy::large_enum_variant)]
376enum StreamType {
377    AggregateStream(AggregateStream),
378    GroupedHash(GroupedHashAggregateStream),
379    GroupedPriorityQueue(GroupedTopKAggregateStream),
380}
381
382impl From<StreamType> for SendableRecordBatchStream {
383    fn from(stream: StreamType) -> Self {
384        match stream {
385            StreamType::AggregateStream(stream) => Box::pin(stream),
386            StreamType::GroupedHash(stream) => Box::pin(stream),
387            StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
388        }
389    }
390}
391
392/// Hash aggregate execution plan
393#[derive(Debug, Clone)]
394pub struct AggregateExec {
395    /// Aggregation mode (full, partial)
396    mode: AggregateMode,
397    /// Group by expressions
398    group_by: PhysicalGroupBy,
399    /// Aggregate expressions
400    aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
401    /// FILTER (WHERE clause) expression for each aggregate expression
402    filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
403    /// Set if the output of this aggregation is truncated by a upstream sort/limit clause
404    limit: Option<usize>,
405    /// Input plan, could be a partial aggregate or the input to the aggregate
406    pub input: Arc<dyn ExecutionPlan>,
407    /// Schema after the aggregate is applied
408    schema: SchemaRef,
409    /// Input schema before any aggregation is applied. For partial aggregate this will be the
410    /// same as input.schema() but for the final aggregate it will be the same as the input
411    /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`.
412    /// We need the input schema of partial aggregate to be able to deserialize aggregate
413    /// expressions from protobuf for final aggregate.
414    pub input_schema: SchemaRef,
415    /// Execution metrics
416    metrics: ExecutionPlanMetricsSet,
417    required_input_ordering: Option<OrderingRequirements>,
418    /// Describes how the input is ordered relative to the group by columns
419    input_order_mode: InputOrderMode,
420    cache: PlanProperties,
421}
422
423impl AggregateExec {
424    /// Function used in `OptimizeAggregateOrder` optimizer rule,
425    /// where we need parts of the new value, others cloned from the old one
426    /// Rewrites aggregate exec with new aggregate expressions.
427    pub fn with_new_aggr_exprs(
428        &self,
429        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
430    ) -> Self {
431        Self {
432            aggr_expr,
433            // clone the rest of the fields
434            required_input_ordering: self.required_input_ordering.clone(),
435            metrics: ExecutionPlanMetricsSet::new(),
436            input_order_mode: self.input_order_mode.clone(),
437            cache: self.cache.clone(),
438            mode: self.mode,
439            group_by: self.group_by.clone(),
440            filter_expr: self.filter_expr.clone(),
441            limit: self.limit,
442            input: Arc::clone(&self.input),
443            schema: Arc::clone(&self.schema),
444            input_schema: Arc::clone(&self.input_schema),
445        }
446    }
447
448    pub fn cache(&self) -> &PlanProperties {
449        &self.cache
450    }
451
452    /// Create a new hash aggregate execution plan
453    pub fn try_new(
454        mode: AggregateMode,
455        group_by: PhysicalGroupBy,
456        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
457        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
458        input: Arc<dyn ExecutionPlan>,
459        input_schema: SchemaRef,
460    ) -> Result<Self> {
461        let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
462
463        let schema = Arc::new(schema);
464        AggregateExec::try_new_with_schema(
465            mode,
466            group_by,
467            aggr_expr,
468            filter_expr,
469            input,
470            input_schema,
471            schema,
472        )
473    }
474
475    /// Create a new hash aggregate execution plan with the given schema.
476    /// This constructor isn't part of the public API, it is used internally
477    /// by DataFusion to enforce schema consistency during when re-creating
478    /// `AggregateExec`s inside optimization rules. Schema field names of an
479    /// `AggregateExec` depends on the names of aggregate expressions. Since
480    /// a rule may re-write aggregate expressions (e.g. reverse them) during
481    /// initialization, field names may change inadvertently if one re-creates
482    /// the schema in such cases.
483    #[allow(clippy::too_many_arguments)]
484    fn try_new_with_schema(
485        mode: AggregateMode,
486        group_by: PhysicalGroupBy,
487        mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
488        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
489        input: Arc<dyn ExecutionPlan>,
490        input_schema: SchemaRef,
491        schema: SchemaRef,
492    ) -> Result<Self> {
493        // Make sure arguments are consistent in size
494        if aggr_expr.len() != filter_expr.len() {
495            return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
496        }
497
498        let input_eq_properties = input.equivalence_properties();
499        // Get GROUP BY expressions:
500        let groupby_exprs = group_by.input_exprs();
501        // If existing ordering satisfies a prefix of the GROUP BY expressions,
502        // prefix requirements with this section. In this case, aggregation will
503        // work more efficiently.
504        let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?;
505        let mut new_requirements = indices
506            .iter()
507            .map(|&idx| {
508                PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None)
509            })
510            .collect::<Vec<_>>();
511
512        let req = get_finer_aggregate_exprs_requirement(
513            &mut aggr_expr,
514            &group_by,
515            input_eq_properties,
516            &mode,
517        )?;
518        new_requirements.extend(req);
519
520        let required_input_ordering =
521            LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
522
523        // If our aggregation has grouping sets then our base grouping exprs will
524        // be expanded based on the flags in `group_by.groups` where for each
525        // group we swap the grouping expr for `null` if the flag is `true`
526        // That means that each index in `indices` is valid if and only if
527        // it is not null in every group
528        let indices: Vec<usize> = indices
529            .into_iter()
530            .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
531            .collect();
532
533        let input_order_mode = if indices.len() == groupby_exprs.len()
534            && !indices.is_empty()
535            && group_by.groups.len() == 1
536        {
537            InputOrderMode::Sorted
538        } else if !indices.is_empty() {
539            InputOrderMode::PartiallySorted(indices)
540        } else {
541            InputOrderMode::Linear
542        };
543
544        // construct a map from the input expression to the output expression of the Aggregation group by
545        let group_expr_mapping =
546            ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
547
548        let cache = Self::compute_properties(
549            &input,
550            Arc::clone(&schema),
551            &group_expr_mapping,
552            &mode,
553            &input_order_mode,
554            aggr_expr.as_slice(),
555        )?;
556
557        Ok(AggregateExec {
558            mode,
559            group_by,
560            aggr_expr,
561            filter_expr,
562            input,
563            schema,
564            input_schema,
565            metrics: ExecutionPlanMetricsSet::new(),
566            required_input_ordering,
567            limit: None,
568            input_order_mode,
569            cache,
570        })
571    }
572
573    /// Aggregation mode (full, partial)
574    pub fn mode(&self) -> &AggregateMode {
575        &self.mode
576    }
577
578    /// Set the `limit` of this AggExec
579    pub fn with_limit(mut self, limit: Option<usize>) -> Self {
580        self.limit = limit;
581        self
582    }
583    /// Grouping expressions
584    pub fn group_expr(&self) -> &PhysicalGroupBy {
585        &self.group_by
586    }
587
588    /// Grouping expressions as they occur in the output schema
589    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
590        self.group_by.output_exprs()
591    }
592
593    /// Aggregate expressions
594    pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
595        &self.aggr_expr
596    }
597
598    /// FILTER (WHERE clause) expression for each aggregate expression
599    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
600        &self.filter_expr
601    }
602
603    /// Input plan
604    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
605        &self.input
606    }
607
608    /// Get the input schema before any aggregates are applied
609    pub fn input_schema(&self) -> SchemaRef {
610        Arc::clone(&self.input_schema)
611    }
612
613    /// number of rows soft limit of the AggregateExec
614    pub fn limit(&self) -> Option<usize> {
615        self.limit
616    }
617
618    fn execute_typed(
619        &self,
620        partition: usize,
621        context: Arc<TaskContext>,
622    ) -> Result<StreamType> {
623        // no group by at all
624        if self.group_by.expr.is_empty() {
625            return Ok(StreamType::AggregateStream(AggregateStream::new(
626                self, context, partition,
627            )?));
628        }
629
630        // grouping by an expression that has a sort/limit upstream
631        if let Some(limit) = self.limit {
632            if !self.is_unordered_unfiltered_group_by_distinct() {
633                return Ok(StreamType::GroupedPriorityQueue(
634                    GroupedTopKAggregateStream::new(self, context, partition, limit)?,
635                ));
636            }
637        }
638
639        // grouping by something else and we need to just materialize all results
640        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
641            self, context, partition,
642        )?))
643    }
644
645    /// Finds the DataType and SortDirection for this Aggregate, if there is one
646    pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
647        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
648        agg_expr.get_minmax_desc()
649    }
650
651    /// true, if this Aggregate has a group-by with no required or explicit ordering,
652    /// no filtering and no aggregate expressions
653    /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule
654    /// on an AggregateExec.
655    pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
656        // ensure there is a group by
657        if self.group_expr().is_empty() {
658            return false;
659        }
660        // ensure there are no aggregate expressions
661        if !self.aggr_expr().is_empty() {
662            return false;
663        }
664        // ensure there are no filters on aggregate expressions; the above check
665        // may preclude this case
666        if self.filter_expr().iter().any(|e| e.is_some()) {
667            return false;
668        }
669        // ensure there are no order by expressions
670        if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
671            return false;
672        }
673        // ensure there is no output ordering; can this rule be relaxed?
674        if self.properties().output_ordering().is_some() {
675            return false;
676        }
677        // ensure no ordering is required on the input
678        if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
679            return matches!(requirement, OrderingRequirements::Hard(_));
680        }
681        true
682    }
683
684    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
685    pub fn compute_properties(
686        input: &Arc<dyn ExecutionPlan>,
687        schema: SchemaRef,
688        group_expr_mapping: &ProjectionMapping,
689        mode: &AggregateMode,
690        input_order_mode: &InputOrderMode,
691        aggr_exprs: &[Arc<AggregateFunctionExpr>],
692    ) -> Result<PlanProperties> {
693        // Construct equivalence properties:
694        let mut eq_properties = input
695            .equivalence_properties()
696            .project(group_expr_mapping, schema);
697
698        // If the group by is empty, then we ensure that the operator will produce
699        // only one row, and mark the generated result as a constant value.
700        if group_expr_mapping.is_empty() {
701            let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
702                let column = Arc::new(Column::new(func.name(), idx));
703                ConstExpr::from(column as Arc<dyn PhysicalExpr>)
704            });
705            eq_properties.add_constants(new_constants)?;
706        }
707
708        // Group by expression will be a distinct value after the aggregation.
709        // Add it into the constraint set.
710        let mut constraints = eq_properties.constraints().to_vec();
711        let new_constraint = Constraint::Unique(
712            group_expr_mapping
713                .iter()
714                .flat_map(|(_, target_cols)| {
715                    target_cols.iter().flat_map(|(expr, _)| {
716                        expr.as_any().downcast_ref::<Column>().map(|c| c.index())
717                    })
718                })
719                .collect(),
720        );
721        constraints.push(new_constraint);
722        eq_properties =
723            eq_properties.with_constraints(Constraints::new_unverified(constraints));
724
725        // Get output partitioning:
726        let input_partitioning = input.output_partitioning().clone();
727        let output_partitioning = if mode.is_first_stage() {
728            // First stage aggregation will not change the output partitioning,
729            // but needs to respect aliases (e.g. mapping in the GROUP BY
730            // expression).
731            let input_eq_properties = input.equivalence_properties();
732            input_partitioning.project(group_expr_mapping, input_eq_properties)
733        } else {
734            input_partitioning.clone()
735        };
736
737        // TODO: Emission type and boundedness information can be enhanced here
738        let emission_type = if *input_order_mode == InputOrderMode::Linear {
739            EmissionType::Final
740        } else {
741            input.pipeline_behavior()
742        };
743
744        Ok(PlanProperties::new(
745            eq_properties,
746            output_partitioning,
747            emission_type,
748            input.boundedness(),
749        ))
750    }
751
752    pub fn input_order_mode(&self) -> &InputOrderMode {
753        &self.input_order_mode
754    }
755
756    fn statistics_inner(&self, child_statistics: Statistics) -> Result<Statistics> {
757        // TODO stats: group expressions:
758        // - once expressions will be able to compute their own stats, use it here
759        // - case where we group by on a column for which with have the `distinct` stat
760        // TODO stats: aggr expression:
761        // - aggregations sometimes also preserve invariants such as min, max...
762
763        let column_statistics = {
764            // self.schema: [<group by exprs>, <aggregate exprs>]
765            let mut column_statistics = Statistics::unknown_column(&self.schema());
766
767            for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
768                if let Some(col) = expr.as_any().downcast_ref::<Column>() {
769                    column_statistics[idx].max_value = child_statistics.column_statistics
770                        [col.index()]
771                    .max_value
772                    .clone();
773
774                    column_statistics[idx].min_value = child_statistics.column_statistics
775                        [col.index()]
776                    .min_value
777                    .clone();
778                }
779            }
780
781            column_statistics
782        };
783        match self.mode {
784            AggregateMode::Final | AggregateMode::FinalPartitioned
785                if self.group_by.expr.is_empty() =>
786            {
787                Ok(Statistics {
788                    num_rows: Precision::Exact(1),
789                    column_statistics,
790                    total_byte_size: Precision::Absent,
791                })
792            }
793            _ => {
794                // When the input row count is 1, we can adopt that statistic keeping its reliability.
795                // When it is larger than 1, we degrade the precision since it may decrease after aggregation.
796                let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
797                {
798                    if *value > 1 {
799                        child_statistics.num_rows.to_inexact()
800                    } else if *value == 0 {
801                        child_statistics.num_rows
802                    } else {
803                        // num_rows = 1 case
804                        let grouping_set_num = self.group_by.groups.len();
805                        child_statistics.num_rows.map(|x| x * grouping_set_num)
806                    }
807                } else {
808                    Precision::Absent
809                };
810                Ok(Statistics {
811                    num_rows,
812                    column_statistics,
813                    total_byte_size: Precision::Absent,
814                })
815            }
816        }
817    }
818}
819
820impl DisplayAs for AggregateExec {
821    fn fmt_as(
822        &self,
823        t: DisplayFormatType,
824        f: &mut std::fmt::Formatter,
825    ) -> std::fmt::Result {
826        match t {
827            DisplayFormatType::Default | DisplayFormatType::Verbose => {
828                let format_expr_with_alias =
829                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
830                        let e = e.to_string();
831                        if &e != alias {
832                            format!("{e} as {alias}")
833                        } else {
834                            e
835                        }
836                    };
837
838                write!(f, "AggregateExec: mode={:?}", self.mode)?;
839                let g: Vec<String> = if self.group_by.is_single() {
840                    self.group_by
841                        .expr
842                        .iter()
843                        .map(format_expr_with_alias)
844                        .collect()
845                } else {
846                    self.group_by
847                        .groups
848                        .iter()
849                        .map(|group| {
850                            let terms = group
851                                .iter()
852                                .enumerate()
853                                .map(|(idx, is_null)| {
854                                    if *is_null {
855                                        format_expr_with_alias(
856                                            &self.group_by.null_expr[idx],
857                                        )
858                                    } else {
859                                        format_expr_with_alias(&self.group_by.expr[idx])
860                                    }
861                                })
862                                .collect::<Vec<String>>()
863                                .join(", ");
864                            format!("({terms})")
865                        })
866                        .collect()
867                };
868
869                write!(f, ", gby=[{}]", g.join(", "))?;
870
871                let a: Vec<String> = self
872                    .aggr_expr
873                    .iter()
874                    .map(|agg| agg.name().to_string())
875                    .collect();
876                write!(f, ", aggr=[{}]", a.join(", "))?;
877                if let Some(limit) = self.limit {
878                    write!(f, ", lim=[{limit}]")?;
879                }
880
881                if self.input_order_mode != InputOrderMode::Linear {
882                    write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
883                }
884            }
885            DisplayFormatType::TreeRender => {
886                let format_expr_with_alias =
887                    |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
888                        let expr_sql = fmt_sql(e.as_ref()).to_string();
889                        if &expr_sql != alias {
890                            format!("{expr_sql} as {alias}")
891                        } else {
892                            expr_sql
893                        }
894                    };
895
896                let g: Vec<String> = if self.group_by.is_single() {
897                    self.group_by
898                        .expr
899                        .iter()
900                        .map(format_expr_with_alias)
901                        .collect()
902                } else {
903                    self.group_by
904                        .groups
905                        .iter()
906                        .map(|group| {
907                            let terms = group
908                                .iter()
909                                .enumerate()
910                                .map(|(idx, is_null)| {
911                                    if *is_null {
912                                        format_expr_with_alias(
913                                            &self.group_by.null_expr[idx],
914                                        )
915                                    } else {
916                                        format_expr_with_alias(&self.group_by.expr[idx])
917                                    }
918                                })
919                                .collect::<Vec<String>>()
920                                .join(", ");
921                            format!("({terms})")
922                        })
923                        .collect()
924                };
925                let a: Vec<String> = self
926                    .aggr_expr
927                    .iter()
928                    .map(|agg| agg.human_display().to_string())
929                    .collect();
930                writeln!(f, "mode={:?}", self.mode)?;
931                if !g.is_empty() {
932                    writeln!(f, "group_by={}", g.join(", "))?;
933                }
934                if !a.is_empty() {
935                    writeln!(f, "aggr={}", a.join(", "))?;
936                }
937            }
938        }
939        Ok(())
940    }
941}
942
943impl ExecutionPlan for AggregateExec {
944    fn name(&self) -> &'static str {
945        "AggregateExec"
946    }
947
948    /// Return a reference to Any that can be used for down-casting
949    fn as_any(&self) -> &dyn Any {
950        self
951    }
952
953    fn properties(&self) -> &PlanProperties {
954        &self.cache
955    }
956
957    fn required_input_distribution(&self) -> Vec<Distribution> {
958        match &self.mode {
959            AggregateMode::Partial => {
960                vec![Distribution::UnspecifiedDistribution]
961            }
962            AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
963                vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
964            }
965            AggregateMode::Final | AggregateMode::Single => {
966                vec![Distribution::SinglePartition]
967            }
968        }
969    }
970
971    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
972        vec![self.required_input_ordering.clone()]
973    }
974
975    /// The output ordering of [`AggregateExec`] is determined by its `group_by`
976    /// columns. Although this method is not explicitly used by any optimizer
977    /// rules yet, overriding the default implementation ensures that it
978    /// accurately reflects the actual behavior.
979    ///
980    /// If the [`InputOrderMode`] is `Linear`, the `group_by` columns don't have
981    /// an ordering, which means the results do not either. However, in the
982    /// `Ordered` and `PartiallyOrdered` cases, the `group_by` columns do have
983    /// an ordering, which is preserved in the output.
984    fn maintains_input_order(&self) -> Vec<bool> {
985        vec![self.input_order_mode != InputOrderMode::Linear]
986    }
987
988    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
989        vec![&self.input]
990    }
991
992    fn with_new_children(
993        self: Arc<Self>,
994        children: Vec<Arc<dyn ExecutionPlan>>,
995    ) -> Result<Arc<dyn ExecutionPlan>> {
996        let mut me = AggregateExec::try_new_with_schema(
997            self.mode,
998            self.group_by.clone(),
999            self.aggr_expr.clone(),
1000            self.filter_expr.clone(),
1001            Arc::clone(&children[0]),
1002            Arc::clone(&self.input_schema),
1003            Arc::clone(&self.schema),
1004        )?;
1005        me.limit = self.limit;
1006
1007        Ok(Arc::new(me))
1008    }
1009
1010    fn execute(
1011        &self,
1012        partition: usize,
1013        context: Arc<TaskContext>,
1014    ) -> Result<SendableRecordBatchStream> {
1015        self.execute_typed(partition, context)
1016            .map(|stream| stream.into())
1017    }
1018
1019    fn metrics(&self) -> Option<MetricsSet> {
1020        Some(self.metrics.clone_inner())
1021    }
1022
1023    fn statistics(&self) -> Result<Statistics> {
1024        self.partition_statistics(None)
1025    }
1026
1027    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1028        self.statistics_inner(self.input().partition_statistics(partition)?)
1029    }
1030
1031    fn cardinality_effect(&self) -> CardinalityEffect {
1032        CardinalityEffect::LowerEqual
1033    }
1034
1035    /// Push down parent filters when possible (see implementation comment for details),
1036    /// but do not introduce any new self filters.
1037    fn gather_filters_for_pushdown(
1038        &self,
1039        _phase: FilterPushdownPhase,
1040        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1041        _config: &ConfigOptions,
1042    ) -> Result<FilterDescription> {
1043        // It's safe to push down filters through aggregates when filters only reference
1044        // grouping columns, because such filters determine which groups to compute, not
1045        // *how* to compute them. Each group's aggregate values (SUM, COUNT, etc.) are
1046        // calculated from the same input rows regardless of whether we filter before or
1047        // after grouping - filtering before just eliminates entire groups early.
1048        // This optimization is NOT safe for filters on aggregated columns (like filtering on
1049        // the result of SUM or COUNT), as those require computing all groups first.
1050
1051        let grouping_columns: HashSet<_> = self
1052            .group_by
1053            .expr()
1054            .iter()
1055            .flat_map(|(expr, _)| collect_columns(expr))
1056            .collect();
1057
1058        // Analyze each filter separately to determine if it can be pushed down
1059        let mut safe_filters = Vec::new();
1060        let mut unsafe_filters = Vec::new();
1061
1062        for filter in parent_filters {
1063            let filter_columns: HashSet<_> =
1064                collect_columns(&filter).into_iter().collect();
1065
1066            // Check if this filter references non-grouping columns
1067            let references_non_grouping = !grouping_columns.is_empty()
1068                && !filter_columns.is_subset(&grouping_columns);
1069
1070            if references_non_grouping {
1071                unsafe_filters.push(filter);
1072                continue;
1073            }
1074
1075            // For GROUPING SETS, verify this filter's columns appear in all grouping sets
1076            if self.group_by.groups().len() > 1 {
1077                let filter_column_indices: Vec<usize> = filter_columns
1078                    .iter()
1079                    .filter_map(|filter_col| {
1080                        self.group_by.expr().iter().position(|(expr, _)| {
1081                            collect_columns(expr).contains(filter_col)
1082                        })
1083                    })
1084                    .collect();
1085
1086                // Check if any of this filter's columns are missing from any grouping set
1087                let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1088                    filter_column_indices
1089                        .iter()
1090                        .any(|&idx| null_mask.get(idx) == Some(&true))
1091                });
1092
1093                if has_missing_column {
1094                    unsafe_filters.push(filter);
1095                    continue;
1096                }
1097            }
1098
1099            // This filter is safe to push down
1100            safe_filters.push(filter);
1101        }
1102
1103        // Build child filter description with both safe and unsafe filters
1104        let child = self.children()[0];
1105        let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1106
1107        // Add unsafe filters as unsupported
1108        child_desc.parent_filters.extend(
1109            unsafe_filters
1110                .into_iter()
1111                .map(PushedDownPredicate::unsupported),
1112        );
1113
1114        Ok(FilterDescription::new().with_child(child_desc))
1115    }
1116}
1117
1118fn create_schema(
1119    input_schema: &Schema,
1120    group_by: &PhysicalGroupBy,
1121    aggr_expr: &[Arc<AggregateFunctionExpr>],
1122    mode: AggregateMode,
1123) -> Result<Schema> {
1124    let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1125    fields.extend(group_by.output_fields(input_schema)?);
1126
1127    match mode {
1128        AggregateMode::Partial => {
1129            // in partial mode, the fields of the accumulator's state
1130            for expr in aggr_expr {
1131                fields.extend(expr.state_fields()?.iter().cloned());
1132            }
1133        }
1134        AggregateMode::Final
1135        | AggregateMode::FinalPartitioned
1136        | AggregateMode::Single
1137        | AggregateMode::SinglePartitioned => {
1138            // in final mode, the field with the final result of the accumulator
1139            for expr in aggr_expr {
1140                fields.push(expr.field())
1141            }
1142        }
1143    }
1144
1145    Ok(Schema::new_with_metadata(
1146        fields,
1147        input_schema.metadata().clone(),
1148    ))
1149}
1150
1151/// Determines the lexical ordering requirement for an aggregate expression.
1152///
1153/// # Parameters
1154///
1155/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the
1156///   aggregate expression.
1157/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1158///   physical GROUP BY expression.
1159/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1160///   mode of aggregation.
1161/// - `include_soft_requirement`: When `false`, only hard requirements are
1162///   considered, as indicated by [`AggregateFunctionExpr::order_sensitivity`]
1163///   returning [`AggregateOrderSensitivity::HardRequirement`].
1164///   Otherwise, also soft requirements ([`AggregateOrderSensitivity::SoftRequirement`])
1165///   are considered.
1166///
1167/// # Returns
1168///
1169/// A `LexOrdering` instance indicating the lexical ordering requirement for
1170/// the aggregate expression.
1171fn get_aggregate_expr_req(
1172    aggr_expr: &AggregateFunctionExpr,
1173    group_by: &PhysicalGroupBy,
1174    agg_mode: &AggregateMode,
1175    include_soft_requirement: bool,
1176) -> Option<LexOrdering> {
1177    // If the aggregation is performing a "second stage" calculation,
1178    // then ignore the ordering requirement. Ordering requirement applies
1179    // only to the aggregation input data.
1180    if !agg_mode.is_first_stage() {
1181        return None;
1182    }
1183
1184    match aggr_expr.order_sensitivity() {
1185        AggregateOrderSensitivity::Insensitive => return None,
1186        AggregateOrderSensitivity::HardRequirement => {}
1187        AggregateOrderSensitivity::SoftRequirement => {
1188            if !include_soft_requirement {
1189                return None;
1190            }
1191        }
1192        AggregateOrderSensitivity::Beneficial => return None,
1193    }
1194
1195    let mut sort_exprs = aggr_expr.order_bys().to_vec();
1196    // In non-first stage modes, we accumulate data (using `merge_batch`) from
1197    // different partitions (i.e. merge partial results). During this merge, we
1198    // consider the ordering of each partial result. Hence, we do not need to
1199    // use the ordering requirement in such modes as long as partial results are
1200    // generated with the correct ordering.
1201    if group_by.is_single() {
1202        // Remove all orderings that occur in the group by. These requirements
1203        // will definitely be satisfied -- Each group by expression will have
1204        // distinct values per group, hence all requirements are satisfied.
1205        let physical_exprs = group_by.input_exprs();
1206        sort_exprs.retain(|sort_expr| {
1207            !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1208        });
1209    }
1210    LexOrdering::new(sort_exprs)
1211}
1212
1213/// Concatenates the given slices.
1214pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1215    [lhs, rhs].concat()
1216}
1217
1218// Determines if the candidate ordering is finer than the current ordering.
1219// Returns `None` if they are incomparable, `Some(true)` if there is no current
1220// ordering or candidate ordering is finer, and `Some(false)` otherwise.
1221fn determine_finer(
1222    current: &Option<LexOrdering>,
1223    candidate: &LexOrdering,
1224) -> Option<bool> {
1225    if let Some(ordering) = current {
1226        candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1227    } else {
1228        Some(true)
1229    }
1230}
1231
1232/// Gets the common requirement that satisfies all the aggregate expressions.
1233/// When possible, chooses the requirement that is already satisfied by the
1234/// equivalence properties.
1235///
1236/// # Parameters
1237///
1238/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the
1239///   aggregate expressions.
1240/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1241///   physical GROUP BY expression.
1242/// - `eq_properties`: A reference to an `EquivalenceProperties` instance
1243///   representing equivalence properties for ordering.
1244/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1245///   mode of aggregation.
1246///
1247/// # Returns
1248///
1249/// A `Result<Vec<PhysicalSortRequirement>>` instance, which is the requirement
1250/// that satisfies all the aggregate requirements. Returns an error in case of
1251/// conflicting requirements.
1252pub fn get_finer_aggregate_exprs_requirement(
1253    aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1254    group_by: &PhysicalGroupBy,
1255    eq_properties: &EquivalenceProperties,
1256    agg_mode: &AggregateMode,
1257) -> Result<Vec<PhysicalSortRequirement>> {
1258    let mut requirement = None;
1259
1260    // First try and find a match for all hard and soft requirements.
1261    // If a match can't be found, try a second time just matching hard
1262    // requirements.
1263    for include_soft_requirement in [false, true] {
1264        for aggr_expr in aggr_exprs.iter_mut() {
1265            let Some(aggr_req) = get_aggregate_expr_req(
1266                aggr_expr,
1267                group_by,
1268                agg_mode,
1269                include_soft_requirement,
1270            )
1271            .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1272                // There is no aggregate ordering requirement, or it is trivially
1273                // satisfied -- we can skip this expression.
1274                continue;
1275            };
1276            // If the common requirement is finer than the current expression's,
1277            // we can skip this expression. If the latter is finer than the former,
1278            // adopt it if it is satisfied by the equivalence properties. Otherwise,
1279            // defer the analysis to the reverse expression.
1280            let forward_finer = determine_finer(&requirement, &aggr_req);
1281            if let Some(finer) = forward_finer {
1282                if !finer {
1283                    continue;
1284                } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1285                    requirement = Some(aggr_req);
1286                    continue;
1287                }
1288            }
1289            if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1290                let Some(rev_aggr_req) = get_aggregate_expr_req(
1291                    &reverse_aggr_expr,
1292                    group_by,
1293                    agg_mode,
1294                    include_soft_requirement,
1295                )
1296                .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1297                    // The reverse requirement is trivially satisfied -- just reverse
1298                    // the expression and continue with the next one:
1299                    *aggr_expr = Arc::new(reverse_aggr_expr);
1300                    continue;
1301                };
1302                // If the common requirement is finer than the reverse expression's,
1303                // just reverse it and continue the loop with the next aggregate
1304                // expression. If the latter is finer than the former, adopt it if
1305                // it is satisfied by the equivalence properties. Otherwise, adopt
1306                // the forward expression.
1307                if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1308                    if !finer {
1309                        *aggr_expr = Arc::new(reverse_aggr_expr);
1310                    } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1311                        *aggr_expr = Arc::new(reverse_aggr_expr);
1312                        requirement = Some(rev_aggr_req);
1313                    } else {
1314                        requirement = Some(aggr_req);
1315                    }
1316                } else if forward_finer.is_some() {
1317                    requirement = Some(aggr_req);
1318                } else {
1319                    // Neither the existing requirement nor the current aggregate
1320                    // requirement satisfy the other (forward or reverse), this
1321                    // means they are conflicting. This is a problem only for hard
1322                    // requirements. Unsatisfied soft requirements can be ignored.
1323                    if !include_soft_requirement {
1324                        return not_impl_err!(
1325                            "Conflicting ordering requirements in aggregate functions is not supported"
1326                        );
1327                    }
1328                }
1329            }
1330        }
1331    }
1332
1333    Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1334}
1335
1336/// Returns physical expressions for arguments to evaluate against a batch.
1337///
1338/// The expressions are different depending on `mode`:
1339/// * Partial: AggregateFunctionExpr::expressions
1340/// * Final: columns of `AggregateFunctionExpr::state_fields()`
1341pub fn aggregate_expressions(
1342    aggr_expr: &[Arc<AggregateFunctionExpr>],
1343    mode: &AggregateMode,
1344    col_idx_base: usize,
1345) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1346    match mode {
1347        AggregateMode::Partial
1348        | AggregateMode::Single
1349        | AggregateMode::SinglePartitioned => Ok(aggr_expr
1350            .iter()
1351            .map(|agg| {
1352                let mut result = agg.expressions();
1353                // Append ordering requirements to expressions' results. This
1354                // way order sensitive aggregators can satisfy requirement
1355                // themselves.
1356                result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
1357                result
1358            })
1359            .collect()),
1360        // In this mode, we build the merge expressions of the aggregation.
1361        AggregateMode::Final | AggregateMode::FinalPartitioned => {
1362            let mut col_idx_base = col_idx_base;
1363            aggr_expr
1364                .iter()
1365                .map(|agg| {
1366                    let exprs = merge_expressions(col_idx_base, agg)?;
1367                    col_idx_base += exprs.len();
1368                    Ok(exprs)
1369                })
1370                .collect()
1371        }
1372    }
1373}
1374
1375/// uses `state_fields` to build a vec of physical column expressions required to merge the
1376/// AggregateFunctionExpr' accumulator's state.
1377///
1378/// `index_base` is the starting physical column index for the next expanded state field.
1379fn merge_expressions(
1380    index_base: usize,
1381    expr: &AggregateFunctionExpr,
1382) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1383    expr.state_fields().map(|fields| {
1384        fields
1385            .iter()
1386            .enumerate()
1387            .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1388            .collect()
1389    })
1390}
1391
1392pub type AccumulatorItem = Box<dyn Accumulator>;
1393
1394pub fn create_accumulators(
1395    aggr_expr: &[Arc<AggregateFunctionExpr>],
1396) -> Result<Vec<AccumulatorItem>> {
1397    aggr_expr
1398        .iter()
1399        .map(|expr| expr.create_accumulator())
1400        .collect()
1401}
1402
1403/// returns a vector of ArrayRefs, where each entry corresponds to either the
1404/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
1405pub fn finalize_aggregation(
1406    accumulators: &mut [AccumulatorItem],
1407    mode: &AggregateMode,
1408) -> Result<Vec<ArrayRef>> {
1409    match mode {
1410        AggregateMode::Partial => {
1411            // Build the vector of states
1412            accumulators
1413                .iter_mut()
1414                .map(|accumulator| {
1415                    accumulator.state().and_then(|e| {
1416                        e.iter()
1417                            .map(|v| v.to_array())
1418                            .collect::<Result<Vec<ArrayRef>>>()
1419                    })
1420                })
1421                .flatten_ok()
1422                .collect()
1423        }
1424        AggregateMode::Final
1425        | AggregateMode::FinalPartitioned
1426        | AggregateMode::Single
1427        | AggregateMode::SinglePartitioned => {
1428            // Merge the state to the final value
1429            accumulators
1430                .iter_mut()
1431                .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1432                .collect()
1433        }
1434    }
1435}
1436
1437/// Evaluates expressions against a record batch.
1438fn evaluate(
1439    expr: &[Arc<dyn PhysicalExpr>],
1440    batch: &RecordBatch,
1441) -> Result<Vec<ArrayRef>> {
1442    expr.iter()
1443        .map(|expr| {
1444            expr.evaluate(batch)
1445                .and_then(|v| v.into_array(batch.num_rows()))
1446        })
1447        .collect()
1448}
1449
1450/// Evaluates expressions against a record batch.
1451pub fn evaluate_many(
1452    expr: &[Vec<Arc<dyn PhysicalExpr>>],
1453    batch: &RecordBatch,
1454) -> Result<Vec<Vec<ArrayRef>>> {
1455    expr.iter().map(|expr| evaluate(expr, batch)).collect()
1456}
1457
1458fn evaluate_optional(
1459    expr: &[Option<Arc<dyn PhysicalExpr>>],
1460    batch: &RecordBatch,
1461) -> Result<Vec<Option<ArrayRef>>> {
1462    expr.iter()
1463        .map(|expr| {
1464            expr.as_ref()
1465                .map(|expr| {
1466                    expr.evaluate(batch)
1467                        .and_then(|v| v.into_array(batch.num_rows()))
1468                })
1469                .transpose()
1470        })
1471        .collect()
1472}
1473
1474fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1475    if group.len() > 64 {
1476        return not_impl_err!(
1477            "Grouping sets with more than 64 columns are not supported"
1478        );
1479    }
1480    let group_id = group.iter().fold(0u64, |acc, &is_null| {
1481        (acc << 1) | if is_null { 1 } else { 0 }
1482    });
1483    let num_rows = batch.num_rows();
1484    if group.len() <= 8 {
1485        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1486    } else if group.len() <= 16 {
1487        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1488    } else if group.len() <= 32 {
1489        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1490    } else {
1491        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1492    }
1493}
1494
1495/// Evaluate a group by expression against a `RecordBatch`
1496///
1497/// Arguments:
1498/// - `group_by`: the expression to evaluate
1499/// - `batch`: the `RecordBatch` to evaluate against
1500///
1501/// Returns: A Vec of Vecs of Array of results
1502/// The outer Vec appears to be for grouping sets
1503/// The inner Vec contains the results per expression
1504/// The inner-inner Array contains the results per row
1505pub fn evaluate_group_by(
1506    group_by: &PhysicalGroupBy,
1507    batch: &RecordBatch,
1508) -> Result<Vec<Vec<ArrayRef>>> {
1509    let exprs: Vec<ArrayRef> = group_by
1510        .expr
1511        .iter()
1512        .map(|(expr, _)| {
1513            let value = expr.evaluate(batch)?;
1514            value.into_array(batch.num_rows())
1515        })
1516        .collect::<Result<Vec<_>>>()?;
1517
1518    let null_exprs: Vec<ArrayRef> = group_by
1519        .null_expr
1520        .iter()
1521        .map(|(expr, _)| {
1522            let value = expr.evaluate(batch)?;
1523            value.into_array(batch.num_rows())
1524        })
1525        .collect::<Result<Vec<_>>>()?;
1526
1527    group_by
1528        .groups
1529        .iter()
1530        .map(|group| {
1531            let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1532            group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1533                if *is_null {
1534                    Arc::clone(&null_exprs[idx])
1535                } else {
1536                    Arc::clone(&exprs[idx])
1537                }
1538            }));
1539            if !group_by.is_single() {
1540                group_values.push(group_id_array(group, batch)?);
1541            }
1542            Ok(group_values)
1543        })
1544        .collect()
1545}
1546
1547#[cfg(test)]
1548mod tests {
1549    use std::task::{Context, Poll};
1550
1551    use super::*;
1552    use crate::coalesce_batches::CoalesceBatchesExec;
1553    use crate::coalesce_partitions::CoalescePartitionsExec;
1554    use crate::common;
1555    use crate::common::collect;
1556    use crate::execution_plan::Boundedness;
1557    use crate::expressions::col;
1558    use crate::metrics::MetricValue;
1559    use crate::test::assert_is_pending;
1560    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
1561    use crate::test::TestMemoryExec;
1562    use crate::RecordBatchStream;
1563
1564    use arrow::array::{
1565        DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
1566        UInt32Array, UInt64Array,
1567    };
1568    use arrow::compute::{concat_batches, SortOptions};
1569    use arrow::datatypes::{DataType, Int32Type};
1570    use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
1571    use datafusion_common::{internal_err, DataFusionError, ScalarValue};
1572    use datafusion_execution::config::SessionConfig;
1573    use datafusion_execution::memory_pool::FairSpillPool;
1574    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1575    use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1576    use datafusion_functions_aggregate::average::avg_udaf;
1577    use datafusion_functions_aggregate::count::count_udaf;
1578    use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1579    use datafusion_functions_aggregate::median::median_udaf;
1580    use datafusion_functions_aggregate::sum::sum_udaf;
1581    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1582    use datafusion_physical_expr::expressions::lit;
1583    use datafusion_physical_expr::expressions::Literal;
1584    use datafusion_physical_expr::Partitioning;
1585    use datafusion_physical_expr::PhysicalSortExpr;
1586
1587    use futures::{FutureExt, Stream};
1588    use insta::{allow_duplicates, assert_snapshot};
1589
1590    // Generate a schema which consists of 5 columns (a, b, c, d, e)
1591    fn create_test_schema() -> Result<SchemaRef> {
1592        let a = Field::new("a", DataType::Int32, true);
1593        let b = Field::new("b", DataType::Int32, true);
1594        let c = Field::new("c", DataType::Int32, true);
1595        let d = Field::new("d", DataType::Int32, true);
1596        let e = Field::new("e", DataType::Int32, true);
1597        let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1598
1599        Ok(schema)
1600    }
1601
1602    /// some mock data to aggregates
1603    fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1604        // define a schema.
1605        let schema = Arc::new(Schema::new(vec![
1606            Field::new("a", DataType::UInt32, false),
1607            Field::new("b", DataType::Float64, false),
1608        ]));
1609
1610        // define data.
1611        (
1612            Arc::clone(&schema),
1613            vec![
1614                RecordBatch::try_new(
1615                    Arc::clone(&schema),
1616                    vec![
1617                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1618                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1619                    ],
1620                )
1621                .unwrap(),
1622                RecordBatch::try_new(
1623                    schema,
1624                    vec![
1625                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1626                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1627                    ],
1628                )
1629                .unwrap(),
1630            ],
1631        )
1632    }
1633
1634    /// Generates some mock data for aggregate tests.
1635    fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1636        // Define a schema:
1637        let schema = Arc::new(Schema::new(vec![
1638            Field::new("a", DataType::UInt32, false),
1639            Field::new("b", DataType::Float64, false),
1640        ]));
1641
1642        // Generate data so that first and last value results are at 2nd and
1643        // 3rd partitions.  With this construction, we guarantee we don't receive
1644        // the expected result by accident, but merging actually works properly;
1645        // i.e. it doesn't depend on the data insertion order.
1646        (
1647            Arc::clone(&schema),
1648            vec![
1649                RecordBatch::try_new(
1650                    Arc::clone(&schema),
1651                    vec![
1652                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1653                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1654                    ],
1655                )
1656                .unwrap(),
1657                RecordBatch::try_new(
1658                    Arc::clone(&schema),
1659                    vec![
1660                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1661                        Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1662                    ],
1663                )
1664                .unwrap(),
1665                RecordBatch::try_new(
1666                    Arc::clone(&schema),
1667                    vec![
1668                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1669                        Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1670                    ],
1671                )
1672                .unwrap(),
1673                RecordBatch::try_new(
1674                    schema,
1675                    vec![
1676                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1677                        Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1678                    ],
1679                )
1680                .unwrap(),
1681            ],
1682        )
1683    }
1684
1685    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1686        let session_config = SessionConfig::new().with_batch_size(batch_size);
1687        let runtime = RuntimeEnvBuilder::new()
1688            .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1689            .build_arc()
1690            .unwrap();
1691        let task_ctx = TaskContext::default()
1692            .with_session_config(session_config)
1693            .with_runtime(runtime);
1694        Arc::new(task_ctx)
1695    }
1696
1697    async fn check_grouping_sets(
1698        input: Arc<dyn ExecutionPlan>,
1699        spill: bool,
1700    ) -> Result<()> {
1701        let input_schema = input.schema();
1702
1703        let grouping_set = PhysicalGroupBy::new(
1704            vec![
1705                (col("a", &input_schema)?, "a".to_string()),
1706                (col("b", &input_schema)?, "b".to_string()),
1707            ],
1708            vec![
1709                (lit(ScalarValue::UInt32(None)), "a".to_string()),
1710                (lit(ScalarValue::Float64(None)), "b".to_string()),
1711            ],
1712            vec![
1713                vec![false, true],  // (a, NULL)
1714                vec![true, false],  // (NULL, b)
1715                vec![false, false], // (a,b)
1716            ],
1717        );
1718
1719        let aggregates = vec![Arc::new(
1720            AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1721                .schema(Arc::clone(&input_schema))
1722                .alias("COUNT(1)")
1723                .build()?,
1724        )];
1725
1726        let task_ctx = if spill {
1727            // adjust the max memory size to have the partial aggregate result for spill mode.
1728            new_spill_ctx(4, 500)
1729        } else {
1730            Arc::new(TaskContext::default())
1731        };
1732
1733        let partial_aggregate = Arc::new(AggregateExec::try_new(
1734            AggregateMode::Partial,
1735            grouping_set.clone(),
1736            aggregates.clone(),
1737            vec![None],
1738            input,
1739            Arc::clone(&input_schema),
1740        )?);
1741
1742        let result =
1743            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1744
1745        if spill {
1746            // In spill mode, we test with the limited memory, if the mem usage exceeds,
1747            // we trigger the early emit rule, which turns out the partial aggregate result.
1748            allow_duplicates! {
1749            assert_snapshot!(batches_to_sort_string(&result),
1750            @r"
1751+---+-----+---------------+-----------------+
1752| a | b   | __grouping_id | COUNT(1)[count] |
1753+---+-----+---------------+-----------------+
1754|   | 1.0 | 2             | 1               |
1755|   | 1.0 | 2             | 1               |
1756|   | 2.0 | 2             | 1               |
1757|   | 2.0 | 2             | 1               |
1758|   | 3.0 | 2             | 1               |
1759|   | 3.0 | 2             | 1               |
1760|   | 4.0 | 2             | 1               |
1761|   | 4.0 | 2             | 1               |
1762| 2 |     | 1             | 1               |
1763| 2 |     | 1             | 1               |
1764| 2 | 1.0 | 0             | 1               |
1765| 2 | 1.0 | 0             | 1               |
1766| 3 |     | 1             | 1               |
1767| 3 |     | 1             | 2               |
1768| 3 | 2.0 | 0             | 2               |
1769| 3 | 3.0 | 0             | 1               |
1770| 4 |     | 1             | 1               |
1771| 4 |     | 1             | 2               |
1772| 4 | 3.0 | 0             | 1               |
1773| 4 | 4.0 | 0             | 2               |
1774+---+-----+---------------+-----------------+
1775            "
1776            );
1777            }
1778        } else {
1779            allow_duplicates! {
1780            assert_snapshot!(batches_to_sort_string(&result),
1781            @r"
1782+---+-----+---------------+-----------------+
1783| a | b   | __grouping_id | COUNT(1)[count] |
1784+---+-----+---------------+-----------------+
1785|   | 1.0 | 2             | 2               |
1786|   | 2.0 | 2             | 2               |
1787|   | 3.0 | 2             | 2               |
1788|   | 4.0 | 2             | 2               |
1789| 2 |     | 1             | 2               |
1790| 2 | 1.0 | 0             | 2               |
1791| 3 |     | 1             | 3               |
1792| 3 | 2.0 | 0             | 2               |
1793| 3 | 3.0 | 0             | 1               |
1794| 4 |     | 1             | 3               |
1795| 4 | 3.0 | 0             | 1               |
1796| 4 | 4.0 | 0             | 2               |
1797+---+-----+---------------+-----------------+
1798            "
1799            );
1800            }
1801        };
1802
1803        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1804
1805        let final_grouping_set = grouping_set.as_final();
1806
1807        let task_ctx = if spill {
1808            new_spill_ctx(4, 3160)
1809        } else {
1810            task_ctx
1811        };
1812
1813        let merged_aggregate = Arc::new(AggregateExec::try_new(
1814            AggregateMode::Final,
1815            final_grouping_set,
1816            aggregates,
1817            vec![None],
1818            merge,
1819            input_schema,
1820        )?);
1821
1822        let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1823        let batch = concat_batches(&result[0].schema(), &result)?;
1824        assert_eq!(batch.num_columns(), 4);
1825        assert_eq!(batch.num_rows(), 12);
1826
1827        allow_duplicates! {
1828        assert_snapshot!(
1829            batches_to_sort_string(&result),
1830            @r"
1831            +---+-----+---------------+----------+
1832            | a | b   | __grouping_id | COUNT(1) |
1833            +---+-----+---------------+----------+
1834            |   | 1.0 | 2             | 2        |
1835            |   | 2.0 | 2             | 2        |
1836            |   | 3.0 | 2             | 2        |
1837            |   | 4.0 | 2             | 2        |
1838            | 2 |     | 1             | 2        |
1839            | 2 | 1.0 | 0             | 2        |
1840            | 3 |     | 1             | 3        |
1841            | 3 | 2.0 | 0             | 2        |
1842            | 3 | 3.0 | 0             | 1        |
1843            | 4 |     | 1             | 3        |
1844            | 4 | 3.0 | 0             | 1        |
1845            | 4 | 4.0 | 0             | 2        |
1846            +---+-----+---------------+----------+
1847            "
1848        );
1849        }
1850
1851        let metrics = merged_aggregate.metrics().unwrap();
1852        let output_rows = metrics.output_rows().unwrap();
1853        assert_eq!(12, output_rows);
1854
1855        Ok(())
1856    }
1857
1858    /// build the aggregates on the data from some_data() and check the results
1859    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
1860        let input_schema = input.schema();
1861
1862        let grouping_set = PhysicalGroupBy::new(
1863            vec![(col("a", &input_schema)?, "a".to_string())],
1864            vec![],
1865            vec![vec![false]],
1866        );
1867
1868        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
1869            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1870                .schema(Arc::clone(&input_schema))
1871                .alias("AVG(b)")
1872                .build()?,
1873        )];
1874
1875        let task_ctx = if spill {
1876            // set to an appropriate value to trigger spill
1877            new_spill_ctx(2, 1600)
1878        } else {
1879            Arc::new(TaskContext::default())
1880        };
1881
1882        let partial_aggregate = Arc::new(AggregateExec::try_new(
1883            AggregateMode::Partial,
1884            grouping_set.clone(),
1885            aggregates.clone(),
1886            vec![None],
1887            input,
1888            Arc::clone(&input_schema),
1889        )?);
1890
1891        let result =
1892            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1893
1894        if spill {
1895            allow_duplicates! {
1896            assert_snapshot!(batches_to_sort_string(&result), @r"
1897                +---+---------------+-------------+
1898                | a | AVG(b)[count] | AVG(b)[sum] |
1899                +---+---------------+-------------+
1900                | 2 | 1             | 1.0         |
1901                | 2 | 1             | 1.0         |
1902                | 3 | 1             | 2.0         |
1903                | 3 | 2             | 5.0         |
1904                | 4 | 3             | 11.0        |
1905                +---+---------------+-------------+
1906            ");
1907            }
1908        } else {
1909            allow_duplicates! {
1910            assert_snapshot!(batches_to_sort_string(&result), @r"
1911                +---+---------------+-------------+
1912                | a | AVG(b)[count] | AVG(b)[sum] |
1913                +---+---------------+-------------+
1914                | 2 | 2             | 2.0         |
1915                | 3 | 3             | 7.0         |
1916                | 4 | 3             | 11.0        |
1917                +---+---------------+-------------+
1918            ");
1919            }
1920        };
1921
1922        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1923
1924        let final_grouping_set = grouping_set.as_final();
1925
1926        let merged_aggregate = Arc::new(AggregateExec::try_new(
1927            AggregateMode::Final,
1928            final_grouping_set,
1929            aggregates,
1930            vec![None],
1931            merge,
1932            input_schema,
1933        )?);
1934
1935        let task_ctx = if spill {
1936            // enlarge memory limit to let the final aggregation finish
1937            new_spill_ctx(2, 2600)
1938        } else {
1939            Arc::clone(&task_ctx)
1940        };
1941        let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
1942        let batch = concat_batches(&result[0].schema(), &result)?;
1943        assert_eq!(batch.num_columns(), 2);
1944        assert_eq!(batch.num_rows(), 3);
1945
1946        allow_duplicates! {
1947        assert_snapshot!(batches_to_sort_string(&result), @r"
1948            +---+--------------------+
1949            | a | AVG(b)             |
1950            +---+--------------------+
1951            | 2 | 1.0                |
1952            | 3 | 2.3333333333333335 |
1953            | 4 | 3.6666666666666665 |
1954            +---+--------------------+
1955            ");
1956            // For row 2: 3, (2 + 3 + 2) / 3
1957            // For row 3: 4, (3 + 4 + 4) / 3
1958        }
1959
1960        let metrics = merged_aggregate.metrics().unwrap();
1961        let output_rows = metrics.output_rows().unwrap();
1962        let spill_count = metrics.spill_count().unwrap();
1963        let spilled_bytes = metrics.spilled_bytes().unwrap();
1964        let spilled_rows = metrics.spilled_rows().unwrap();
1965
1966        if spill {
1967            // When spilling, the output rows metrics become partial output size + final output size
1968            // This is because final aggregation starts while partial aggregation is still emitting
1969            assert_eq!(8, output_rows);
1970
1971            assert!(spill_count > 0);
1972            assert!(spilled_bytes > 0);
1973            assert!(spilled_rows > 0);
1974        } else {
1975            assert_eq!(3, output_rows);
1976
1977            assert_eq!(0, spill_count);
1978            assert_eq!(0, spilled_bytes);
1979            assert_eq!(0, spilled_rows);
1980        }
1981
1982        Ok(())
1983    }
1984
1985    /// Define a test source that can yield back to runtime before returning its first item ///
1986
1987    #[derive(Debug)]
1988    struct TestYieldingExec {
1989        /// True if this exec should yield back to runtime the first time it is polled
1990        pub yield_first: bool,
1991        cache: PlanProperties,
1992    }
1993
1994    impl TestYieldingExec {
1995        fn new(yield_first: bool) -> Self {
1996            let schema = some_data().0;
1997            let cache = Self::compute_properties(schema);
1998            Self { yield_first, cache }
1999        }
2000
2001        /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
2002        fn compute_properties(schema: SchemaRef) -> PlanProperties {
2003            PlanProperties::new(
2004                EquivalenceProperties::new(schema),
2005                Partitioning::UnknownPartitioning(1),
2006                EmissionType::Incremental,
2007                Boundedness::Bounded,
2008            )
2009        }
2010    }
2011
2012    impl DisplayAs for TestYieldingExec {
2013        fn fmt_as(
2014            &self,
2015            t: DisplayFormatType,
2016            f: &mut std::fmt::Formatter,
2017        ) -> std::fmt::Result {
2018            match t {
2019                DisplayFormatType::Default | DisplayFormatType::Verbose => {
2020                    write!(f, "TestYieldingExec")
2021                }
2022                DisplayFormatType::TreeRender => {
2023                    // TODO: collect info
2024                    write!(f, "")
2025                }
2026            }
2027        }
2028    }
2029
2030    impl ExecutionPlan for TestYieldingExec {
2031        fn name(&self) -> &'static str {
2032            "TestYieldingExec"
2033        }
2034
2035        fn as_any(&self) -> &dyn Any {
2036            self
2037        }
2038
2039        fn properties(&self) -> &PlanProperties {
2040            &self.cache
2041        }
2042
2043        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2044            vec![]
2045        }
2046
2047        fn with_new_children(
2048            self: Arc<Self>,
2049            _: Vec<Arc<dyn ExecutionPlan>>,
2050        ) -> Result<Arc<dyn ExecutionPlan>> {
2051            internal_err!("Children cannot be replaced in {self:?}")
2052        }
2053
2054        fn execute(
2055            &self,
2056            _partition: usize,
2057            _context: Arc<TaskContext>,
2058        ) -> Result<SendableRecordBatchStream> {
2059            let stream = if self.yield_first {
2060                TestYieldingStream::New
2061            } else {
2062                TestYieldingStream::Yielded
2063            };
2064
2065            Ok(Box::pin(stream))
2066        }
2067
2068        fn statistics(&self) -> Result<Statistics> {
2069            self.partition_statistics(None)
2070        }
2071
2072        fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
2073            if partition.is_some() {
2074                return Ok(Statistics::new_unknown(self.schema().as_ref()));
2075            }
2076            let (_, batches) = some_data();
2077            Ok(common::compute_record_batch_statistics(
2078                &[batches],
2079                &self.schema(),
2080                None,
2081            ))
2082        }
2083    }
2084
2085    /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records
2086    enum TestYieldingStream {
2087        New,
2088        Yielded,
2089        ReturnedBatch1,
2090        ReturnedBatch2,
2091    }
2092
2093    impl Stream for TestYieldingStream {
2094        type Item = Result<RecordBatch>;
2095
2096        fn poll_next(
2097            mut self: std::pin::Pin<&mut Self>,
2098            cx: &mut Context<'_>,
2099        ) -> Poll<Option<Self::Item>> {
2100            match &*self {
2101                TestYieldingStream::New => {
2102                    *(self.as_mut()) = TestYieldingStream::Yielded;
2103                    cx.waker().wake_by_ref();
2104                    Poll::Pending
2105                }
2106                TestYieldingStream::Yielded => {
2107                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2108                    Poll::Ready(Some(Ok(some_data().1[0].clone())))
2109                }
2110                TestYieldingStream::ReturnedBatch1 => {
2111                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2112                    Poll::Ready(Some(Ok(some_data().1[1].clone())))
2113                }
2114                TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2115            }
2116        }
2117    }
2118
2119    impl RecordBatchStream for TestYieldingStream {
2120        fn schema(&self) -> SchemaRef {
2121            some_data().0
2122        }
2123    }
2124
2125    //--- Tests ---//
2126
2127    #[tokio::test]
2128    async fn aggregate_source_not_yielding() -> Result<()> {
2129        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2130
2131        check_aggregates(input, false).await
2132    }
2133
2134    #[tokio::test]
2135    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2136        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2137
2138        check_grouping_sets(input, false).await
2139    }
2140
2141    #[tokio::test]
2142    async fn aggregate_source_with_yielding() -> Result<()> {
2143        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2144
2145        check_aggregates(input, false).await
2146    }
2147
2148    #[tokio::test]
2149    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2150        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2151
2152        check_grouping_sets(input, false).await
2153    }
2154
2155    #[tokio::test]
2156    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2157        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2158
2159        check_aggregates(input, true).await
2160    }
2161
2162    #[tokio::test]
2163    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2164        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2165
2166        check_grouping_sets(input, true).await
2167    }
2168
2169    #[tokio::test]
2170    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2171        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2172
2173        check_aggregates(input, true).await
2174    }
2175
2176    #[tokio::test]
2177    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2178        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2179
2180        check_grouping_sets(input, true).await
2181    }
2182
2183    // Median(a)
2184    fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2185        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2186            .schema(schema)
2187            .alias("MEDIAN(a)")
2188            .build()
2189    }
2190
2191    #[tokio::test]
2192    async fn test_oom() -> Result<()> {
2193        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2194        let input_schema = input.schema();
2195
2196        let runtime = RuntimeEnvBuilder::new()
2197            .with_memory_limit(1, 1.0)
2198            .build_arc()?;
2199        let task_ctx = TaskContext::default().with_runtime(runtime);
2200        let task_ctx = Arc::new(task_ctx);
2201
2202        let groups_none = PhysicalGroupBy::default();
2203        let groups_some = PhysicalGroupBy::new(
2204            vec![(col("a", &input_schema)?, "a".to_string())],
2205            vec![],
2206            vec![vec![false]],
2207        );
2208
2209        // something that allocates within the aggregator
2210        let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2211            vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2212
2213        // use fast-path in `row_hash.rs`.
2214        let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2215            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2216                .schema(Arc::clone(&input_schema))
2217                .alias("AVG(b)")
2218                .build()?,
2219        )];
2220
2221        for (version, groups, aggregates) in [
2222            (0, groups_none, aggregates_v0),
2223            (2, groups_some, aggregates_v2),
2224        ] {
2225            let n_aggr = aggregates.len();
2226            let partial_aggregate = Arc::new(AggregateExec::try_new(
2227                AggregateMode::Partial,
2228                groups,
2229                aggregates,
2230                vec![None; n_aggr],
2231                Arc::clone(&input),
2232                Arc::clone(&input_schema),
2233            )?);
2234
2235            let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?;
2236
2237            // ensure that we really got the version we wanted
2238            match version {
2239                0 => {
2240                    assert!(matches!(stream, StreamType::AggregateStream(_)));
2241                }
2242                1 => {
2243                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2244                }
2245                2 => {
2246                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2247                }
2248                _ => panic!("Unknown version: {version}"),
2249            }
2250
2251            let stream: SendableRecordBatchStream = stream.into();
2252            let err = collect(stream).await.unwrap_err();
2253
2254            // error root cause traversal is a bit complicated, see #4172.
2255            let err = err.find_root();
2256            assert!(
2257                matches!(err, DataFusionError::ResourcesExhausted(_)),
2258                "Wrong error type: {err}",
2259            );
2260        }
2261
2262        Ok(())
2263    }
2264
2265    #[tokio::test]
2266    async fn test_drop_cancel_without_groups() -> Result<()> {
2267        let task_ctx = Arc::new(TaskContext::default());
2268        let schema =
2269            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2270
2271        let groups = PhysicalGroupBy::default();
2272
2273        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2274            AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2275                .schema(Arc::clone(&schema))
2276                .alias("AVG(a)")
2277                .build()?,
2278        )];
2279
2280        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2281        let refs = blocking_exec.refs();
2282        let aggregate_exec = Arc::new(AggregateExec::try_new(
2283            AggregateMode::Partial,
2284            groups.clone(),
2285            aggregates.clone(),
2286            vec![None],
2287            blocking_exec,
2288            schema,
2289        )?);
2290
2291        let fut = crate::collect(aggregate_exec, task_ctx);
2292        let mut fut = fut.boxed();
2293
2294        assert_is_pending(&mut fut);
2295        drop(fut);
2296        assert_strong_count_converges_to_zero(refs).await;
2297
2298        Ok(())
2299    }
2300
2301    #[tokio::test]
2302    async fn test_drop_cancel_with_groups() -> Result<()> {
2303        let task_ctx = Arc::new(TaskContext::default());
2304        let schema = Arc::new(Schema::new(vec![
2305            Field::new("a", DataType::Float64, true),
2306            Field::new("b", DataType::Float64, true),
2307        ]));
2308
2309        let groups =
2310            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2311
2312        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2313            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2314                .schema(Arc::clone(&schema))
2315                .alias("AVG(b)")
2316                .build()?,
2317        )];
2318
2319        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2320        let refs = blocking_exec.refs();
2321        let aggregate_exec = Arc::new(AggregateExec::try_new(
2322            AggregateMode::Partial,
2323            groups,
2324            aggregates.clone(),
2325            vec![None],
2326            blocking_exec,
2327            schema,
2328        )?);
2329
2330        let fut = crate::collect(aggregate_exec, task_ctx);
2331        let mut fut = fut.boxed();
2332
2333        assert_is_pending(&mut fut);
2334        drop(fut);
2335        assert_strong_count_converges_to_zero(refs).await;
2336
2337        Ok(())
2338    }
2339
2340    #[tokio::test]
2341    async fn run_first_last_multi_partitions() -> Result<()> {
2342        for use_coalesce_batches in [false, true] {
2343            for is_first_acc in [false, true] {
2344                for spill in [false, true] {
2345                    first_last_multi_partitions(
2346                        use_coalesce_batches,
2347                        is_first_acc,
2348                        spill,
2349                        4200,
2350                    )
2351                    .await?
2352                }
2353            }
2354        }
2355        Ok(())
2356    }
2357
2358    // FIRST_VALUE(b ORDER BY b <SortOptions>)
2359    fn test_first_value_agg_expr(
2360        schema: &Schema,
2361        sort_options: SortOptions,
2362    ) -> Result<Arc<AggregateFunctionExpr>> {
2363        let order_bys = vec![PhysicalSortExpr {
2364            expr: col("b", schema)?,
2365            options: sort_options,
2366        }];
2367        let args = [col("b", schema)?];
2368
2369        AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2370            .order_by(order_bys)
2371            .schema(Arc::new(schema.clone()))
2372            .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2373            .build()
2374            .map(Arc::new)
2375    }
2376
2377    // LAST_VALUE(b ORDER BY b <SortOptions>)
2378    fn test_last_value_agg_expr(
2379        schema: &Schema,
2380        sort_options: SortOptions,
2381    ) -> Result<Arc<AggregateFunctionExpr>> {
2382        let order_bys = vec![PhysicalSortExpr {
2383            expr: col("b", schema)?,
2384            options: sort_options,
2385        }];
2386        let args = [col("b", schema)?];
2387        AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2388            .order_by(order_bys)
2389            .schema(Arc::new(schema.clone()))
2390            .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2391            .build()
2392            .map(Arc::new)
2393    }
2394
2395    // This function either constructs the physical plan below,
2396    //
2397    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2398    // "  CoalesceBatchesExec: target_batch_size=1024",
2399    // "    CoalescePartitionsExec",
2400    // "      AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2401    // "        DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2402    //
2403    // or
2404    //
2405    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2406    // "  CoalescePartitionsExec",
2407    // "    AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2408    // "      DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2409    //
2410    // and checks whether the function `merge_batch` works correctly for
2411    // FIRST_VALUE and LAST_VALUE functions.
2412    async fn first_last_multi_partitions(
2413        use_coalesce_batches: bool,
2414        is_first_acc: bool,
2415        spill: bool,
2416        max_memory: usize,
2417    ) -> Result<()> {
2418        let task_ctx = if spill {
2419            new_spill_ctx(2, max_memory)
2420        } else {
2421            Arc::new(TaskContext::default())
2422        };
2423
2424        let (schema, data) = some_data_v2();
2425        let partition1 = data[0].clone();
2426        let partition2 = data[1].clone();
2427        let partition3 = data[2].clone();
2428        let partition4 = data[3].clone();
2429
2430        let groups =
2431            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2432
2433        let sort_options = SortOptions {
2434            descending: false,
2435            nulls_first: false,
2436        };
2437        let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2438            vec![test_first_value_agg_expr(&schema, sort_options)?]
2439        } else {
2440            vec![test_last_value_agg_expr(&schema, sort_options)?]
2441        };
2442
2443        let memory_exec = TestMemoryExec::try_new_exec(
2444            &[
2445                vec![partition1],
2446                vec![partition2],
2447                vec![partition3],
2448                vec![partition4],
2449            ],
2450            Arc::clone(&schema),
2451            None,
2452        )?;
2453        let aggregate_exec = Arc::new(AggregateExec::try_new(
2454            AggregateMode::Partial,
2455            groups.clone(),
2456            aggregates.clone(),
2457            vec![None],
2458            memory_exec,
2459            Arc::clone(&schema),
2460        )?);
2461        let coalesce = if use_coalesce_batches {
2462            let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2463            Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2464        } else {
2465            Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2466                as Arc<dyn ExecutionPlan>
2467        };
2468        let aggregate_final = Arc::new(AggregateExec::try_new(
2469            AggregateMode::Final,
2470            groups,
2471            aggregates.clone(),
2472            vec![None],
2473            coalesce,
2474            schema,
2475        )?) as Arc<dyn ExecutionPlan>;
2476
2477        let result = crate::collect(aggregate_final, task_ctx).await?;
2478        if is_first_acc {
2479            allow_duplicates! {
2480            assert_snapshot!(batches_to_string(&result), @r"
2481                +---+--------------------------------------------+
2482                | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
2483                +---+--------------------------------------------+
2484                | 2 | 0.0                                        |
2485                | 3 | 1.0                                        |
2486                | 4 | 3.0                                        |
2487                +---+--------------------------------------------+
2488                ");
2489            }
2490        } else {
2491            allow_duplicates! {
2492            assert_snapshot!(batches_to_string(&result), @r"
2493                +---+-------------------------------------------+
2494                | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
2495                +---+-------------------------------------------+
2496                | 2 | 3.0                                       |
2497                | 3 | 5.0                                       |
2498                | 4 | 6.0                                       |
2499                +---+-------------------------------------------+
2500                ");
2501            }
2502        };
2503        Ok(())
2504    }
2505
2506    #[tokio::test]
2507    async fn test_get_finest_requirements() -> Result<()> {
2508        let test_schema = create_test_schema()?;
2509
2510        let options = SortOptions {
2511            descending: false,
2512            nulls_first: false,
2513        };
2514        let col_a = &col("a", &test_schema)?;
2515        let col_b = &col("b", &test_schema)?;
2516        let col_c = &col("c", &test_schema)?;
2517        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2518        // Columns a and b are equal.
2519        eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
2520        // Aggregate requirements are
2521        // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively
2522        let order_by_exprs = vec![
2523            vec![],
2524            vec![PhysicalSortExpr {
2525                expr: Arc::clone(col_a),
2526                options,
2527            }],
2528            vec![
2529                PhysicalSortExpr {
2530                    expr: Arc::clone(col_a),
2531                    options,
2532                },
2533                PhysicalSortExpr {
2534                    expr: Arc::clone(col_b),
2535                    options,
2536                },
2537                PhysicalSortExpr {
2538                    expr: Arc::clone(col_c),
2539                    options,
2540                },
2541            ],
2542            vec![
2543                PhysicalSortExpr {
2544                    expr: Arc::clone(col_a),
2545                    options,
2546                },
2547                PhysicalSortExpr {
2548                    expr: Arc::clone(col_b),
2549                    options,
2550                },
2551            ],
2552        ];
2553
2554        let common_requirement = vec![
2555            PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
2556            PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
2557        ];
2558        let mut aggr_exprs = order_by_exprs
2559            .into_iter()
2560            .map(|order_by_expr| {
2561                AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2562                    .alias("a")
2563                    .order_by(order_by_expr)
2564                    .schema(Arc::clone(&test_schema))
2565                    .build()
2566                    .map(Arc::new)
2567                    .unwrap()
2568            })
2569            .collect::<Vec<_>>();
2570        let group_by = PhysicalGroupBy::new_single(vec![]);
2571        let result = get_finer_aggregate_exprs_requirement(
2572            &mut aggr_exprs,
2573            &group_by,
2574            &eq_properties,
2575            &AggregateMode::Partial,
2576        )?;
2577        assert_eq!(result, common_requirement);
2578        Ok(())
2579    }
2580
2581    #[test]
2582    fn test_agg_exec_same_schema() -> Result<()> {
2583        let schema = Arc::new(Schema::new(vec![
2584            Field::new("a", DataType::Float32, true),
2585            Field::new("b", DataType::Float32, true),
2586        ]));
2587
2588        let col_a = col("a", &schema)?;
2589        let option_desc = SortOptions {
2590            descending: true,
2591            nulls_first: true,
2592        };
2593        let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2594
2595        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2596            test_first_value_agg_expr(&schema, option_desc)?,
2597            test_last_value_agg_expr(&schema, option_desc)?,
2598        ];
2599        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2600        let aggregate_exec = Arc::new(AggregateExec::try_new(
2601            AggregateMode::Partial,
2602            groups,
2603            aggregates,
2604            vec![None, None],
2605            Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2606            schema,
2607        )?);
2608        let new_agg =
2609            Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
2610        assert_eq!(new_agg.schema(), aggregate_exec.schema());
2611        Ok(())
2612    }
2613
2614    #[tokio::test]
2615    async fn test_agg_exec_group_by_const() -> Result<()> {
2616        let schema = Arc::new(Schema::new(vec![
2617            Field::new("a", DataType::Float32, true),
2618            Field::new("b", DataType::Float32, true),
2619            Field::new("const", DataType::Int32, false),
2620        ]));
2621
2622        let col_a = col("a", &schema)?;
2623        let col_b = col("b", &schema)?;
2624        let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2625
2626        let groups = PhysicalGroupBy::new(
2627            vec![
2628                (col_a, "a".to_string()),
2629                (col_b, "b".to_string()),
2630                (const_expr, "const".to_string()),
2631            ],
2632            vec![
2633                (
2634                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2635                    "a".to_string(),
2636                ),
2637                (
2638                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2639                    "b".to_string(),
2640                ),
2641                (
2642                    Arc::new(Literal::new(ScalarValue::Int32(None))),
2643                    "const".to_string(),
2644                ),
2645            ],
2646            vec![
2647                vec![false, true, true],
2648                vec![true, false, true],
2649                vec![true, true, false],
2650            ],
2651        );
2652
2653        let aggregates: Vec<Arc<AggregateFunctionExpr>> =
2654            vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2655                .schema(Arc::clone(&schema))
2656                .alias("1")
2657                .build()
2658                .map(Arc::new)?];
2659
2660        let input_batches = (0..4)
2661            .map(|_| {
2662                let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2663                let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2664                let c = Arc::new(Int32Array::from(vec![1; 8192]));
2665
2666                RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2667            })
2668            .collect();
2669
2670        let input =
2671            TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
2672
2673        let aggregate_exec = Arc::new(AggregateExec::try_new(
2674            AggregateMode::Single,
2675            groups,
2676            aggregates.clone(),
2677            vec![None],
2678            input,
2679            schema,
2680        )?);
2681
2682        let output =
2683            collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2684
2685        allow_duplicates! {
2686        assert_snapshot!(batches_to_sort_string(&output), @r"
2687            +-----+-----+-------+---------------+-------+
2688            | a   | b   | const | __grouping_id | 1     |
2689            +-----+-----+-------+---------------+-------+
2690            |     |     | 1     | 6             | 32768 |
2691            |     | 0.0 |       | 5             | 32768 |
2692            | 0.0 |     |       | 3             | 32768 |
2693            +-----+-----+-------+---------------+-------+
2694        ");
2695        }
2696
2697        Ok(())
2698    }
2699
2700    #[tokio::test]
2701    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2702        let batch = RecordBatch::try_new(
2703            Arc::new(Schema::new(vec![
2704                Field::new(
2705                    "labels".to_string(),
2706                    DataType::Struct(
2707                        vec![
2708                            Field::new(
2709                                "a".to_string(),
2710                                DataType::Dictionary(
2711                                    Box::new(DataType::Int32),
2712                                    Box::new(DataType::Utf8),
2713                                ),
2714                                true,
2715                            ),
2716                            Field::new(
2717                                "b".to_string(),
2718                                DataType::Dictionary(
2719                                    Box::new(DataType::Int32),
2720                                    Box::new(DataType::Utf8),
2721                                ),
2722                                true,
2723                            ),
2724                        ]
2725                        .into(),
2726                    ),
2727                    false,
2728                ),
2729                Field::new("value", DataType::UInt64, false),
2730            ])),
2731            vec![
2732                Arc::new(StructArray::from(vec![
2733                    (
2734                        Arc::new(Field::new(
2735                            "a".to_string(),
2736                            DataType::Dictionary(
2737                                Box::new(DataType::Int32),
2738                                Box::new(DataType::Utf8),
2739                            ),
2740                            true,
2741                        )),
2742                        Arc::new(
2743                            vec![Some("a"), None, Some("a")]
2744                                .into_iter()
2745                                .collect::<DictionaryArray<Int32Type>>(),
2746                        ) as ArrayRef,
2747                    ),
2748                    (
2749                        Arc::new(Field::new(
2750                            "b".to_string(),
2751                            DataType::Dictionary(
2752                                Box::new(DataType::Int32),
2753                                Box::new(DataType::Utf8),
2754                            ),
2755                            true,
2756                        )),
2757                        Arc::new(
2758                            vec![Some("b"), Some("c"), Some("b")]
2759                                .into_iter()
2760                                .collect::<DictionaryArray<Int32Type>>(),
2761                        ) as ArrayRef,
2762                    ),
2763                ])),
2764                Arc::new(UInt64Array::from(vec![1, 1, 1])),
2765            ],
2766        )
2767        .expect("Failed to create RecordBatch");
2768
2769        let group_by = PhysicalGroupBy::new_single(vec![(
2770            col("labels", &batch.schema())?,
2771            "labels".to_string(),
2772        )]);
2773
2774        let aggr_expr = vec![AggregateExprBuilder::new(
2775            sum_udaf(),
2776            vec![col("value", &batch.schema())?],
2777        )
2778        .schema(Arc::clone(&batch.schema()))
2779        .alias(String::from("SUM(value)"))
2780        .build()
2781        .map(Arc::new)?];
2782
2783        let input = TestMemoryExec::try_new_exec(
2784            &[vec![batch.clone()]],
2785            Arc::<Schema>::clone(&batch.schema()),
2786            None,
2787        )?;
2788        let aggregate_exec = Arc::new(AggregateExec::try_new(
2789            AggregateMode::FinalPartitioned,
2790            group_by,
2791            aggr_expr,
2792            vec![None],
2793            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2794            batch.schema(),
2795        )?);
2796
2797        let session_config = SessionConfig::default();
2798        let ctx = TaskContext::default().with_session_config(session_config);
2799        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2800
2801        allow_duplicates! {
2802        assert_snapshot!(batches_to_string(&output), @r"
2803            +--------------+------------+
2804            | labels       | SUM(value) |
2805            +--------------+------------+
2806            | {a: a, b: b} | 2          |
2807            | {a: , b: c}  | 1          |
2808            +--------------+------------+
2809            ");
2810        }
2811
2812        Ok(())
2813    }
2814
2815    #[tokio::test]
2816    async fn test_skip_aggregation_after_first_batch() -> Result<()> {
2817        let schema = Arc::new(Schema::new(vec![
2818            Field::new("key", DataType::Int32, true),
2819            Field::new("val", DataType::Int32, true),
2820        ]));
2821
2822        let group_by =
2823            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2824
2825        let aggr_expr =
2826            vec![
2827                AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2828                    .schema(Arc::clone(&schema))
2829                    .alias(String::from("COUNT(val)"))
2830                    .build()
2831                    .map(Arc::new)?,
2832            ];
2833
2834        let input_data = vec![
2835            RecordBatch::try_new(
2836                Arc::clone(&schema),
2837                vec![
2838                    Arc::new(Int32Array::from(vec![1, 2, 3])),
2839                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2840                ],
2841            )
2842            .unwrap(),
2843            RecordBatch::try_new(
2844                Arc::clone(&schema),
2845                vec![
2846                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2847                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2848                ],
2849            )
2850            .unwrap(),
2851        ];
2852
2853        let input =
2854            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2855        let aggregate_exec = Arc::new(AggregateExec::try_new(
2856            AggregateMode::Partial,
2857            group_by,
2858            aggr_expr,
2859            vec![None],
2860            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2861            schema,
2862        )?);
2863
2864        let mut session_config = SessionConfig::default();
2865        session_config = session_config.set(
2866            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2867            &ScalarValue::Int64(Some(2)),
2868        );
2869        session_config = session_config.set(
2870            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2871            &ScalarValue::Float64(Some(0.1)),
2872        );
2873
2874        let ctx = TaskContext::default().with_session_config(session_config);
2875        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2876
2877        allow_duplicates! {
2878            assert_snapshot!(batches_to_string(&output), @r"
2879            +-----+-------------------+
2880            | key | COUNT(val)[count] |
2881            +-----+-------------------+
2882            | 1   | 1                 |
2883            | 2   | 1                 |
2884            | 3   | 1                 |
2885            | 2   | 1                 |
2886            | 3   | 1                 |
2887            | 4   | 1                 |
2888            +-----+-------------------+
2889            ");
2890        }
2891
2892        Ok(())
2893    }
2894
2895    #[tokio::test]
2896    async fn test_skip_aggregation_after_threshold() -> Result<()> {
2897        let schema = Arc::new(Schema::new(vec![
2898            Field::new("key", DataType::Int32, true),
2899            Field::new("val", DataType::Int32, true),
2900        ]));
2901
2902        let group_by =
2903            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2904
2905        let aggr_expr =
2906            vec![
2907                AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2908                    .schema(Arc::clone(&schema))
2909                    .alias(String::from("COUNT(val)"))
2910                    .build()
2911                    .map(Arc::new)?,
2912            ];
2913
2914        let input_data = vec![
2915            RecordBatch::try_new(
2916                Arc::clone(&schema),
2917                vec![
2918                    Arc::new(Int32Array::from(vec![1, 2, 3])),
2919                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2920                ],
2921            )
2922            .unwrap(),
2923            RecordBatch::try_new(
2924                Arc::clone(&schema),
2925                vec![
2926                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2927                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2928                ],
2929            )
2930            .unwrap(),
2931            RecordBatch::try_new(
2932                Arc::clone(&schema),
2933                vec![
2934                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2935                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2936                ],
2937            )
2938            .unwrap(),
2939        ];
2940
2941        let input =
2942            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2943        let aggregate_exec = Arc::new(AggregateExec::try_new(
2944            AggregateMode::Partial,
2945            group_by,
2946            aggr_expr,
2947            vec![None],
2948            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2949            schema,
2950        )?);
2951
2952        let mut session_config = SessionConfig::default();
2953        session_config = session_config.set(
2954            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2955            &ScalarValue::Int64(Some(5)),
2956        );
2957        session_config = session_config.set(
2958            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2959            &ScalarValue::Float64(Some(0.1)),
2960        );
2961
2962        let ctx = TaskContext::default().with_session_config(session_config);
2963        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2964
2965        allow_duplicates! {
2966            assert_snapshot!(batches_to_string(&output), @r"
2967            +-----+-------------------+
2968            | key | COUNT(val)[count] |
2969            +-----+-------------------+
2970            | 1   | 1                 |
2971            | 2   | 2                 |
2972            | 3   | 2                 |
2973            | 4   | 1                 |
2974            | 2   | 1                 |
2975            | 3   | 1                 |
2976            | 4   | 1                 |
2977            +-----+-------------------+
2978            ");
2979        }
2980
2981        Ok(())
2982    }
2983
2984    #[test]
2985    fn group_exprs_nullable() -> Result<()> {
2986        let input_schema = Arc::new(Schema::new(vec![
2987            Field::new("a", DataType::Float32, false),
2988            Field::new("b", DataType::Float32, false),
2989        ]));
2990
2991        let aggr_expr =
2992            vec![
2993                AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
2994                    .schema(Arc::clone(&input_schema))
2995                    .alias("COUNT(a)")
2996                    .build()
2997                    .map(Arc::new)?,
2998            ];
2999
3000        let grouping_set = PhysicalGroupBy::new(
3001            vec![
3002                (col("a", &input_schema)?, "a".to_string()),
3003                (col("b", &input_schema)?, "b".to_string()),
3004            ],
3005            vec![
3006                (lit(ScalarValue::Float32(None)), "a".to_string()),
3007                (lit(ScalarValue::Float32(None)), "b".to_string()),
3008            ],
3009            vec![
3010                vec![false, true],  // (a, NULL)
3011                vec![false, false], // (a,b)
3012            ],
3013        );
3014        let aggr_schema = create_schema(
3015            &input_schema,
3016            &grouping_set,
3017            &aggr_expr,
3018            AggregateMode::Final,
3019        )?;
3020        let expected_schema = Schema::new(vec![
3021            Field::new("a", DataType::Float32, false),
3022            Field::new("b", DataType::Float32, true),
3023            Field::new("__grouping_id", DataType::UInt8, false),
3024            Field::new("COUNT(a)", DataType::Int64, false),
3025        ]);
3026        assert_eq!(aggr_schema, expected_schema);
3027        Ok(())
3028    }
3029
3030    // test for https://github.com/apache/datafusion/issues/13949
3031    async fn run_test_with_spill_pool_if_necessary(
3032        pool_size: usize,
3033        expect_spill: bool,
3034    ) -> Result<()> {
3035        fn create_record_batch(
3036            schema: &Arc<Schema>,
3037            data: (Vec<u32>, Vec<f64>),
3038        ) -> Result<RecordBatch> {
3039            Ok(RecordBatch::try_new(
3040                Arc::clone(schema),
3041                vec![
3042                    Arc::new(UInt32Array::from(data.0)),
3043                    Arc::new(Float64Array::from(data.1)),
3044                ],
3045            )?)
3046        }
3047
3048        let schema = Arc::new(Schema::new(vec![
3049            Field::new("a", DataType::UInt32, false),
3050            Field::new("b", DataType::Float64, false),
3051        ]));
3052
3053        let batches = vec![
3054            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3055            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3056        ];
3057        let plan: Arc<dyn ExecutionPlan> =
3058            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3059
3060        let grouping_set = PhysicalGroupBy::new(
3061            vec![(col("a", &schema)?, "a".to_string())],
3062            vec![],
3063            vec![vec![false]],
3064        );
3065
3066        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
3067        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3068            Arc::new(
3069                AggregateExprBuilder::new(
3070                    datafusion_functions_aggregate::min_max::min_udaf(),
3071                    vec![col("b", &schema)?],
3072                )
3073                .schema(Arc::clone(&schema))
3074                .alias("MIN(b)")
3075                .build()?,
3076            ),
3077            Arc::new(
3078                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3079                    .schema(Arc::clone(&schema))
3080                    .alias("AVG(b)")
3081                    .build()?,
3082            ),
3083        ];
3084
3085        let single_aggregate = Arc::new(AggregateExec::try_new(
3086            AggregateMode::Single,
3087            grouping_set,
3088            aggregates,
3089            vec![None, None],
3090            plan,
3091            Arc::clone(&schema),
3092        )?);
3093
3094        let batch_size = 2;
3095        let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3096        let task_ctx = Arc::new(
3097            TaskContext::default()
3098                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3099                .with_runtime(Arc::new(
3100                    RuntimeEnvBuilder::new()
3101                        .with_memory_pool(memory_pool)
3102                        .build()?,
3103                )),
3104        );
3105
3106        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3107
3108        assert_spill_count_metric(expect_spill, single_aggregate);
3109
3110        allow_duplicates! {
3111            assert_snapshot!(batches_to_string(&result), @r"
3112                +---+--------+--------+
3113                | a | MIN(b) | AVG(b) |
3114                +---+--------+--------+
3115                | 2 | 1.0    | 1.0    |
3116                | 3 | 2.0    | 2.0    |
3117                | 4 | 3.0    | 3.5    |
3118                +---+--------+--------+
3119            ");
3120        }
3121
3122        Ok(())
3123    }
3124
3125    fn assert_spill_count_metric(
3126        expect_spill: bool,
3127        single_aggregate: Arc<AggregateExec>,
3128    ) {
3129        if let Some(metrics_set) = single_aggregate.metrics() {
3130            let mut spill_count = 0;
3131
3132            // Inspect metrics for SpillCount
3133            for metric in metrics_set.iter() {
3134                if let MetricValue::SpillCount(count) = metric.value() {
3135                    spill_count = count.value();
3136                    break;
3137                }
3138            }
3139
3140            if expect_spill && spill_count == 0 {
3141                panic!(
3142                    "Expected spill but SpillCount metric not found or SpillCount was 0."
3143                );
3144            } else if !expect_spill && spill_count > 0 {
3145                panic!("Expected no spill but found SpillCount metric with value greater than 0.");
3146            }
3147        } else {
3148            panic!("No metrics returned from the operator; cannot verify spilling.");
3149        }
3150    }
3151
3152    #[tokio::test]
3153    async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3154        // test with spill
3155        run_test_with_spill_pool_if_necessary(2_000, true).await?;
3156        // test without spill
3157        run_test_with_spill_pool_if_necessary(20_000, false).await?;
3158        Ok(())
3159    }
3160}