1use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25 no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26 topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30 ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
31 FilterPushdownPropagation, PushedDownPredicate,
32};
33use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
34use crate::{
35 DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36 SendableRecordBatchStream, Statistics,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use parking_lot::Mutex;
41use std::collections::HashSet;
42
43use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44use arrow::datatypes::{Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use arrow_schema::FieldRef;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49 Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_expr::{Accumulator, Aggregate};
53use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
54use datafusion_physical_expr::equivalence::ProjectionMapping;
55use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
56use datafusion_physical_expr::{
57 ConstExpr, EquivalenceProperties, physical_exprs_contains,
58};
59use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
60use datafusion_physical_expr_common::sort_expr::{
61 LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
62};
63
64use datafusion_expr::utils::AggregateOrderSensitivity;
65use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
66use itertools::Itertools;
67
68pub mod group_values;
69mod no_grouping;
70pub mod order;
71mod row_hash;
72mod topk;
73mod topk_stream;
74
75const AGGREGATION_HASH_SEED: ahash::RandomState =
77 ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
78
79#[derive(Debug, Copy, Clone, PartialEq, Eq)]
84pub enum AggregateMode {
85 Partial,
92 Final,
108 FinalPartitioned,
117 Single,
125 SinglePartitioned,
134}
135
136impl AggregateMode {
137 pub fn is_first_stage(&self) -> bool {
141 match self {
142 AggregateMode::Partial
143 | AggregateMode::Single
144 | AggregateMode::SinglePartitioned => true,
145 AggregateMode::Final | AggregateMode::FinalPartitioned => false,
146 }
147 }
148}
149
150#[derive(Clone, Debug, Default)]
169pub struct PhysicalGroupBy {
170 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
172 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
174 groups: Vec<Vec<bool>>,
179 has_grouping_set: bool,
182}
183
184impl PhysicalGroupBy {
185 pub fn new(
187 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
188 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
189 groups: Vec<Vec<bool>>,
190 has_grouping_set: bool,
191 ) -> Self {
192 Self {
193 expr,
194 null_expr,
195 groups,
196 has_grouping_set,
197 }
198 }
199
200 pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
203 let num_exprs = expr.len();
204 Self {
205 expr,
206 null_expr: vec![],
207 groups: vec![vec![false; num_exprs]],
208 has_grouping_set: false,
209 }
210 }
211
212 pub fn exprs_nullable(&self) -> Vec<bool> {
214 let mut exprs_nullable = vec![false; self.expr.len()];
215 for group in self.groups.iter() {
216 group.iter().enumerate().for_each(|(index, is_null)| {
217 if *is_null {
218 exprs_nullable[index] = true;
219 }
220 })
221 }
222 exprs_nullable
223 }
224
225 pub fn is_true_no_grouping(&self) -> bool {
227 self.is_empty() && !self.has_grouping_set
228 }
229
230 pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
232 &self.expr
233 }
234
235 pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
237 &self.null_expr
238 }
239
240 pub fn groups(&self) -> &[Vec<bool>] {
242 &self.groups
243 }
244
245 pub fn has_grouping_set(&self) -> bool {
247 self.has_grouping_set
248 }
249
250 pub fn is_empty(&self) -> bool {
252 self.expr.is_empty()
253 }
254
255 pub fn is_single(&self) -> bool {
258 !self.has_grouping_set
259 }
260
261 pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
263 self.expr
264 .iter()
265 .map(|(expr, _alias)| Arc::clone(expr))
266 .collect()
267 }
268
269 fn num_output_exprs(&self) -> usize {
271 let mut num_exprs = self.expr.len();
272 if self.has_grouping_set {
273 num_exprs += 1
274 }
275 num_exprs
276 }
277
278 pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
280 let num_output_exprs = self.num_output_exprs();
281 let mut output_exprs = Vec::with_capacity(num_output_exprs);
282 output_exprs.extend(
283 self.expr
284 .iter()
285 .enumerate()
286 .take(num_output_exprs)
287 .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
288 );
289 if self.has_grouping_set {
290 output_exprs.push(Arc::new(Column::new(
291 Aggregate::INTERNAL_GROUPING_ID,
292 self.expr.len(),
293 )) as _);
294 }
295 output_exprs
296 }
297
298 pub fn num_group_exprs(&self) -> usize {
300 self.expr.len() + usize::from(self.has_grouping_set)
301 }
302
303 pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
304 Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
305 }
306
307 fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
309 let mut fields = Vec::with_capacity(self.num_group_exprs());
310 for ((expr, name), group_expr_nullable) in
311 self.expr.iter().zip(self.exprs_nullable().into_iter())
312 {
313 fields.push(
314 Field::new(
315 name,
316 expr.data_type(input_schema)?,
317 group_expr_nullable || expr.nullable(input_schema)?,
318 )
319 .with_metadata(expr.return_field(input_schema)?.metadata().clone())
320 .into(),
321 );
322 }
323 if self.has_grouping_set {
324 fields.push(
325 Field::new(
326 Aggregate::INTERNAL_GROUPING_ID,
327 Aggregate::grouping_id_type(self.expr.len()),
328 false,
329 )
330 .into(),
331 );
332 }
333 Ok(fields)
334 }
335
336 fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
341 let mut fields = self.group_fields(input_schema)?;
342 fields.truncate(self.num_output_exprs());
343 Ok(fields)
344 }
345
346 pub fn as_final(&self) -> PhysicalGroupBy {
349 let expr: Vec<_> =
350 self.output_exprs()
351 .into_iter()
352 .zip(
353 self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
354 Aggregate::INTERNAL_GROUPING_ID.to_owned(),
355 )),
356 )
357 .collect();
358 let num_exprs = expr.len();
359 let groups = if self.expr.is_empty() && !self.has_grouping_set {
360 vec![]
362 } else {
363 vec![vec![false; num_exprs]]
364 };
365 Self {
366 expr,
367 null_expr: vec![],
368 groups,
369 has_grouping_set: false,
370 }
371 }
372}
373
374impl PartialEq for PhysicalGroupBy {
375 fn eq(&self, other: &PhysicalGroupBy) -> bool {
376 self.expr.len() == other.expr.len()
377 && self
378 .expr
379 .iter()
380 .zip(other.expr.iter())
381 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
382 && self.null_expr.len() == other.null_expr.len()
383 && self
384 .null_expr
385 .iter()
386 .zip(other.null_expr.iter())
387 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
388 && self.groups == other.groups
389 && self.has_grouping_set == other.has_grouping_set
390 }
391}
392
393#[expect(clippy::large_enum_variant)]
394enum StreamType {
395 AggregateStream(AggregateStream),
396 GroupedHash(GroupedHashAggregateStream),
397 GroupedPriorityQueue(GroupedTopKAggregateStream),
398}
399
400impl From<StreamType> for SendableRecordBatchStream {
401 fn from(stream: StreamType) -> Self {
402 match stream {
403 StreamType::AggregateStream(stream) => Box::pin(stream),
404 StreamType::GroupedHash(stream) => Box::pin(stream),
405 StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
406 }
407 }
408}
409
410#[derive(Debug, Clone)]
454struct AggrDynFilter {
455 filter: Arc<DynamicFilterPhysicalExpr>,
458 supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
466}
467
468#[derive(Debug, Clone)]
473struct PerAccumulatorDynFilter {
474 aggr_type: DynamicFilterAggregateType,
475 aggr_index: usize,
481 shared_bound: Arc<Mutex<ScalarValue>>,
483}
484
485#[derive(Debug, Clone)]
487enum DynamicFilterAggregateType {
488 Min,
489 Max,
490}
491
492#[derive(Debug, Clone)]
494pub struct AggregateExec {
495 mode: AggregateMode,
497 group_by: PhysicalGroupBy,
499 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
501 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
503 limit: Option<usize>,
505 pub input: Arc<dyn ExecutionPlan>,
507 schema: SchemaRef,
509 pub input_schema: SchemaRef,
515 metrics: ExecutionPlanMetricsSet,
517 required_input_ordering: Option<OrderingRequirements>,
518 input_order_mode: InputOrderMode,
520 cache: PlanProperties,
521 dynamic_filter: Option<Arc<AggrDynFilter>>,
528}
529
530impl AggregateExec {
531 pub fn with_new_aggr_exprs(
535 &self,
536 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
537 ) -> Self {
538 Self {
539 aggr_expr,
540 required_input_ordering: self.required_input_ordering.clone(),
542 metrics: ExecutionPlanMetricsSet::new(),
543 input_order_mode: self.input_order_mode.clone(),
544 cache: self.cache.clone(),
545 mode: self.mode,
546 group_by: self.group_by.clone(),
547 filter_expr: self.filter_expr.clone(),
548 limit: self.limit,
549 input: Arc::clone(&self.input),
550 schema: Arc::clone(&self.schema),
551 input_schema: Arc::clone(&self.input_schema),
552 dynamic_filter: self.dynamic_filter.clone(),
553 }
554 }
555
556 pub fn cache(&self) -> &PlanProperties {
557 &self.cache
558 }
559
560 pub fn try_new(
562 mode: AggregateMode,
563 group_by: PhysicalGroupBy,
564 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
565 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
566 input: Arc<dyn ExecutionPlan>,
567 input_schema: SchemaRef,
568 ) -> Result<Self> {
569 let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
570
571 let schema = Arc::new(schema);
572 AggregateExec::try_new_with_schema(
573 mode,
574 group_by,
575 aggr_expr,
576 filter_expr,
577 input,
578 input_schema,
579 schema,
580 )
581 }
582
583 fn try_new_with_schema(
592 mode: AggregateMode,
593 group_by: PhysicalGroupBy,
594 mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
595 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
596 input: Arc<dyn ExecutionPlan>,
597 input_schema: SchemaRef,
598 schema: SchemaRef,
599 ) -> Result<Self> {
600 assert_eq_or_internal_err!(
602 aggr_expr.len(),
603 filter_expr.len(),
604 "Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
605 aggr_expr,
606 filter_expr
607 );
608
609 let input_eq_properties = input.equivalence_properties();
610 let groupby_exprs = group_by.input_exprs();
612 let (new_sort_exprs, indices) =
617 input_eq_properties.find_longest_permutation(&groupby_exprs)?;
618
619 let mut new_requirements = new_sort_exprs
620 .into_iter()
621 .map(PhysicalSortRequirement::from)
622 .collect::<Vec<_>>();
623
624 let req = get_finer_aggregate_exprs_requirement(
625 &mut aggr_expr,
626 &group_by,
627 input_eq_properties,
628 &mode,
629 )?;
630 new_requirements.extend(req);
631
632 let required_input_ordering =
633 LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
634
635 let indices: Vec<usize> = indices
641 .into_iter()
642 .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
643 .collect();
644
645 let input_order_mode = if indices.len() == groupby_exprs.len()
646 && !indices.is_empty()
647 && group_by.groups.len() == 1
648 {
649 InputOrderMode::Sorted
650 } else if !indices.is_empty() {
651 InputOrderMode::PartiallySorted(indices)
652 } else {
653 InputOrderMode::Linear
654 };
655
656 let group_expr_mapping =
658 ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
659
660 let cache = Self::compute_properties(
661 &input,
662 Arc::clone(&schema),
663 &group_expr_mapping,
664 &mode,
665 &input_order_mode,
666 aggr_expr.as_slice(),
667 )?;
668
669 let mut exec = AggregateExec {
670 mode,
671 group_by,
672 aggr_expr,
673 filter_expr,
674 input,
675 schema,
676 input_schema,
677 metrics: ExecutionPlanMetricsSet::new(),
678 required_input_ordering,
679 limit: None,
680 input_order_mode,
681 cache,
682 dynamic_filter: None,
683 };
684
685 exec.init_dynamic_filter();
686
687 Ok(exec)
688 }
689
690 pub fn mode(&self) -> &AggregateMode {
692 &self.mode
693 }
694
695 pub fn with_limit(mut self, limit: Option<usize>) -> Self {
697 self.limit = limit;
698 self
699 }
700 pub fn group_expr(&self) -> &PhysicalGroupBy {
702 &self.group_by
703 }
704
705 pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
707 self.group_by.output_exprs()
708 }
709
710 pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
712 &self.aggr_expr
713 }
714
715 pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
717 &self.filter_expr
718 }
719
720 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
722 &self.input
723 }
724
725 pub fn input_schema(&self) -> SchemaRef {
727 Arc::clone(&self.input_schema)
728 }
729
730 pub fn limit(&self) -> Option<usize> {
732 self.limit
733 }
734
735 fn execute_typed(
736 &self,
737 partition: usize,
738 context: &Arc<TaskContext>,
739 ) -> Result<StreamType> {
740 if self.group_by.is_true_no_grouping() {
741 return Ok(StreamType::AggregateStream(AggregateStream::new(
742 self, context, partition,
743 )?));
744 }
745
746 if let Some(limit) = self.limit
748 && !self.is_unordered_unfiltered_group_by_distinct()
749 {
750 return Ok(StreamType::GroupedPriorityQueue(
751 GroupedTopKAggregateStream::new(self, context, partition, limit)?,
752 ));
753 }
754
755 Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
757 self, context, partition,
758 )?))
759 }
760
761 pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
763 let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
764 agg_expr.get_minmax_desc()
765 }
766
767 pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
772 if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
774 return false;
775 }
776 if !self.aggr_expr().is_empty() {
778 return false;
779 }
780 if self.filter_expr().iter().any(|e| e.is_some()) {
783 return false;
784 }
785 if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
787 return false;
788 }
789 if self.properties().output_ordering().is_some() {
791 return false;
792 }
793 if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
795 return matches!(requirement, OrderingRequirements::Hard(_));
796 }
797 true
798 }
799
800 pub fn compute_properties(
802 input: &Arc<dyn ExecutionPlan>,
803 schema: SchemaRef,
804 group_expr_mapping: &ProjectionMapping,
805 mode: &AggregateMode,
806 input_order_mode: &InputOrderMode,
807 aggr_exprs: &[Arc<AggregateFunctionExpr>],
808 ) -> Result<PlanProperties> {
809 let mut eq_properties = input
811 .equivalence_properties()
812 .project(group_expr_mapping, schema);
813
814 if group_expr_mapping.is_empty() {
817 let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
818 let column = Arc::new(Column::new(func.name(), idx));
819 ConstExpr::from(column as Arc<dyn PhysicalExpr>)
820 });
821 eq_properties.add_constants(new_constants)?;
822 }
823
824 let mut constraints = eq_properties.constraints().to_vec();
827 let new_constraint = Constraint::Unique(
828 group_expr_mapping
829 .iter()
830 .flat_map(|(_, target_cols)| {
831 target_cols.iter().flat_map(|(expr, _)| {
832 expr.as_any().downcast_ref::<Column>().map(|c| c.index())
833 })
834 })
835 .collect(),
836 );
837 constraints.push(new_constraint);
838 eq_properties =
839 eq_properties.with_constraints(Constraints::new_unverified(constraints));
840
841 let input_partitioning = input.output_partitioning().clone();
843 let output_partitioning = if mode.is_first_stage() {
844 let input_eq_properties = input.equivalence_properties();
848 input_partitioning.project(group_expr_mapping, input_eq_properties)
849 } else {
850 input_partitioning.clone()
851 };
852
853 let emission_type = if *input_order_mode == InputOrderMode::Linear {
855 EmissionType::Final
856 } else {
857 input.pipeline_behavior()
858 };
859
860 Ok(PlanProperties::new(
861 eq_properties,
862 output_partitioning,
863 emission_type,
864 input.boundedness(),
865 ))
866 }
867
868 pub fn input_order_mode(&self) -> &InputOrderMode {
869 &self.input_order_mode
870 }
871
872 fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
873 let column_statistics = {
880 let mut column_statistics = Statistics::unknown_column(&self.schema());
882
883 for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
884 if let Some(col) = expr.as_any().downcast_ref::<Column>() {
885 column_statistics[idx].max_value = child_statistics.column_statistics
886 [col.index()]
887 .max_value
888 .clone();
889
890 column_statistics[idx].min_value = child_statistics.column_statistics
891 [col.index()]
892 .min_value
893 .clone();
894 }
895 }
896
897 column_statistics
898 };
899 match self.mode {
900 AggregateMode::Final | AggregateMode::FinalPartitioned
901 if self.group_by.expr.is_empty() =>
902 {
903 let total_byte_size =
904 Self::calculate_scaled_byte_size(child_statistics, 1);
905
906 Ok(Statistics {
907 num_rows: Precision::Exact(1),
908 column_statistics,
909 total_byte_size,
910 })
911 }
912 _ => {
913 let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
916 {
917 if *value > 1 {
918 child_statistics.num_rows.to_inexact()
919 } else if *value == 0 {
920 child_statistics.num_rows
921 } else {
922 let grouping_set_num = self.group_by.groups.len();
924 child_statistics.num_rows.map(|x| x * grouping_set_num)
925 }
926 } else {
927 Precision::Absent
928 };
929
930 let total_byte_size = num_rows
931 .get_value()
932 .and_then(|&output_rows| {
933 Self::calculate_scaled_byte_size(child_statistics, output_rows)
934 .get_value()
935 .map(|&bytes| Precision::Inexact(bytes))
936 })
937 .unwrap_or(Precision::Absent);
938
939 Ok(Statistics {
940 num_rows,
941 column_statistics,
942 total_byte_size,
943 })
944 }
945 }
946 }
947
948 fn init_dynamic_filter(&mut self) {
952 if (!self.group_by.is_empty()) || (!matches!(self.mode, AggregateMode::Partial)) {
953 debug_assert!(
954 self.dynamic_filter.is_none(),
955 "The current operator node does not support dynamic filter"
956 );
957 return;
958 }
959
960 if self.dynamic_filter.is_some() {
962 return;
963 }
964
965 let mut aggr_dyn_filters = Vec::new();
969 let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
973 for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
974 let fun_name = aggr_expr.fun().name();
976 let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
979 DynamicFilterAggregateType::Min
980 } else if fun_name.eq_ignore_ascii_case("max") {
981 DynamicFilterAggregateType::Max
982 } else {
983 continue;
984 };
985
986 if let [arg] = aggr_expr.expressions().as_slice()
988 && arg.as_any().is::<Column>()
989 {
990 all_cols.push(Arc::clone(arg));
991 aggr_dyn_filters.push(PerAccumulatorDynFilter {
992 aggr_type,
993 aggr_index: i,
994 shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
995 });
996 }
997 }
998
999 if !aggr_dyn_filters.is_empty() {
1000 self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1001 filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1002 supported_accumulators_info: aggr_dyn_filters,
1003 }))
1004 }
1005 }
1006
1007 #[inline]
1013 fn calculate_scaled_byte_size(
1014 input_stats: &Statistics,
1015 target_row_count: usize,
1016 ) -> Precision<usize> {
1017 match (
1018 input_stats.num_rows.get_value(),
1019 input_stats.total_byte_size.get_value(),
1020 ) {
1021 (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
1022 let bytes_per_row = input_bytes as f64 / input_rows as f64;
1023 let scaled_bytes =
1024 (bytes_per_row * target_row_count as f64).ceil() as usize;
1025 Precision::Inexact(scaled_bytes)
1026 }
1027 _ => Precision::Absent,
1028 }
1029 }
1030}
1031
1032impl DisplayAs for AggregateExec {
1033 fn fmt_as(
1034 &self,
1035 t: DisplayFormatType,
1036 f: &mut std::fmt::Formatter,
1037 ) -> std::fmt::Result {
1038 match t {
1039 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1040 let format_expr_with_alias =
1041 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1042 let e = e.to_string();
1043 if &e != alias {
1044 format!("{e} as {alias}")
1045 } else {
1046 e
1047 }
1048 };
1049
1050 write!(f, "AggregateExec: mode={:?}", self.mode)?;
1051 let g: Vec<String> = if self.group_by.is_single() {
1052 self.group_by
1053 .expr
1054 .iter()
1055 .map(format_expr_with_alias)
1056 .collect()
1057 } else {
1058 self.group_by
1059 .groups
1060 .iter()
1061 .map(|group| {
1062 let terms = group
1063 .iter()
1064 .enumerate()
1065 .map(|(idx, is_null)| {
1066 if *is_null {
1067 format_expr_with_alias(
1068 &self.group_by.null_expr[idx],
1069 )
1070 } else {
1071 format_expr_with_alias(&self.group_by.expr[idx])
1072 }
1073 })
1074 .collect::<Vec<String>>()
1075 .join(", ");
1076 format!("({terms})")
1077 })
1078 .collect()
1079 };
1080
1081 write!(f, ", gby=[{}]", g.join(", "))?;
1082
1083 let a: Vec<String> = self
1084 .aggr_expr
1085 .iter()
1086 .map(|agg| agg.name().to_string())
1087 .collect();
1088 write!(f, ", aggr=[{}]", a.join(", "))?;
1089 if let Some(limit) = self.limit {
1090 write!(f, ", lim=[{limit}]")?;
1091 }
1092
1093 if self.input_order_mode != InputOrderMode::Linear {
1094 write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
1095 }
1096 }
1097 DisplayFormatType::TreeRender => {
1098 let format_expr_with_alias =
1099 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1100 let expr_sql = fmt_sql(e.as_ref()).to_string();
1101 if &expr_sql != alias {
1102 format!("{expr_sql} as {alias}")
1103 } else {
1104 expr_sql
1105 }
1106 };
1107
1108 let g: Vec<String> = if self.group_by.is_single() {
1109 self.group_by
1110 .expr
1111 .iter()
1112 .map(format_expr_with_alias)
1113 .collect()
1114 } else {
1115 self.group_by
1116 .groups
1117 .iter()
1118 .map(|group| {
1119 let terms = group
1120 .iter()
1121 .enumerate()
1122 .map(|(idx, is_null)| {
1123 if *is_null {
1124 format_expr_with_alias(
1125 &self.group_by.null_expr[idx],
1126 )
1127 } else {
1128 format_expr_with_alias(&self.group_by.expr[idx])
1129 }
1130 })
1131 .collect::<Vec<String>>()
1132 .join(", ");
1133 format!("({terms})")
1134 })
1135 .collect()
1136 };
1137 let a: Vec<String> = self
1138 .aggr_expr
1139 .iter()
1140 .map(|agg| agg.human_display().to_string())
1141 .collect();
1142 writeln!(f, "mode={:?}", self.mode)?;
1143 if !g.is_empty() {
1144 writeln!(f, "group_by={}", g.join(", "))?;
1145 }
1146 if !a.is_empty() {
1147 writeln!(f, "aggr={}", a.join(", "))?;
1148 }
1149 }
1150 }
1151 Ok(())
1152 }
1153}
1154
1155impl ExecutionPlan for AggregateExec {
1156 fn name(&self) -> &'static str {
1157 "AggregateExec"
1158 }
1159
1160 fn as_any(&self) -> &dyn Any {
1162 self
1163 }
1164
1165 fn properties(&self) -> &PlanProperties {
1166 &self.cache
1167 }
1168
1169 fn required_input_distribution(&self) -> Vec<Distribution> {
1170 match &self.mode {
1171 AggregateMode::Partial => {
1172 vec![Distribution::UnspecifiedDistribution]
1173 }
1174 AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
1175 vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
1176 }
1177 AggregateMode::Final | AggregateMode::Single => {
1178 vec![Distribution::SinglePartition]
1179 }
1180 }
1181 }
1182
1183 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
1184 vec![self.required_input_ordering.clone()]
1185 }
1186
1187 fn maintains_input_order(&self) -> Vec<bool> {
1197 vec![self.input_order_mode != InputOrderMode::Linear]
1198 }
1199
1200 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1201 vec![&self.input]
1202 }
1203
1204 fn with_new_children(
1205 self: Arc<Self>,
1206 children: Vec<Arc<dyn ExecutionPlan>>,
1207 ) -> Result<Arc<dyn ExecutionPlan>> {
1208 let mut me = AggregateExec::try_new_with_schema(
1209 self.mode,
1210 self.group_by.clone(),
1211 self.aggr_expr.clone(),
1212 self.filter_expr.clone(),
1213 Arc::clone(&children[0]),
1214 Arc::clone(&self.input_schema),
1215 Arc::clone(&self.schema),
1216 )?;
1217 me.limit = self.limit;
1218 me.dynamic_filter = self.dynamic_filter.clone();
1219
1220 Ok(Arc::new(me))
1221 }
1222
1223 fn execute(
1224 &self,
1225 partition: usize,
1226 context: Arc<TaskContext>,
1227 ) -> Result<SendableRecordBatchStream> {
1228 self.execute_typed(partition, &context)
1229 .map(|stream| stream.into())
1230 }
1231
1232 fn metrics(&self) -> Option<MetricsSet> {
1233 Some(self.metrics.clone_inner())
1234 }
1235
1236 fn statistics(&self) -> Result<Statistics> {
1237 self.partition_statistics(None)
1238 }
1239
1240 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1241 let child_statistics = self.input().partition_statistics(partition)?;
1242 self.statistics_inner(&child_statistics)
1243 }
1244
1245 fn cardinality_effect(&self) -> CardinalityEffect {
1246 CardinalityEffect::LowerEqual
1247 }
1248
1249 fn gather_filters_for_pushdown(
1252 &self,
1253 phase: FilterPushdownPhase,
1254 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1255 config: &ConfigOptions,
1256 ) -> Result<FilterDescription> {
1257 let grouping_columns: HashSet<_> = self
1266 .group_by
1267 .expr()
1268 .iter()
1269 .flat_map(|(expr, _)| collect_columns(expr))
1270 .collect();
1271
1272 let mut safe_filters = Vec::new();
1274 let mut unsafe_filters = Vec::new();
1275
1276 for filter in parent_filters {
1277 let filter_columns: HashSet<_> =
1278 collect_columns(&filter).into_iter().collect();
1279
1280 let references_non_grouping = !grouping_columns.is_empty()
1282 && !filter_columns.is_subset(&grouping_columns);
1283
1284 if references_non_grouping {
1285 unsafe_filters.push(filter);
1286 continue;
1287 }
1288
1289 if self.group_by.groups().len() > 1 {
1291 let filter_column_indices: Vec<usize> = filter_columns
1292 .iter()
1293 .filter_map(|filter_col| {
1294 self.group_by.expr().iter().position(|(expr, _)| {
1295 collect_columns(expr).contains(filter_col)
1296 })
1297 })
1298 .collect();
1299
1300 let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1302 filter_column_indices
1303 .iter()
1304 .any(|&idx| null_mask.get(idx) == Some(&true))
1305 });
1306
1307 if has_missing_column {
1308 unsafe_filters.push(filter);
1309 continue;
1310 }
1311 }
1312
1313 safe_filters.push(filter);
1315 }
1316
1317 let child = self.children()[0];
1319 let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1320
1321 child_desc.parent_filters.extend(
1323 unsafe_filters
1324 .into_iter()
1325 .map(PushedDownPredicate::unsupported),
1326 );
1327
1328 if matches!(phase, FilterPushdownPhase::Post)
1330 && config.optimizer.enable_aggregate_dynamic_filter_pushdown
1331 && let Some(self_dyn_filter) = &self.dynamic_filter
1332 {
1333 let dyn_filter = Arc::clone(&self_dyn_filter.filter);
1334 child_desc = child_desc.with_self_filter(dyn_filter);
1335 }
1336
1337 Ok(FilterDescription::new().with_child(child_desc))
1338 }
1339
1340 fn handle_child_pushdown_result(
1343 &self,
1344 phase: FilterPushdownPhase,
1345 child_pushdown_result: ChildPushdownResult,
1346 _config: &ConfigOptions,
1347 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1348 let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
1349
1350 if matches!(phase, FilterPushdownPhase::Post) && self.dynamic_filter.is_some() {
1353 let dyn_filter = self.dynamic_filter.as_ref().unwrap();
1374 let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
1375
1376 if !child_accepts_dyn_filter {
1377 let mut new_node = self.clone();
1380 new_node.dynamic_filter = None;
1381
1382 result = result
1383 .with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
1384 }
1385 }
1386
1387 Ok(result)
1388 }
1389}
1390
1391fn create_schema(
1392 input_schema: &Schema,
1393 group_by: &PhysicalGroupBy,
1394 aggr_expr: &[Arc<AggregateFunctionExpr>],
1395 mode: AggregateMode,
1396) -> Result<Schema> {
1397 let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1398 fields.extend(group_by.output_fields(input_schema)?);
1399
1400 match mode {
1401 AggregateMode::Partial => {
1402 for expr in aggr_expr {
1404 fields.extend(expr.state_fields()?.iter().cloned());
1405 }
1406 }
1407 AggregateMode::Final
1408 | AggregateMode::FinalPartitioned
1409 | AggregateMode::Single
1410 | AggregateMode::SinglePartitioned => {
1411 for expr in aggr_expr {
1413 fields.push(expr.field())
1414 }
1415 }
1416 }
1417
1418 Ok(Schema::new_with_metadata(
1419 fields,
1420 input_schema.metadata().clone(),
1421 ))
1422}
1423
1424fn get_aggregate_expr_req(
1445 aggr_expr: &AggregateFunctionExpr,
1446 group_by: &PhysicalGroupBy,
1447 agg_mode: &AggregateMode,
1448 include_soft_requirement: bool,
1449) -> Option<LexOrdering> {
1450 if !agg_mode.is_first_stage() {
1454 return None;
1455 }
1456
1457 match aggr_expr.order_sensitivity() {
1458 AggregateOrderSensitivity::Insensitive => return None,
1459 AggregateOrderSensitivity::HardRequirement => {}
1460 AggregateOrderSensitivity::SoftRequirement => {
1461 if !include_soft_requirement {
1462 return None;
1463 }
1464 }
1465 AggregateOrderSensitivity::Beneficial => return None,
1466 }
1467
1468 let mut sort_exprs = aggr_expr.order_bys().to_vec();
1469 if group_by.is_single() {
1475 let physical_exprs = group_by.input_exprs();
1479 sort_exprs.retain(|sort_expr| {
1480 !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1481 });
1482 }
1483 LexOrdering::new(sort_exprs)
1484}
1485
1486pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1488 [lhs, rhs].concat()
1489}
1490
1491fn determine_finer(
1495 current: &Option<LexOrdering>,
1496 candidate: &LexOrdering,
1497) -> Option<bool> {
1498 if let Some(ordering) = current {
1499 candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1500 } else {
1501 Some(true)
1502 }
1503}
1504
1505pub fn get_finer_aggregate_exprs_requirement(
1526 aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1527 group_by: &PhysicalGroupBy,
1528 eq_properties: &EquivalenceProperties,
1529 agg_mode: &AggregateMode,
1530) -> Result<Vec<PhysicalSortRequirement>> {
1531 let mut requirement = None;
1532
1533 for include_soft_requirement in [false, true] {
1537 for aggr_expr in aggr_exprs.iter_mut() {
1538 let Some(aggr_req) = get_aggregate_expr_req(
1539 aggr_expr,
1540 group_by,
1541 agg_mode,
1542 include_soft_requirement,
1543 )
1544 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1545 continue;
1548 };
1549 let forward_finer = determine_finer(&requirement, &aggr_req);
1554 if let Some(finer) = forward_finer {
1555 if !finer {
1556 continue;
1557 } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1558 requirement = Some(aggr_req);
1559 continue;
1560 }
1561 }
1562 if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1563 let Some(rev_aggr_req) = get_aggregate_expr_req(
1564 &reverse_aggr_expr,
1565 group_by,
1566 agg_mode,
1567 include_soft_requirement,
1568 )
1569 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1570 *aggr_expr = Arc::new(reverse_aggr_expr);
1573 continue;
1574 };
1575 if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1581 if !finer {
1582 *aggr_expr = Arc::new(reverse_aggr_expr);
1583 } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1584 *aggr_expr = Arc::new(reverse_aggr_expr);
1585 requirement = Some(rev_aggr_req);
1586 } else {
1587 requirement = Some(aggr_req);
1588 }
1589 } else if forward_finer.is_some() {
1590 requirement = Some(aggr_req);
1591 } else {
1592 if !include_soft_requirement {
1597 return not_impl_err!(
1598 "Conflicting ordering requirements in aggregate functions is not supported"
1599 );
1600 }
1601 }
1602 }
1603 }
1604 }
1605
1606 Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1607}
1608
1609pub fn aggregate_expressions(
1615 aggr_expr: &[Arc<AggregateFunctionExpr>],
1616 mode: &AggregateMode,
1617 col_idx_base: usize,
1618) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1619 match mode {
1620 AggregateMode::Partial
1621 | AggregateMode::Single
1622 | AggregateMode::SinglePartitioned => Ok(aggr_expr
1623 .iter()
1624 .map(|agg| {
1625 let mut result = agg.expressions();
1626 result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
1630 result
1631 })
1632 .collect()),
1633 AggregateMode::Final | AggregateMode::FinalPartitioned => {
1635 let mut col_idx_base = col_idx_base;
1636 aggr_expr
1637 .iter()
1638 .map(|agg| {
1639 let exprs = merge_expressions(col_idx_base, agg)?;
1640 col_idx_base += exprs.len();
1641 Ok(exprs)
1642 })
1643 .collect()
1644 }
1645 }
1646}
1647
1648fn merge_expressions(
1653 index_base: usize,
1654 expr: &AggregateFunctionExpr,
1655) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1656 expr.state_fields().map(|fields| {
1657 fields
1658 .iter()
1659 .enumerate()
1660 .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1661 .collect()
1662 })
1663}
1664
1665pub type AccumulatorItem = Box<dyn Accumulator>;
1666
1667pub fn create_accumulators(
1668 aggr_expr: &[Arc<AggregateFunctionExpr>],
1669) -> Result<Vec<AccumulatorItem>> {
1670 aggr_expr
1671 .iter()
1672 .map(|expr| expr.create_accumulator())
1673 .collect()
1674}
1675
1676pub fn finalize_aggregation(
1679 accumulators: &mut [AccumulatorItem],
1680 mode: &AggregateMode,
1681) -> Result<Vec<ArrayRef>> {
1682 match mode {
1683 AggregateMode::Partial => {
1684 accumulators
1686 .iter_mut()
1687 .map(|accumulator| {
1688 accumulator.state().and_then(|e| {
1689 e.iter()
1690 .map(|v| v.to_array())
1691 .collect::<Result<Vec<ArrayRef>>>()
1692 })
1693 })
1694 .flatten_ok()
1695 .collect()
1696 }
1697 AggregateMode::Final
1698 | AggregateMode::FinalPartitioned
1699 | AggregateMode::Single
1700 | AggregateMode::SinglePartitioned => {
1701 accumulators
1703 .iter_mut()
1704 .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1705 .collect()
1706 }
1707 }
1708}
1709
1710pub fn evaluate_many(
1712 expr: &[Vec<Arc<dyn PhysicalExpr>>],
1713 batch: &RecordBatch,
1714) -> Result<Vec<Vec<ArrayRef>>> {
1715 expr.iter()
1716 .map(|expr| evaluate_expressions_to_arrays(expr, batch))
1717 .collect()
1718}
1719
1720fn evaluate_optional(
1721 expr: &[Option<Arc<dyn PhysicalExpr>>],
1722 batch: &RecordBatch,
1723) -> Result<Vec<Option<ArrayRef>>> {
1724 expr.iter()
1725 .map(|expr| {
1726 expr.as_ref()
1727 .map(|expr| {
1728 expr.evaluate(batch)
1729 .and_then(|v| v.into_array(batch.num_rows()))
1730 })
1731 .transpose()
1732 })
1733 .collect()
1734}
1735
1736fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1737 if group.len() > 64 {
1738 return not_impl_err!(
1739 "Grouping sets with more than 64 columns are not supported"
1740 );
1741 }
1742 let group_id = group.iter().fold(0u64, |acc, &is_null| {
1743 (acc << 1) | if is_null { 1 } else { 0 }
1744 });
1745 let num_rows = batch.num_rows();
1746 if group.len() <= 8 {
1747 Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1748 } else if group.len() <= 16 {
1749 Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1750 } else if group.len() <= 32 {
1751 Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1752 } else {
1753 Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1754 }
1755}
1756
1757pub fn evaluate_group_by(
1768 group_by: &PhysicalGroupBy,
1769 batch: &RecordBatch,
1770) -> Result<Vec<Vec<ArrayRef>>> {
1771 let exprs = evaluate_expressions_to_arrays(
1772 group_by.expr.iter().map(|(expr, _)| expr),
1773 batch,
1774 )?;
1775 let null_exprs = evaluate_expressions_to_arrays(
1776 group_by.null_expr.iter().map(|(expr, _)| expr),
1777 batch,
1778 )?;
1779
1780 group_by
1781 .groups
1782 .iter()
1783 .map(|group| {
1784 let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1785 group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1786 if *is_null {
1787 Arc::clone(&null_exprs[idx])
1788 } else {
1789 Arc::clone(&exprs[idx])
1790 }
1791 }));
1792 if !group_by.is_single() {
1793 group_values.push(group_id_array(group, batch)?);
1794 }
1795 Ok(group_values)
1796 })
1797 .collect()
1798}
1799
1800#[cfg(test)]
1801mod tests {
1802 use std::task::{Context, Poll};
1803
1804 use super::*;
1805 use crate::RecordBatchStream;
1806 use crate::coalesce_batches::CoalesceBatchesExec;
1807 use crate::coalesce_partitions::CoalescePartitionsExec;
1808 use crate::common;
1809 use crate::common::collect;
1810 use crate::execution_plan::Boundedness;
1811 use crate::expressions::col;
1812 use crate::metrics::MetricValue;
1813 use crate::test::TestMemoryExec;
1814 use crate::test::assert_is_pending;
1815 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
1816
1817 use arrow::array::{
1818 DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
1819 UInt32Array, UInt64Array,
1820 };
1821 use arrow::compute::{SortOptions, concat_batches};
1822 use arrow::datatypes::{DataType, Int32Type};
1823 use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
1824 use datafusion_common::{DataFusionError, ScalarValue, internal_err};
1825 use datafusion_execution::config::SessionConfig;
1826 use datafusion_execution::memory_pool::FairSpillPool;
1827 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1828 use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1829 use datafusion_functions_aggregate::average::avg_udaf;
1830 use datafusion_functions_aggregate::count::count_udaf;
1831 use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1832 use datafusion_functions_aggregate::median::median_udaf;
1833 use datafusion_functions_aggregate::sum::sum_udaf;
1834 use datafusion_physical_expr::Partitioning;
1835 use datafusion_physical_expr::PhysicalSortExpr;
1836 use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1837 use datafusion_physical_expr::expressions::Literal;
1838 use datafusion_physical_expr::expressions::lit;
1839
1840 use crate::projection::ProjectionExec;
1841 use datafusion_physical_expr::projection::ProjectionExpr;
1842 use futures::{FutureExt, Stream};
1843 use insta::{allow_duplicates, assert_snapshot};
1844
1845 fn create_test_schema() -> Result<SchemaRef> {
1847 let a = Field::new("a", DataType::Int32, true);
1848 let b = Field::new("b", DataType::Int32, true);
1849 let c = Field::new("c", DataType::Int32, true);
1850 let d = Field::new("d", DataType::Int32, true);
1851 let e = Field::new("e", DataType::Int32, true);
1852 let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1853
1854 Ok(schema)
1855 }
1856
1857 fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1859 let schema = Arc::new(Schema::new(vec![
1861 Field::new("a", DataType::UInt32, false),
1862 Field::new("b", DataType::Float64, false),
1863 ]));
1864
1865 (
1867 Arc::clone(&schema),
1868 vec![
1869 RecordBatch::try_new(
1870 Arc::clone(&schema),
1871 vec![
1872 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1873 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1874 ],
1875 )
1876 .unwrap(),
1877 RecordBatch::try_new(
1878 schema,
1879 vec![
1880 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1881 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1882 ],
1883 )
1884 .unwrap(),
1885 ],
1886 )
1887 }
1888
1889 fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1891 let schema = Arc::new(Schema::new(vec![
1893 Field::new("a", DataType::UInt32, false),
1894 Field::new("b", DataType::Float64, false),
1895 ]));
1896
1897 (
1902 Arc::clone(&schema),
1903 vec![
1904 RecordBatch::try_new(
1905 Arc::clone(&schema),
1906 vec![
1907 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1908 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1909 ],
1910 )
1911 .unwrap(),
1912 RecordBatch::try_new(
1913 Arc::clone(&schema),
1914 vec![
1915 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1916 Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1917 ],
1918 )
1919 .unwrap(),
1920 RecordBatch::try_new(
1921 Arc::clone(&schema),
1922 vec![
1923 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1924 Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1925 ],
1926 )
1927 .unwrap(),
1928 RecordBatch::try_new(
1929 schema,
1930 vec![
1931 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1932 Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1933 ],
1934 )
1935 .unwrap(),
1936 ],
1937 )
1938 }
1939
1940 fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1941 let session_config = SessionConfig::new().with_batch_size(batch_size);
1942 let runtime = RuntimeEnvBuilder::new()
1943 .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1944 .build_arc()
1945 .unwrap();
1946 let task_ctx = TaskContext::default()
1947 .with_session_config(session_config)
1948 .with_runtime(runtime);
1949 Arc::new(task_ctx)
1950 }
1951
1952 async fn check_grouping_sets(
1953 input: Arc<dyn ExecutionPlan>,
1954 spill: bool,
1955 ) -> Result<()> {
1956 let input_schema = input.schema();
1957
1958 let grouping_set = PhysicalGroupBy::new(
1959 vec![
1960 (col("a", &input_schema)?, "a".to_string()),
1961 (col("b", &input_schema)?, "b".to_string()),
1962 ],
1963 vec![
1964 (lit(ScalarValue::UInt32(None)), "a".to_string()),
1965 (lit(ScalarValue::Float64(None)), "b".to_string()),
1966 ],
1967 vec![
1968 vec![false, true], vec![true, false], vec![false, false], ],
1972 true,
1973 );
1974
1975 let aggregates = vec![Arc::new(
1976 AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1977 .schema(Arc::clone(&input_schema))
1978 .alias("COUNT(1)")
1979 .build()?,
1980 )];
1981
1982 let task_ctx = if spill {
1983 new_spill_ctx(4, 500)
1985 } else {
1986 Arc::new(TaskContext::default())
1987 };
1988
1989 let partial_aggregate = Arc::new(AggregateExec::try_new(
1990 AggregateMode::Partial,
1991 grouping_set.clone(),
1992 aggregates.clone(),
1993 vec![None],
1994 input,
1995 Arc::clone(&input_schema),
1996 )?);
1997
1998 let result =
1999 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2000
2001 if spill {
2002 allow_duplicates! {
2005 assert_snapshot!(batches_to_sort_string(&result),
2006 @r"
2007 +---+-----+---------------+-----------------+
2008 | a | b | __grouping_id | COUNT(1)[count] |
2009 +---+-----+---------------+-----------------+
2010 | | 1.0 | 2 | 1 |
2011 | | 1.0 | 2 | 1 |
2012 | | 2.0 | 2 | 1 |
2013 | | 2.0 | 2 | 1 |
2014 | | 3.0 | 2 | 1 |
2015 | | 3.0 | 2 | 1 |
2016 | | 4.0 | 2 | 1 |
2017 | | 4.0 | 2 | 1 |
2018 | 2 | | 1 | 1 |
2019 | 2 | | 1 | 1 |
2020 | 2 | 1.0 | 0 | 1 |
2021 | 2 | 1.0 | 0 | 1 |
2022 | 3 | | 1 | 1 |
2023 | 3 | | 1 | 2 |
2024 | 3 | 2.0 | 0 | 2 |
2025 | 3 | 3.0 | 0 | 1 |
2026 | 4 | | 1 | 1 |
2027 | 4 | | 1 | 2 |
2028 | 4 | 3.0 | 0 | 1 |
2029 | 4 | 4.0 | 0 | 2 |
2030 +---+-----+---------------+-----------------+
2031 "
2032 );
2033 }
2034 } else {
2035 allow_duplicates! {
2036 assert_snapshot!(batches_to_sort_string(&result),
2037 @r"
2038 +---+-----+---------------+-----------------+
2039 | a | b | __grouping_id | COUNT(1)[count] |
2040 +---+-----+---------------+-----------------+
2041 | | 1.0 | 2 | 2 |
2042 | | 2.0 | 2 | 2 |
2043 | | 3.0 | 2 | 2 |
2044 | | 4.0 | 2 | 2 |
2045 | 2 | | 1 | 2 |
2046 | 2 | 1.0 | 0 | 2 |
2047 | 3 | | 1 | 3 |
2048 | 3 | 2.0 | 0 | 2 |
2049 | 3 | 3.0 | 0 | 1 |
2050 | 4 | | 1 | 3 |
2051 | 4 | 3.0 | 0 | 1 |
2052 | 4 | 4.0 | 0 | 2 |
2053 +---+-----+---------------+-----------------+
2054 "
2055 );
2056 }
2057 };
2058
2059 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2060
2061 let final_grouping_set = grouping_set.as_final();
2062
2063 let task_ctx = if spill {
2064 new_spill_ctx(4, 3160)
2065 } else {
2066 task_ctx
2067 };
2068
2069 let merged_aggregate = Arc::new(AggregateExec::try_new(
2070 AggregateMode::Final,
2071 final_grouping_set,
2072 aggregates,
2073 vec![None],
2074 merge,
2075 input_schema,
2076 )?);
2077
2078 let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2079 let batch = concat_batches(&result[0].schema(), &result)?;
2080 assert_eq!(batch.num_columns(), 4);
2081 assert_eq!(batch.num_rows(), 12);
2082
2083 allow_duplicates! {
2084 assert_snapshot!(
2085 batches_to_sort_string(&result),
2086 @r"
2087 +---+-----+---------------+----------+
2088 | a | b | __grouping_id | COUNT(1) |
2089 +---+-----+---------------+----------+
2090 | | 1.0 | 2 | 2 |
2091 | | 2.0 | 2 | 2 |
2092 | | 3.0 | 2 | 2 |
2093 | | 4.0 | 2 | 2 |
2094 | 2 | | 1 | 2 |
2095 | 2 | 1.0 | 0 | 2 |
2096 | 3 | | 1 | 3 |
2097 | 3 | 2.0 | 0 | 2 |
2098 | 3 | 3.0 | 0 | 1 |
2099 | 4 | | 1 | 3 |
2100 | 4 | 3.0 | 0 | 1 |
2101 | 4 | 4.0 | 0 | 2 |
2102 +---+-----+---------------+----------+
2103 "
2104 );
2105 }
2106
2107 let metrics = merged_aggregate.metrics().unwrap();
2108 let output_rows = metrics.output_rows().unwrap();
2109 assert_eq!(12, output_rows);
2110
2111 Ok(())
2112 }
2113
2114 async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
2116 let input_schema = input.schema();
2117
2118 let grouping_set = PhysicalGroupBy::new(
2119 vec![(col("a", &input_schema)?, "a".to_string())],
2120 vec![],
2121 vec![vec![false]],
2122 false,
2123 );
2124
2125 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2126 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2127 .schema(Arc::clone(&input_schema))
2128 .alias("AVG(b)")
2129 .build()?,
2130 )];
2131
2132 let task_ctx = if spill {
2133 new_spill_ctx(2, 1600)
2135 } else {
2136 Arc::new(TaskContext::default())
2137 };
2138
2139 let partial_aggregate = Arc::new(AggregateExec::try_new(
2140 AggregateMode::Partial,
2141 grouping_set.clone(),
2142 aggregates.clone(),
2143 vec![None],
2144 input,
2145 Arc::clone(&input_schema),
2146 )?);
2147
2148 let result =
2149 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2150
2151 if spill {
2152 allow_duplicates! {
2153 assert_snapshot!(batches_to_sort_string(&result), @r"
2154 +---+---------------+-------------+
2155 | a | AVG(b)[count] | AVG(b)[sum] |
2156 +---+---------------+-------------+
2157 | 2 | 1 | 1.0 |
2158 | 2 | 1 | 1.0 |
2159 | 3 | 1 | 2.0 |
2160 | 3 | 2 | 5.0 |
2161 | 4 | 3 | 11.0 |
2162 +---+---------------+-------------+
2163 ");
2164 }
2165 } else {
2166 allow_duplicates! {
2167 assert_snapshot!(batches_to_sort_string(&result), @r"
2168 +---+---------------+-------------+
2169 | a | AVG(b)[count] | AVG(b)[sum] |
2170 +---+---------------+-------------+
2171 | 2 | 2 | 2.0 |
2172 | 3 | 3 | 7.0 |
2173 | 4 | 3 | 11.0 |
2174 +---+---------------+-------------+
2175 ");
2176 }
2177 };
2178
2179 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2180
2181 let final_grouping_set = grouping_set.as_final();
2182
2183 let merged_aggregate = Arc::new(AggregateExec::try_new(
2184 AggregateMode::Final,
2185 final_grouping_set,
2186 aggregates,
2187 vec![None],
2188 merge,
2189 input_schema,
2190 )?);
2191
2192 let final_stats = merged_aggregate.partition_statistics(None)?;
2194 assert!(final_stats.total_byte_size.get_value().is_some());
2195
2196 let task_ctx = if spill {
2197 new_spill_ctx(2, 2600)
2199 } else {
2200 Arc::clone(&task_ctx)
2201 };
2202 let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
2203 let batch = concat_batches(&result[0].schema(), &result)?;
2204 assert_eq!(batch.num_columns(), 2);
2205 assert_eq!(batch.num_rows(), 3);
2206
2207 allow_duplicates! {
2208 assert_snapshot!(batches_to_sort_string(&result), @r"
2209 +---+--------------------+
2210 | a | AVG(b) |
2211 +---+--------------------+
2212 | 2 | 1.0 |
2213 | 3 | 2.3333333333333335 |
2214 | 4 | 3.6666666666666665 |
2215 +---+--------------------+
2216 ");
2217 }
2220
2221 let metrics = merged_aggregate.metrics().unwrap();
2222 let output_rows = metrics.output_rows().unwrap();
2223 let spill_count = metrics.spill_count().unwrap();
2224 let spilled_bytes = metrics.spilled_bytes().unwrap();
2225 let spilled_rows = metrics.spilled_rows().unwrap();
2226
2227 if spill {
2228 assert_eq!(8, output_rows);
2231
2232 assert!(spill_count > 0);
2233 assert!(spilled_bytes > 0);
2234 assert!(spilled_rows > 0);
2235 } else {
2236 assert_eq!(3, output_rows);
2237
2238 assert_eq!(0, spill_count);
2239 assert_eq!(0, spilled_bytes);
2240 assert_eq!(0, spilled_rows);
2241 }
2242
2243 Ok(())
2244 }
2245
2246 #[derive(Debug)]
2249 struct TestYieldingExec {
2250 pub yield_first: bool,
2252 cache: PlanProperties,
2253 }
2254
2255 impl TestYieldingExec {
2256 fn new(yield_first: bool) -> Self {
2257 let schema = some_data().0;
2258 let cache = Self::compute_properties(schema);
2259 Self { yield_first, cache }
2260 }
2261
2262 fn compute_properties(schema: SchemaRef) -> PlanProperties {
2264 PlanProperties::new(
2265 EquivalenceProperties::new(schema),
2266 Partitioning::UnknownPartitioning(1),
2267 EmissionType::Incremental,
2268 Boundedness::Bounded,
2269 )
2270 }
2271 }
2272
2273 impl DisplayAs for TestYieldingExec {
2274 fn fmt_as(
2275 &self,
2276 t: DisplayFormatType,
2277 f: &mut std::fmt::Formatter,
2278 ) -> std::fmt::Result {
2279 match t {
2280 DisplayFormatType::Default | DisplayFormatType::Verbose => {
2281 write!(f, "TestYieldingExec")
2282 }
2283 DisplayFormatType::TreeRender => {
2284 write!(f, "")
2286 }
2287 }
2288 }
2289 }
2290
2291 impl ExecutionPlan for TestYieldingExec {
2292 fn name(&self) -> &'static str {
2293 "TestYieldingExec"
2294 }
2295
2296 fn as_any(&self) -> &dyn Any {
2297 self
2298 }
2299
2300 fn properties(&self) -> &PlanProperties {
2301 &self.cache
2302 }
2303
2304 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2305 vec![]
2306 }
2307
2308 fn with_new_children(
2309 self: Arc<Self>,
2310 _: Vec<Arc<dyn ExecutionPlan>>,
2311 ) -> Result<Arc<dyn ExecutionPlan>> {
2312 internal_err!("Children cannot be replaced in {self:?}")
2313 }
2314
2315 fn execute(
2316 &self,
2317 _partition: usize,
2318 _context: Arc<TaskContext>,
2319 ) -> Result<SendableRecordBatchStream> {
2320 let stream = if self.yield_first {
2321 TestYieldingStream::New
2322 } else {
2323 TestYieldingStream::Yielded
2324 };
2325
2326 Ok(Box::pin(stream))
2327 }
2328
2329 fn statistics(&self) -> Result<Statistics> {
2330 self.partition_statistics(None)
2331 }
2332
2333 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
2334 if partition.is_some() {
2335 return Ok(Statistics::new_unknown(self.schema().as_ref()));
2336 }
2337 let (_, batches) = some_data();
2338 Ok(common::compute_record_batch_statistics(
2339 &[batches],
2340 &self.schema(),
2341 None,
2342 ))
2343 }
2344 }
2345
2346 enum TestYieldingStream {
2348 New,
2349 Yielded,
2350 ReturnedBatch1,
2351 ReturnedBatch2,
2352 }
2353
2354 impl Stream for TestYieldingStream {
2355 type Item = Result<RecordBatch>;
2356
2357 fn poll_next(
2358 mut self: std::pin::Pin<&mut Self>,
2359 cx: &mut Context<'_>,
2360 ) -> Poll<Option<Self::Item>> {
2361 match &*self {
2362 TestYieldingStream::New => {
2363 *(self.as_mut()) = TestYieldingStream::Yielded;
2364 cx.waker().wake_by_ref();
2365 Poll::Pending
2366 }
2367 TestYieldingStream::Yielded => {
2368 *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2369 Poll::Ready(Some(Ok(some_data().1[0].clone())))
2370 }
2371 TestYieldingStream::ReturnedBatch1 => {
2372 *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2373 Poll::Ready(Some(Ok(some_data().1[1].clone())))
2374 }
2375 TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2376 }
2377 }
2378 }
2379
2380 impl RecordBatchStream for TestYieldingStream {
2381 fn schema(&self) -> SchemaRef {
2382 some_data().0
2383 }
2384 }
2385
2386 #[tokio::test]
2389 async fn aggregate_source_not_yielding() -> Result<()> {
2390 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2391
2392 check_aggregates(input, false).await
2393 }
2394
2395 #[tokio::test]
2396 async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2397 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2398
2399 check_grouping_sets(input, false).await
2400 }
2401
2402 #[tokio::test]
2403 async fn aggregate_source_with_yielding() -> Result<()> {
2404 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2405
2406 check_aggregates(input, false).await
2407 }
2408
2409 #[tokio::test]
2410 async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2411 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2412
2413 check_grouping_sets(input, false).await
2414 }
2415
2416 #[tokio::test]
2417 async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2418 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2419
2420 check_aggregates(input, true).await
2421 }
2422
2423 #[tokio::test]
2424 async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2425 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2426
2427 check_grouping_sets(input, true).await
2428 }
2429
2430 #[tokio::test]
2431 async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2432 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2433
2434 check_aggregates(input, true).await
2435 }
2436
2437 #[tokio::test]
2438 async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2439 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2440
2441 check_grouping_sets(input, true).await
2442 }
2443
2444 fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2446 AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2447 .schema(schema)
2448 .alias("MEDIAN(a)")
2449 .build()
2450 }
2451
2452 #[tokio::test]
2453 async fn test_oom() -> Result<()> {
2454 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2455 let input_schema = input.schema();
2456
2457 let runtime = RuntimeEnvBuilder::new()
2458 .with_memory_limit(1, 1.0)
2459 .build_arc()?;
2460 let task_ctx = TaskContext::default().with_runtime(runtime);
2461 let task_ctx = Arc::new(task_ctx);
2462
2463 let groups_none = PhysicalGroupBy::default();
2464 let groups_some = PhysicalGroupBy::new(
2465 vec![(col("a", &input_schema)?, "a".to_string())],
2466 vec![],
2467 vec![vec![false]],
2468 false,
2469 );
2470
2471 let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2473 vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2474
2475 let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2477 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2478 .schema(Arc::clone(&input_schema))
2479 .alias("AVG(b)")
2480 .build()?,
2481 )];
2482
2483 for (version, groups, aggregates) in [
2484 (0, groups_none, aggregates_v0),
2485 (2, groups_some, aggregates_v2),
2486 ] {
2487 let n_aggr = aggregates.len();
2488 let partial_aggregate = Arc::new(AggregateExec::try_new(
2489 AggregateMode::Single,
2490 groups,
2491 aggregates,
2492 vec![None; n_aggr],
2493 Arc::clone(&input),
2494 Arc::clone(&input_schema),
2495 )?);
2496
2497 let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
2498
2499 match version {
2501 0 => {
2502 assert!(matches!(stream, StreamType::AggregateStream(_)));
2503 }
2504 1 => {
2505 assert!(matches!(stream, StreamType::GroupedHash(_)));
2506 }
2507 2 => {
2508 assert!(matches!(stream, StreamType::GroupedHash(_)));
2509 }
2510 _ => panic!("Unknown version: {version}"),
2511 }
2512
2513 let stream: SendableRecordBatchStream = stream.into();
2514 let err = collect(stream).await.unwrap_err();
2515
2516 let err = err.find_root();
2518 assert!(
2519 matches!(err, DataFusionError::ResourcesExhausted(_)),
2520 "Wrong error type: {err}",
2521 );
2522 }
2523
2524 Ok(())
2525 }
2526
2527 #[tokio::test]
2528 async fn test_drop_cancel_without_groups() -> Result<()> {
2529 let task_ctx = Arc::new(TaskContext::default());
2530 let schema =
2531 Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2532
2533 let groups = PhysicalGroupBy::default();
2534
2535 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2536 AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2537 .schema(Arc::clone(&schema))
2538 .alias("AVG(a)")
2539 .build()?,
2540 )];
2541
2542 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2543 let refs = blocking_exec.refs();
2544 let aggregate_exec = Arc::new(AggregateExec::try_new(
2545 AggregateMode::Partial,
2546 groups.clone(),
2547 aggregates.clone(),
2548 vec![None],
2549 blocking_exec,
2550 schema,
2551 )?);
2552
2553 let fut = crate::collect(aggregate_exec, task_ctx);
2554 let mut fut = fut.boxed();
2555
2556 assert_is_pending(&mut fut);
2557 drop(fut);
2558 assert_strong_count_converges_to_zero(refs).await;
2559
2560 Ok(())
2561 }
2562
2563 #[tokio::test]
2564 async fn test_drop_cancel_with_groups() -> Result<()> {
2565 let task_ctx = Arc::new(TaskContext::default());
2566 let schema = Arc::new(Schema::new(vec![
2567 Field::new("a", DataType::Float64, true),
2568 Field::new("b", DataType::Float64, true),
2569 ]));
2570
2571 let groups =
2572 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2573
2574 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2575 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2576 .schema(Arc::clone(&schema))
2577 .alias("AVG(b)")
2578 .build()?,
2579 )];
2580
2581 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2582 let refs = blocking_exec.refs();
2583 let aggregate_exec = Arc::new(AggregateExec::try_new(
2584 AggregateMode::Partial,
2585 groups,
2586 aggregates.clone(),
2587 vec![None],
2588 blocking_exec,
2589 schema,
2590 )?);
2591
2592 let fut = crate::collect(aggregate_exec, task_ctx);
2593 let mut fut = fut.boxed();
2594
2595 assert_is_pending(&mut fut);
2596 drop(fut);
2597 assert_strong_count_converges_to_zero(refs).await;
2598
2599 Ok(())
2600 }
2601
2602 #[tokio::test]
2603 async fn run_first_last_multi_partitions() -> Result<()> {
2604 for use_coalesce_batches in [false, true] {
2605 for is_first_acc in [false, true] {
2606 for spill in [false, true] {
2607 first_last_multi_partitions(
2608 use_coalesce_batches,
2609 is_first_acc,
2610 spill,
2611 4200,
2612 )
2613 .await?
2614 }
2615 }
2616 }
2617 Ok(())
2618 }
2619
2620 fn test_first_value_agg_expr(
2622 schema: &Schema,
2623 sort_options: SortOptions,
2624 ) -> Result<Arc<AggregateFunctionExpr>> {
2625 let order_bys = vec![PhysicalSortExpr {
2626 expr: col("b", schema)?,
2627 options: sort_options,
2628 }];
2629 let args = [col("b", schema)?];
2630
2631 AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2632 .order_by(order_bys)
2633 .schema(Arc::new(schema.clone()))
2634 .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2635 .build()
2636 .map(Arc::new)
2637 }
2638
2639 fn test_last_value_agg_expr(
2641 schema: &Schema,
2642 sort_options: SortOptions,
2643 ) -> Result<Arc<AggregateFunctionExpr>> {
2644 let order_bys = vec![PhysicalSortExpr {
2645 expr: col("b", schema)?,
2646 options: sort_options,
2647 }];
2648 let args = [col("b", schema)?];
2649 AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2650 .order_by(order_bys)
2651 .schema(Arc::new(schema.clone()))
2652 .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2653 .build()
2654 .map(Arc::new)
2655 }
2656
2657 async fn first_last_multi_partitions(
2675 use_coalesce_batches: bool,
2676 is_first_acc: bool,
2677 spill: bool,
2678 max_memory: usize,
2679 ) -> Result<()> {
2680 let task_ctx = if spill {
2681 new_spill_ctx(2, max_memory)
2682 } else {
2683 Arc::new(TaskContext::default())
2684 };
2685
2686 let (schema, data) = some_data_v2();
2687 let partition1 = data[0].clone();
2688 let partition2 = data[1].clone();
2689 let partition3 = data[2].clone();
2690 let partition4 = data[3].clone();
2691
2692 let groups =
2693 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2694
2695 let sort_options = SortOptions {
2696 descending: false,
2697 nulls_first: false,
2698 };
2699 let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2700 vec![test_first_value_agg_expr(&schema, sort_options)?]
2701 } else {
2702 vec![test_last_value_agg_expr(&schema, sort_options)?]
2703 };
2704
2705 let memory_exec = TestMemoryExec::try_new_exec(
2706 &[
2707 vec![partition1],
2708 vec![partition2],
2709 vec![partition3],
2710 vec![partition4],
2711 ],
2712 Arc::clone(&schema),
2713 None,
2714 )?;
2715 let aggregate_exec = Arc::new(AggregateExec::try_new(
2716 AggregateMode::Partial,
2717 groups.clone(),
2718 aggregates.clone(),
2719 vec![None],
2720 memory_exec,
2721 Arc::clone(&schema),
2722 )?);
2723 let coalesce = if use_coalesce_batches {
2724 let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2725 Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2726 } else {
2727 Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2728 as Arc<dyn ExecutionPlan>
2729 };
2730 let aggregate_final = Arc::new(AggregateExec::try_new(
2731 AggregateMode::Final,
2732 groups,
2733 aggregates.clone(),
2734 vec![None],
2735 coalesce,
2736 schema,
2737 )?) as Arc<dyn ExecutionPlan>;
2738
2739 let result = crate::collect(aggregate_final, task_ctx).await?;
2740 if is_first_acc {
2741 allow_duplicates! {
2742 assert_snapshot!(batches_to_string(&result), @r"
2743 +---+--------------------------------------------+
2744 | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
2745 +---+--------------------------------------------+
2746 | 2 | 0.0 |
2747 | 3 | 1.0 |
2748 | 4 | 3.0 |
2749 +---+--------------------------------------------+
2750 ");
2751 }
2752 } else {
2753 allow_duplicates! {
2754 assert_snapshot!(batches_to_string(&result), @r"
2755 +---+-------------------------------------------+
2756 | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
2757 +---+-------------------------------------------+
2758 | 2 | 3.0 |
2759 | 3 | 5.0 |
2760 | 4 | 6.0 |
2761 +---+-------------------------------------------+
2762 ");
2763 }
2764 };
2765 Ok(())
2766 }
2767
2768 #[tokio::test]
2769 async fn test_get_finest_requirements() -> Result<()> {
2770 let test_schema = create_test_schema()?;
2771
2772 let options = SortOptions {
2773 descending: false,
2774 nulls_first: false,
2775 };
2776 let col_a = &col("a", &test_schema)?;
2777 let col_b = &col("b", &test_schema)?;
2778 let col_c = &col("c", &test_schema)?;
2779 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2780 eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
2782 let order_by_exprs = vec![
2785 vec![],
2786 vec![PhysicalSortExpr {
2787 expr: Arc::clone(col_a),
2788 options,
2789 }],
2790 vec![
2791 PhysicalSortExpr {
2792 expr: Arc::clone(col_a),
2793 options,
2794 },
2795 PhysicalSortExpr {
2796 expr: Arc::clone(col_b),
2797 options,
2798 },
2799 PhysicalSortExpr {
2800 expr: Arc::clone(col_c),
2801 options,
2802 },
2803 ],
2804 vec![
2805 PhysicalSortExpr {
2806 expr: Arc::clone(col_a),
2807 options,
2808 },
2809 PhysicalSortExpr {
2810 expr: Arc::clone(col_b),
2811 options,
2812 },
2813 ],
2814 ];
2815
2816 let common_requirement = vec![
2817 PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
2818 PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
2819 ];
2820 let mut aggr_exprs = order_by_exprs
2821 .into_iter()
2822 .map(|order_by_expr| {
2823 AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2824 .alias("a")
2825 .order_by(order_by_expr)
2826 .schema(Arc::clone(&test_schema))
2827 .build()
2828 .map(Arc::new)
2829 .unwrap()
2830 })
2831 .collect::<Vec<_>>();
2832 let group_by = PhysicalGroupBy::new_single(vec![]);
2833 let result = get_finer_aggregate_exprs_requirement(
2834 &mut aggr_exprs,
2835 &group_by,
2836 &eq_properties,
2837 &AggregateMode::Partial,
2838 )?;
2839 assert_eq!(result, common_requirement);
2840 Ok(())
2841 }
2842
2843 #[test]
2844 fn test_agg_exec_same_schema() -> Result<()> {
2845 let schema = Arc::new(Schema::new(vec![
2846 Field::new("a", DataType::Float32, true),
2847 Field::new("b", DataType::Float32, true),
2848 ]));
2849
2850 let col_a = col("a", &schema)?;
2851 let option_desc = SortOptions {
2852 descending: true,
2853 nulls_first: true,
2854 };
2855 let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2856
2857 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2858 test_first_value_agg_expr(&schema, option_desc)?,
2859 test_last_value_agg_expr(&schema, option_desc)?,
2860 ];
2861 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2862 let aggregate_exec = Arc::new(AggregateExec::try_new(
2863 AggregateMode::Partial,
2864 groups,
2865 aggregates,
2866 vec![None, None],
2867 Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2868 schema,
2869 )?);
2870 let new_agg =
2871 Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
2872 assert_eq!(new_agg.schema(), aggregate_exec.schema());
2873 Ok(())
2874 }
2875
2876 #[tokio::test]
2877 async fn test_agg_exec_group_by_const() -> Result<()> {
2878 let schema = Arc::new(Schema::new(vec![
2879 Field::new("a", DataType::Float32, true),
2880 Field::new("b", DataType::Float32, true),
2881 Field::new("const", DataType::Int32, false),
2882 ]));
2883
2884 let col_a = col("a", &schema)?;
2885 let col_b = col("b", &schema)?;
2886 let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2887
2888 let groups = PhysicalGroupBy::new(
2889 vec![
2890 (col_a, "a".to_string()),
2891 (col_b, "b".to_string()),
2892 (const_expr, "const".to_string()),
2893 ],
2894 vec![
2895 (
2896 Arc::new(Literal::new(ScalarValue::Float32(None))),
2897 "a".to_string(),
2898 ),
2899 (
2900 Arc::new(Literal::new(ScalarValue::Float32(None))),
2901 "b".to_string(),
2902 ),
2903 (
2904 Arc::new(Literal::new(ScalarValue::Int32(None))),
2905 "const".to_string(),
2906 ),
2907 ],
2908 vec![
2909 vec![false, true, true],
2910 vec![true, false, true],
2911 vec![true, true, false],
2912 ],
2913 true,
2914 );
2915
2916 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2917 AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2918 .schema(Arc::clone(&schema))
2919 .alias("1")
2920 .build()
2921 .map(Arc::new)?,
2922 ];
2923
2924 let input_batches = (0..4)
2925 .map(|_| {
2926 let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2927 let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2928 let c = Arc::new(Int32Array::from(vec![1; 8192]));
2929
2930 RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2931 })
2932 .collect();
2933
2934 let input =
2935 TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
2936
2937 let aggregate_exec = Arc::new(AggregateExec::try_new(
2938 AggregateMode::Single,
2939 groups,
2940 aggregates.clone(),
2941 vec![None],
2942 input,
2943 schema,
2944 )?);
2945
2946 let output =
2947 collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2948
2949 allow_duplicates! {
2950 assert_snapshot!(batches_to_sort_string(&output), @r"
2951 +-----+-----+-------+---------------+-------+
2952 | a | b | const | __grouping_id | 1 |
2953 +-----+-----+-------+---------------+-------+
2954 | | | 1 | 6 | 32768 |
2955 | | 0.0 | | 5 | 32768 |
2956 | 0.0 | | | 3 | 32768 |
2957 +-----+-----+-------+---------------+-------+
2958 ");
2959 }
2960
2961 Ok(())
2962 }
2963
2964 #[tokio::test]
2965 async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2966 let batch = RecordBatch::try_new(
2967 Arc::new(Schema::new(vec![
2968 Field::new(
2969 "labels".to_string(),
2970 DataType::Struct(
2971 vec![
2972 Field::new(
2973 "a".to_string(),
2974 DataType::Dictionary(
2975 Box::new(DataType::Int32),
2976 Box::new(DataType::Utf8),
2977 ),
2978 true,
2979 ),
2980 Field::new(
2981 "b".to_string(),
2982 DataType::Dictionary(
2983 Box::new(DataType::Int32),
2984 Box::new(DataType::Utf8),
2985 ),
2986 true,
2987 ),
2988 ]
2989 .into(),
2990 ),
2991 false,
2992 ),
2993 Field::new("value", DataType::UInt64, false),
2994 ])),
2995 vec![
2996 Arc::new(StructArray::from(vec![
2997 (
2998 Arc::new(Field::new(
2999 "a".to_string(),
3000 DataType::Dictionary(
3001 Box::new(DataType::Int32),
3002 Box::new(DataType::Utf8),
3003 ),
3004 true,
3005 )),
3006 Arc::new(
3007 vec![Some("a"), None, Some("a")]
3008 .into_iter()
3009 .collect::<DictionaryArray<Int32Type>>(),
3010 ) as ArrayRef,
3011 ),
3012 (
3013 Arc::new(Field::new(
3014 "b".to_string(),
3015 DataType::Dictionary(
3016 Box::new(DataType::Int32),
3017 Box::new(DataType::Utf8),
3018 ),
3019 true,
3020 )),
3021 Arc::new(
3022 vec![Some("b"), Some("c"), Some("b")]
3023 .into_iter()
3024 .collect::<DictionaryArray<Int32Type>>(),
3025 ) as ArrayRef,
3026 ),
3027 ])),
3028 Arc::new(UInt64Array::from(vec![1, 1, 1])),
3029 ],
3030 )
3031 .expect("Failed to create RecordBatch");
3032
3033 let group_by = PhysicalGroupBy::new_single(vec![(
3034 col("labels", &batch.schema())?,
3035 "labels".to_string(),
3036 )]);
3037
3038 let aggr_expr = vec![
3039 AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
3040 .schema(Arc::clone(&batch.schema()))
3041 .alias(String::from("SUM(value)"))
3042 .build()
3043 .map(Arc::new)?,
3044 ];
3045
3046 let input = TestMemoryExec::try_new_exec(
3047 &[vec![batch.clone()]],
3048 Arc::<Schema>::clone(&batch.schema()),
3049 None,
3050 )?;
3051 let aggregate_exec = Arc::new(AggregateExec::try_new(
3052 AggregateMode::FinalPartitioned,
3053 group_by,
3054 aggr_expr,
3055 vec![None],
3056 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3057 batch.schema(),
3058 )?);
3059
3060 let session_config = SessionConfig::default();
3061 let ctx = TaskContext::default().with_session_config(session_config);
3062 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3063
3064 allow_duplicates! {
3065 assert_snapshot!(batches_to_string(&output), @r"
3066 +--------------+------------+
3067 | labels | SUM(value) |
3068 +--------------+------------+
3069 | {a: a, b: b} | 2 |
3070 | {a: , b: c} | 1 |
3071 +--------------+------------+
3072 ");
3073 }
3074
3075 Ok(())
3076 }
3077
3078 #[tokio::test]
3079 async fn test_skip_aggregation_after_first_batch() -> Result<()> {
3080 let schema = Arc::new(Schema::new(vec![
3081 Field::new("key", DataType::Int32, true),
3082 Field::new("val", DataType::Int32, true),
3083 ]));
3084
3085 let group_by =
3086 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3087
3088 let aggr_expr = vec![
3089 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3090 .schema(Arc::clone(&schema))
3091 .alias(String::from("COUNT(val)"))
3092 .build()
3093 .map(Arc::new)?,
3094 ];
3095
3096 let input_data = vec![
3097 RecordBatch::try_new(
3098 Arc::clone(&schema),
3099 vec![
3100 Arc::new(Int32Array::from(vec![1, 2, 3])),
3101 Arc::new(Int32Array::from(vec![0, 0, 0])),
3102 ],
3103 )
3104 .unwrap(),
3105 RecordBatch::try_new(
3106 Arc::clone(&schema),
3107 vec![
3108 Arc::new(Int32Array::from(vec![2, 3, 4])),
3109 Arc::new(Int32Array::from(vec![0, 0, 0])),
3110 ],
3111 )
3112 .unwrap(),
3113 ];
3114
3115 let input =
3116 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3117 let aggregate_exec = Arc::new(AggregateExec::try_new(
3118 AggregateMode::Partial,
3119 group_by,
3120 aggr_expr,
3121 vec![None],
3122 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3123 schema,
3124 )?);
3125
3126 let mut session_config = SessionConfig::default();
3127 session_config = session_config.set(
3128 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3129 &ScalarValue::Int64(Some(2)),
3130 );
3131 session_config = session_config.set(
3132 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3133 &ScalarValue::Float64(Some(0.1)),
3134 );
3135
3136 let ctx = TaskContext::default().with_session_config(session_config);
3137 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3138
3139 allow_duplicates! {
3140 assert_snapshot!(batches_to_string(&output), @r"
3141 +-----+-------------------+
3142 | key | COUNT(val)[count] |
3143 +-----+-------------------+
3144 | 1 | 1 |
3145 | 2 | 1 |
3146 | 3 | 1 |
3147 | 2 | 1 |
3148 | 3 | 1 |
3149 | 4 | 1 |
3150 +-----+-------------------+
3151 ");
3152 }
3153
3154 Ok(())
3155 }
3156
3157 #[tokio::test]
3158 async fn test_skip_aggregation_after_threshold() -> Result<()> {
3159 let schema = Arc::new(Schema::new(vec![
3160 Field::new("key", DataType::Int32, true),
3161 Field::new("val", DataType::Int32, true),
3162 ]));
3163
3164 let group_by =
3165 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3166
3167 let aggr_expr = vec![
3168 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3169 .schema(Arc::clone(&schema))
3170 .alias(String::from("COUNT(val)"))
3171 .build()
3172 .map(Arc::new)?,
3173 ];
3174
3175 let input_data = vec![
3176 RecordBatch::try_new(
3177 Arc::clone(&schema),
3178 vec![
3179 Arc::new(Int32Array::from(vec![1, 2, 3])),
3180 Arc::new(Int32Array::from(vec![0, 0, 0])),
3181 ],
3182 )
3183 .unwrap(),
3184 RecordBatch::try_new(
3185 Arc::clone(&schema),
3186 vec![
3187 Arc::new(Int32Array::from(vec![2, 3, 4])),
3188 Arc::new(Int32Array::from(vec![0, 0, 0])),
3189 ],
3190 )
3191 .unwrap(),
3192 RecordBatch::try_new(
3193 Arc::clone(&schema),
3194 vec![
3195 Arc::new(Int32Array::from(vec![2, 3, 4])),
3196 Arc::new(Int32Array::from(vec![0, 0, 0])),
3197 ],
3198 )
3199 .unwrap(),
3200 ];
3201
3202 let input =
3203 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3204 let aggregate_exec = Arc::new(AggregateExec::try_new(
3205 AggregateMode::Partial,
3206 group_by,
3207 aggr_expr,
3208 vec![None],
3209 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3210 schema,
3211 )?);
3212
3213 let mut session_config = SessionConfig::default();
3214 session_config = session_config.set(
3215 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3216 &ScalarValue::Int64(Some(5)),
3217 );
3218 session_config = session_config.set(
3219 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3220 &ScalarValue::Float64(Some(0.1)),
3221 );
3222
3223 let ctx = TaskContext::default().with_session_config(session_config);
3224 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3225
3226 allow_duplicates! {
3227 assert_snapshot!(batches_to_string(&output), @r"
3228 +-----+-------------------+
3229 | key | COUNT(val)[count] |
3230 +-----+-------------------+
3231 | 1 | 1 |
3232 | 2 | 2 |
3233 | 3 | 2 |
3234 | 4 | 1 |
3235 | 2 | 1 |
3236 | 3 | 1 |
3237 | 4 | 1 |
3238 +-----+-------------------+
3239 ");
3240 }
3241
3242 Ok(())
3243 }
3244
3245 #[test]
3246 fn group_exprs_nullable() -> Result<()> {
3247 let input_schema = Arc::new(Schema::new(vec![
3248 Field::new("a", DataType::Float32, false),
3249 Field::new("b", DataType::Float32, false),
3250 ]));
3251
3252 let aggr_expr = vec![
3253 AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
3254 .schema(Arc::clone(&input_schema))
3255 .alias("COUNT(a)")
3256 .build()
3257 .map(Arc::new)?,
3258 ];
3259
3260 let grouping_set = PhysicalGroupBy::new(
3261 vec![
3262 (col("a", &input_schema)?, "a".to_string()),
3263 (col("b", &input_schema)?, "b".to_string()),
3264 ],
3265 vec![
3266 (lit(ScalarValue::Float32(None)), "a".to_string()),
3267 (lit(ScalarValue::Float32(None)), "b".to_string()),
3268 ],
3269 vec![
3270 vec![false, true], vec![false, false], ],
3273 true,
3274 );
3275 let aggr_schema = create_schema(
3276 &input_schema,
3277 &grouping_set,
3278 &aggr_expr,
3279 AggregateMode::Final,
3280 )?;
3281 let expected_schema = Schema::new(vec![
3282 Field::new("a", DataType::Float32, false),
3283 Field::new("b", DataType::Float32, true),
3284 Field::new("__grouping_id", DataType::UInt8, false),
3285 Field::new("COUNT(a)", DataType::Int64, false),
3286 ]);
3287 assert_eq!(aggr_schema, expected_schema);
3288 Ok(())
3289 }
3290
3291 async fn run_test_with_spill_pool_if_necessary(
3293 pool_size: usize,
3294 expect_spill: bool,
3295 ) -> Result<()> {
3296 fn create_record_batch(
3297 schema: &Arc<Schema>,
3298 data: (Vec<u32>, Vec<f64>),
3299 ) -> Result<RecordBatch> {
3300 Ok(RecordBatch::try_new(
3301 Arc::clone(schema),
3302 vec![
3303 Arc::new(UInt32Array::from(data.0)),
3304 Arc::new(Float64Array::from(data.1)),
3305 ],
3306 )?)
3307 }
3308
3309 let schema = Arc::new(Schema::new(vec![
3310 Field::new("a", DataType::UInt32, false),
3311 Field::new("b", DataType::Float64, false),
3312 ]));
3313
3314 let batches = vec![
3315 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3316 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3317 ];
3318 let plan: Arc<dyn ExecutionPlan> =
3319 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3320
3321 let grouping_set = PhysicalGroupBy::new(
3322 vec![(col("a", &schema)?, "a".to_string())],
3323 vec![],
3324 vec![vec![false]],
3325 false,
3326 );
3327
3328 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3330 Arc::new(
3331 AggregateExprBuilder::new(
3332 datafusion_functions_aggregate::min_max::min_udaf(),
3333 vec![col("b", &schema)?],
3334 )
3335 .schema(Arc::clone(&schema))
3336 .alias("MIN(b)")
3337 .build()?,
3338 ),
3339 Arc::new(
3340 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3341 .schema(Arc::clone(&schema))
3342 .alias("AVG(b)")
3343 .build()?,
3344 ),
3345 ];
3346
3347 let single_aggregate = Arc::new(AggregateExec::try_new(
3348 AggregateMode::Single,
3349 grouping_set,
3350 aggregates,
3351 vec![None, None],
3352 plan,
3353 Arc::clone(&schema),
3354 )?);
3355
3356 let batch_size = 2;
3357 let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3358 let task_ctx = Arc::new(
3359 TaskContext::default()
3360 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3361 .with_runtime(Arc::new(
3362 RuntimeEnvBuilder::new()
3363 .with_memory_pool(memory_pool)
3364 .build()?,
3365 )),
3366 );
3367
3368 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3369
3370 assert_spill_count_metric(expect_spill, single_aggregate);
3371
3372 allow_duplicates! {
3373 assert_snapshot!(batches_to_string(&result), @r"
3374 +---+--------+--------+
3375 | a | MIN(b) | AVG(b) |
3376 +---+--------+--------+
3377 | 2 | 1.0 | 1.0 |
3378 | 3 | 2.0 | 2.0 |
3379 | 4 | 3.0 | 3.5 |
3380 +---+--------+--------+
3381 ");
3382 }
3383
3384 Ok(())
3385 }
3386
3387 fn assert_spill_count_metric(
3388 expect_spill: bool,
3389 single_aggregate: Arc<AggregateExec>,
3390 ) {
3391 if let Some(metrics_set) = single_aggregate.metrics() {
3392 let mut spill_count = 0;
3393
3394 for metric in metrics_set.iter() {
3396 if let MetricValue::SpillCount(count) = metric.value() {
3397 spill_count = count.value();
3398 break;
3399 }
3400 }
3401
3402 if expect_spill && spill_count == 0 {
3403 panic!(
3404 "Expected spill but SpillCount metric not found or SpillCount was 0."
3405 );
3406 } else if !expect_spill && spill_count > 0 {
3407 panic!(
3408 "Expected no spill but found SpillCount metric with value greater than 0."
3409 );
3410 }
3411 } else {
3412 panic!("No metrics returned from the operator; cannot verify spilling.");
3413 }
3414 }
3415
3416 #[tokio::test]
3417 async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3418 run_test_with_spill_pool_if_necessary(2_000, true).await?;
3420 run_test_with_spill_pool_if_necessary(20_000, false).await?;
3422 Ok(())
3423 }
3424
3425 #[tokio::test]
3426 async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
3427 fn create_record_batch(
3429 schema: &Arc<Schema>,
3430 data: (Vec<u32>, Vec<f64>),
3431 ) -> Result<RecordBatch> {
3432 Ok(RecordBatch::try_new(
3433 Arc::clone(schema),
3434 vec![
3435 Arc::new(UInt32Array::from(data.0)),
3436 Arc::new(Float64Array::from(data.1)),
3437 ],
3438 )?)
3439 }
3440
3441 let schema = Arc::new(Schema::new(vec![
3442 Field::new("a", DataType::UInt32, false),
3443 Field::new("b", DataType::Float64, false),
3444 ]));
3445
3446 let batches = vec![
3447 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3448 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3449 ];
3450 let plan: Arc<dyn ExecutionPlan> =
3451 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3452 let proj = ProjectionExec::try_new(
3453 vec![
3454 ProjectionExpr::new(lit("0"), "l".to_string()),
3455 ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
3456 ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
3457 ],
3458 plan,
3459 )?;
3460 let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
3461 let schema = plan.schema();
3462
3463 let grouping_set = PhysicalGroupBy::new(
3464 vec![
3465 (col("l", &schema)?, "l".to_string()),
3466 (col("a", &schema)?, "a".to_string()),
3467 ],
3468 vec![],
3469 vec![vec![false, false]],
3470 false,
3471 );
3472
3473 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3475 Arc::new(
3476 AggregateExprBuilder::new(
3477 datafusion_functions_aggregate::min_max::min_udaf(),
3478 vec![col("b", &schema)?],
3479 )
3480 .schema(Arc::clone(&schema))
3481 .alias("MIN(b)")
3482 .build()?,
3483 ),
3484 Arc::new(
3485 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3486 .schema(Arc::clone(&schema))
3487 .alias("AVG(b)")
3488 .build()?,
3489 ),
3490 ];
3491
3492 let single_aggregate = Arc::new(AggregateExec::try_new(
3493 AggregateMode::Single,
3494 grouping_set,
3495 aggregates,
3496 vec![None, None],
3497 plan,
3498 Arc::clone(&schema),
3499 )?);
3500
3501 let batch_size = 2;
3502 let memory_pool = Arc::new(FairSpillPool::new(2000));
3503 let task_ctx = Arc::new(
3504 TaskContext::default()
3505 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3506 .with_runtime(Arc::new(
3507 RuntimeEnvBuilder::new()
3508 .with_memory_pool(memory_pool)
3509 .build()?,
3510 )),
3511 );
3512
3513 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
3514 match result {
3515 Ok(result) => {
3516 assert_spill_count_metric(true, single_aggregate);
3517
3518 allow_duplicates! {
3519 assert_snapshot!(batches_to_string(&result), @r"
3520 +---+---+--------+--------+
3521 | l | a | MIN(b) | AVG(b) |
3522 +---+---+--------+--------+
3523 | 0 | 2 | 1.0 | 1.0 |
3524 | 0 | 3 | 2.0 | 2.0 |
3525 | 0 | 4 | 3.0 | 3.5 |
3526 +---+---+--------+--------+
3527 ");
3528 }
3529 }
3530 Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
3531 }
3532
3533 Ok(())
3534 }
3535
3536 #[tokio::test]
3537 async fn test_aggregate_statistics_edge_cases() -> Result<()> {
3538 use crate::test::exec::StatisticsExec;
3539 use datafusion_common::ColumnStatistics;
3540
3541 let schema = Arc::new(Schema::new(vec![
3542 Field::new("a", DataType::Int32, false),
3543 Field::new("b", DataType::Float64, false),
3544 ]));
3545
3546 let input = Arc::new(StatisticsExec::new(
3548 Statistics {
3549 num_rows: Precision::Exact(100),
3550 total_byte_size: Precision::Absent,
3551 column_statistics: vec![
3552 ColumnStatistics::new_unknown(),
3553 ColumnStatistics::new_unknown(),
3554 ],
3555 },
3556 (*schema).clone(),
3557 )) as Arc<dyn ExecutionPlan>;
3558
3559 let agg = Arc::new(AggregateExec::try_new(
3560 AggregateMode::Final,
3561 PhysicalGroupBy::default(),
3562 vec![Arc::new(
3563 AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3564 .schema(Arc::clone(&schema))
3565 .alias("COUNT(a)")
3566 .build()?,
3567 )],
3568 vec![None],
3569 input,
3570 Arc::clone(&schema),
3571 )?);
3572
3573 let stats = agg.partition_statistics(None)?;
3574 assert_eq!(stats.total_byte_size, Precision::Absent);
3575
3576 let input_zero = Arc::new(StatisticsExec::new(
3578 Statistics {
3579 num_rows: Precision::Exact(0),
3580 total_byte_size: Precision::Exact(0),
3581 column_statistics: vec![
3582 ColumnStatistics::new_unknown(),
3583 ColumnStatistics::new_unknown(),
3584 ],
3585 },
3586 (*schema).clone(),
3587 )) as Arc<dyn ExecutionPlan>;
3588
3589 let agg_zero = Arc::new(AggregateExec::try_new(
3590 AggregateMode::Final,
3591 PhysicalGroupBy::default(),
3592 vec![Arc::new(
3593 AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3594 .schema(Arc::clone(&schema))
3595 .alias("COUNT(a)")
3596 .build()?,
3597 )],
3598 vec![None],
3599 input_zero,
3600 Arc::clone(&schema),
3601 )?);
3602
3603 let stats_zero = agg_zero.partition_statistics(None)?;
3604 assert_eq!(stats_zero.total_byte_size, Precision::Absent);
3605
3606 Ok(())
3607 }
3608
3609 #[tokio::test]
3610 async fn test_order_is_retained_when_spilling() -> Result<()> {
3611 let schema = Arc::new(Schema::new(vec![
3612 Field::new("a", DataType::Int64, false),
3613 Field::new("b", DataType::Int64, false),
3614 Field::new("c", DataType::Int64, false),
3615 ]));
3616
3617 let batches = vec![vec![
3618 RecordBatch::try_new(
3619 Arc::clone(&schema),
3620 vec![
3621 Arc::new(Int64Array::from(vec![2])),
3622 Arc::new(Int64Array::from(vec![2])),
3623 Arc::new(Int64Array::from(vec![1])),
3624 ],
3625 )?,
3626 RecordBatch::try_new(
3627 Arc::clone(&schema),
3628 vec![
3629 Arc::new(Int64Array::from(vec![1])),
3630 Arc::new(Int64Array::from(vec![1])),
3631 Arc::new(Int64Array::from(vec![1])),
3632 ],
3633 )?,
3634 RecordBatch::try_new(
3635 Arc::clone(&schema),
3636 vec![
3637 Arc::new(Int64Array::from(vec![0])),
3638 Arc::new(Int64Array::from(vec![0])),
3639 Arc::new(Int64Array::from(vec![1])),
3640 ],
3641 )?,
3642 ]];
3643 let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3644 let scan = scan.try_with_sort_information(vec![
3645 LexOrdering::new([PhysicalSortExpr::new(
3646 col("b", schema.as_ref())?,
3647 SortOptions::default().desc(),
3648 )])
3649 .unwrap(),
3650 ])?;
3651
3652 let aggr = Arc::new(AggregateExec::try_new(
3653 AggregateMode::Single,
3654 PhysicalGroupBy::new(
3655 vec![
3656 (col("b", schema.as_ref())?, "b".to_string()),
3657 (col("c", schema.as_ref())?, "c".to_string()),
3658 ],
3659 vec![],
3660 vec![vec![false, false]],
3661 false,
3662 ),
3663 vec![Arc::new(
3664 AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
3665 .schema(Arc::clone(&schema))
3666 .alias("SUM(c)")
3667 .build()?,
3668 )],
3669 vec![None],
3670 Arc::new(scan) as Arc<dyn ExecutionPlan>,
3671 Arc::clone(&schema),
3672 )?);
3673
3674 let task_ctx = new_spill_ctx(1, 600);
3675 let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
3676 assert_spill_count_metric(true, aggr);
3677
3678 allow_duplicates! {
3679 assert_snapshot!(batches_to_string(&result), @r"
3680 +---+---+--------+
3681 | b | c | SUM(c) |
3682 +---+---+--------+
3683 | 2 | 1 | 1 |
3684 | 1 | 1 | 1 |
3685 | 0 | 1 | 1 |
3686 +---+---+--------+
3687 ");
3688 }
3689 Ok(())
3690 }
3691}