1use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25 no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26 topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30 ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
31 FilterPushdownPropagation, PushedDownPredicate,
32};
33use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
34use crate::{
35 DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36 SendableRecordBatchStream, Statistics, check_if_same_properties,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use parking_lot::Mutex;
41use std::collections::HashSet;
42
43use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use arrow_schema::FieldRef;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49 Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_expr::{Accumulator, Aggregate};
53use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
54use datafusion_physical_expr::equivalence::ProjectionMapping;
55use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
56use datafusion_physical_expr::{
57 ConstExpr, EquivalenceProperties, physical_exprs_contains,
58};
59use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
60use datafusion_physical_expr_common::sort_expr::{
61 LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
62};
63
64use datafusion_expr::utils::AggregateOrderSensitivity;
65use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
66use itertools::Itertools;
67use topk::hash_table::is_supported_hash_key_type;
68use topk::heap::is_supported_heap_type;
69
70pub mod group_values;
71mod no_grouping;
72pub mod order;
73mod row_hash;
74mod topk;
75mod topk_stream;
76
77pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool {
85 is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type)
86}
87
88const AGGREGATION_HASH_SEED: ahash::RandomState =
90 ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
91
92#[derive(Debug, Copy, Clone, PartialEq, Eq)]
98pub enum AggregateInputMode {
99 Raw,
102 Partial,
105}
106
107#[derive(Debug, Copy, Clone, PartialEq, Eq)]
113pub enum AggregateOutputMode {
114 Partial,
117 Final,
120}
121
122#[derive(Debug, Copy, Clone, PartialEq, Eq)]
141pub enum AggregateMode {
142 Partial,
149 Final,
165 FinalPartitioned,
174 Single,
182 SinglePartitioned,
191 PartialReduce,
208}
209
210impl AggregateMode {
211 pub fn input_mode(&self) -> AggregateInputMode {
217 match self {
218 AggregateMode::Partial
219 | AggregateMode::Single
220 | AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
221 AggregateMode::Final
222 | AggregateMode::FinalPartitioned
223 | AggregateMode::PartialReduce => AggregateInputMode::Partial,
224 }
225 }
226
227 pub fn output_mode(&self) -> AggregateOutputMode {
233 match self {
234 AggregateMode::Final
235 | AggregateMode::FinalPartitioned
236 | AggregateMode::Single
237 | AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
238 AggregateMode::Partial | AggregateMode::PartialReduce => {
239 AggregateOutputMode::Partial
240 }
241 }
242 }
243}
244
245#[derive(Clone, Debug, Default)]
264pub struct PhysicalGroupBy {
265 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
267 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
269 groups: Vec<Vec<bool>>,
274 has_grouping_set: bool,
277}
278
279impl PhysicalGroupBy {
280 pub fn new(
282 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
283 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
284 groups: Vec<Vec<bool>>,
285 has_grouping_set: bool,
286 ) -> Self {
287 Self {
288 expr,
289 null_expr,
290 groups,
291 has_grouping_set,
292 }
293 }
294
295 pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
298 let num_exprs = expr.len();
299 Self {
300 expr,
301 null_expr: vec![],
302 groups: vec![vec![false; num_exprs]],
303 has_grouping_set: false,
304 }
305 }
306
307 pub fn exprs_nullable(&self) -> Vec<bool> {
309 let mut exprs_nullable = vec![false; self.expr.len()];
310 for group in self.groups.iter() {
311 group.iter().enumerate().for_each(|(index, is_null)| {
312 if *is_null {
313 exprs_nullable[index] = true;
314 }
315 })
316 }
317 exprs_nullable
318 }
319
320 pub fn is_true_no_grouping(&self) -> bool {
322 self.is_empty() && !self.has_grouping_set
323 }
324
325 pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
327 &self.expr
328 }
329
330 pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
332 &self.null_expr
333 }
334
335 pub fn groups(&self) -> &[Vec<bool>] {
337 &self.groups
338 }
339
340 pub fn has_grouping_set(&self) -> bool {
342 self.has_grouping_set
343 }
344
345 pub fn is_empty(&self) -> bool {
347 self.expr.is_empty()
348 }
349
350 pub fn is_single(&self) -> bool {
353 !self.has_grouping_set
354 }
355
356 pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
358 self.expr
359 .iter()
360 .map(|(expr, _alias)| Arc::clone(expr))
361 .collect()
362 }
363
364 fn num_output_exprs(&self) -> usize {
366 let mut num_exprs = self.expr.len();
367 if self.has_grouping_set {
368 num_exprs += 1
369 }
370 num_exprs
371 }
372
373 pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
375 let num_output_exprs = self.num_output_exprs();
376 let mut output_exprs = Vec::with_capacity(num_output_exprs);
377 output_exprs.extend(
378 self.expr
379 .iter()
380 .enumerate()
381 .take(num_output_exprs)
382 .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
383 );
384 if self.has_grouping_set {
385 output_exprs.push(Arc::new(Column::new(
386 Aggregate::INTERNAL_GROUPING_ID,
387 self.expr.len(),
388 )) as _);
389 }
390 output_exprs
391 }
392
393 pub fn num_group_exprs(&self) -> usize {
395 self.expr.len() + usize::from(self.has_grouping_set)
396 }
397
398 pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
399 Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
400 }
401
402 fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
404 let mut fields = Vec::with_capacity(self.num_group_exprs());
405 for ((expr, name), group_expr_nullable) in
406 self.expr.iter().zip(self.exprs_nullable().into_iter())
407 {
408 fields.push(
409 Field::new(
410 name,
411 expr.data_type(input_schema)?,
412 group_expr_nullable || expr.nullable(input_schema)?,
413 )
414 .with_metadata(expr.return_field(input_schema)?.metadata().clone())
415 .into(),
416 );
417 }
418 if self.has_grouping_set {
419 fields.push(
420 Field::new(
421 Aggregate::INTERNAL_GROUPING_ID,
422 Aggregate::grouping_id_type(self.expr.len()),
423 false,
424 )
425 .into(),
426 );
427 }
428 Ok(fields)
429 }
430
431 fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
436 let mut fields = self.group_fields(input_schema)?;
437 fields.truncate(self.num_output_exprs());
438 Ok(fields)
439 }
440
441 pub fn as_final(&self) -> PhysicalGroupBy {
444 let expr: Vec<_> =
445 self.output_exprs()
446 .into_iter()
447 .zip(
448 self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
449 Aggregate::INTERNAL_GROUPING_ID.to_owned(),
450 )),
451 )
452 .collect();
453 let num_exprs = expr.len();
454 let groups = if self.expr.is_empty() && !self.has_grouping_set {
455 vec![]
457 } else {
458 vec![vec![false; num_exprs]]
459 };
460 Self {
461 expr,
462 null_expr: vec![],
463 groups,
464 has_grouping_set: false,
465 }
466 }
467}
468
469impl PartialEq for PhysicalGroupBy {
470 fn eq(&self, other: &PhysicalGroupBy) -> bool {
471 self.expr.len() == other.expr.len()
472 && self
473 .expr
474 .iter()
475 .zip(other.expr.iter())
476 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
477 && self.null_expr.len() == other.null_expr.len()
478 && self
479 .null_expr
480 .iter()
481 .zip(other.null_expr.iter())
482 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
483 && self.groups == other.groups
484 && self.has_grouping_set == other.has_grouping_set
485 }
486}
487
488#[expect(clippy::large_enum_variant)]
489enum StreamType {
490 AggregateStream(AggregateStream),
491 GroupedHash(GroupedHashAggregateStream),
492 GroupedPriorityQueue(GroupedTopKAggregateStream),
493}
494
495impl From<StreamType> for SendableRecordBatchStream {
496 fn from(stream: StreamType) -> Self {
497 match stream {
498 StreamType::AggregateStream(stream) => Box::pin(stream),
499 StreamType::GroupedHash(stream) => Box::pin(stream),
500 StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
501 }
502 }
503}
504
505#[derive(Debug, Clone)]
549struct AggrDynFilter {
550 filter: Arc<DynamicFilterPhysicalExpr>,
553 supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
561}
562
563#[derive(Debug, Clone)]
568struct PerAccumulatorDynFilter {
569 aggr_type: DynamicFilterAggregateType,
570 aggr_index: usize,
576 shared_bound: Arc<Mutex<ScalarValue>>,
578}
579
580#[derive(Debug, Clone)]
582enum DynamicFilterAggregateType {
583 Min,
584 Max,
585}
586
587#[derive(Debug, Clone, Copy, PartialEq, Eq)]
589pub struct LimitOptions {
590 pub limit: usize,
592 pub descending: Option<bool>,
595}
596
597impl LimitOptions {
598 pub fn new(limit: usize) -> Self {
600 Self {
601 limit,
602 descending: None,
603 }
604 }
605
606 pub fn new_with_order(limit: usize, descending: bool) -> Self {
608 Self {
609 limit,
610 descending: Some(descending),
611 }
612 }
613
614 pub fn limit(&self) -> usize {
615 self.limit
616 }
617
618 pub fn descending(&self) -> Option<bool> {
619 self.descending
620 }
621}
622
623#[derive(Debug, Clone)]
625pub struct AggregateExec {
626 mode: AggregateMode,
628 group_by: Arc<PhysicalGroupBy>,
631 aggr_expr: Arc<[Arc<AggregateFunctionExpr>]>,
634 filter_expr: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
637 limit_options: Option<LimitOptions>,
639 pub input: Arc<dyn ExecutionPlan>,
641 schema: SchemaRef,
643 pub input_schema: SchemaRef,
649 metrics: ExecutionPlanMetricsSet,
651 required_input_ordering: Option<OrderingRequirements>,
652 input_order_mode: InputOrderMode,
654 cache: Arc<PlanProperties>,
655 dynamic_filter: Option<Arc<AggrDynFilter>>,
662}
663
664impl AggregateExec {
665 pub fn with_new_aggr_exprs(
669 &self,
670 aggr_expr: impl Into<Arc<[Arc<AggregateFunctionExpr>]>>,
671 ) -> Self {
672 Self {
673 aggr_expr: aggr_expr.into(),
674 required_input_ordering: self.required_input_ordering.clone(),
676 metrics: ExecutionPlanMetricsSet::new(),
677 input_order_mode: self.input_order_mode.clone(),
678 cache: Arc::clone(&self.cache),
679 mode: self.mode,
680 group_by: Arc::clone(&self.group_by),
681 filter_expr: Arc::clone(&self.filter_expr),
682 limit_options: self.limit_options,
683 input: Arc::clone(&self.input),
684 schema: Arc::clone(&self.schema),
685 input_schema: Arc::clone(&self.input_schema),
686 dynamic_filter: self.dynamic_filter.clone(),
687 }
688 }
689
690 pub fn with_new_limit_options(&self, limit_options: Option<LimitOptions>) -> Self {
692 Self {
693 limit_options,
694 required_input_ordering: self.required_input_ordering.clone(),
696 metrics: ExecutionPlanMetricsSet::new(),
697 input_order_mode: self.input_order_mode.clone(),
698 cache: Arc::clone(&self.cache),
699 mode: self.mode,
700 group_by: Arc::clone(&self.group_by),
701 aggr_expr: Arc::clone(&self.aggr_expr),
702 filter_expr: Arc::clone(&self.filter_expr),
703 input: Arc::clone(&self.input),
704 schema: Arc::clone(&self.schema),
705 input_schema: Arc::clone(&self.input_schema),
706 dynamic_filter: self.dynamic_filter.clone(),
707 }
708 }
709
710 pub fn cache(&self) -> &PlanProperties {
711 &self.cache
712 }
713
714 pub fn try_new(
716 mode: AggregateMode,
717 group_by: impl Into<Arc<PhysicalGroupBy>>,
718 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
719 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
720 input: Arc<dyn ExecutionPlan>,
721 input_schema: SchemaRef,
722 ) -> Result<Self> {
723 let group_by = group_by.into();
724 let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
725
726 let schema = Arc::new(schema);
727 AggregateExec::try_new_with_schema(
728 mode,
729 group_by,
730 aggr_expr,
731 filter_expr,
732 input,
733 input_schema,
734 schema,
735 )
736 }
737
738 fn try_new_with_schema(
747 mode: AggregateMode,
748 group_by: impl Into<Arc<PhysicalGroupBy>>,
749 mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
750 filter_expr: impl Into<Arc<[Option<Arc<dyn PhysicalExpr>>]>>,
751 input: Arc<dyn ExecutionPlan>,
752 input_schema: SchemaRef,
753 schema: SchemaRef,
754 ) -> Result<Self> {
755 let group_by = group_by.into();
756 let filter_expr = filter_expr.into();
757
758 assert_eq_or_internal_err!(
760 aggr_expr.len(),
761 filter_expr.len(),
762 "Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
763 aggr_expr,
764 filter_expr
765 );
766
767 let input_eq_properties = input.equivalence_properties();
768 let groupby_exprs = group_by.input_exprs();
770 let (new_sort_exprs, indices) =
775 input_eq_properties.find_longest_permutation(&groupby_exprs)?;
776
777 let mut new_requirements = new_sort_exprs
778 .into_iter()
779 .map(PhysicalSortRequirement::from)
780 .collect::<Vec<_>>();
781
782 let req = get_finer_aggregate_exprs_requirement(
783 &mut aggr_expr,
784 &group_by,
785 input_eq_properties,
786 &mode,
787 )?;
788 new_requirements.extend(req);
789
790 let required_input_ordering =
791 LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
792
793 let indices: Vec<usize> = indices
799 .into_iter()
800 .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
801 .collect();
802
803 let input_order_mode = if indices.len() == groupby_exprs.len()
804 && !indices.is_empty()
805 && group_by.groups.len() == 1
806 {
807 InputOrderMode::Sorted
808 } else if !indices.is_empty() {
809 InputOrderMode::PartiallySorted(indices)
810 } else {
811 InputOrderMode::Linear
812 };
813
814 let group_expr_mapping =
816 ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
817
818 let cache = Self::compute_properties(
819 &input,
820 Arc::clone(&schema),
821 &group_expr_mapping,
822 &mode,
823 &input_order_mode,
824 aggr_expr.as_ref(),
825 )?;
826
827 let mut exec = AggregateExec {
828 mode,
829 group_by,
830 aggr_expr: aggr_expr.into(),
831 filter_expr,
832 input,
833 schema,
834 input_schema,
835 metrics: ExecutionPlanMetricsSet::new(),
836 required_input_ordering,
837 limit_options: None,
838 input_order_mode,
839 cache: Arc::new(cache),
840 dynamic_filter: None,
841 };
842
843 exec.init_dynamic_filter();
844
845 Ok(exec)
846 }
847
848 pub fn mode(&self) -> &AggregateMode {
850 &self.mode
851 }
852
853 pub fn with_limit_options(mut self, limit_options: Option<LimitOptions>) -> Self {
855 self.limit_options = limit_options;
856 self
857 }
858
859 pub fn limit_options(&self) -> Option<LimitOptions> {
861 self.limit_options
862 }
863
864 pub fn group_expr(&self) -> &PhysicalGroupBy {
866 &self.group_by
867 }
868
869 pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
871 self.group_by.output_exprs()
872 }
873
874 pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
876 &self.aggr_expr
877 }
878
879 pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
881 &self.filter_expr
882 }
883
884 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
886 &self.input
887 }
888
889 pub fn input_schema(&self) -> SchemaRef {
891 Arc::clone(&self.input_schema)
892 }
893
894 fn execute_typed(
895 &self,
896 partition: usize,
897 context: &Arc<TaskContext>,
898 ) -> Result<StreamType> {
899 if self.group_by.is_true_no_grouping() {
900 return Ok(StreamType::AggregateStream(AggregateStream::new(
901 self, context, partition,
902 )?));
903 }
904
905 if let Some(config) = self.limit_options
907 && !self.is_unordered_unfiltered_group_by_distinct()
908 {
909 return Ok(StreamType::GroupedPriorityQueue(
910 GroupedTopKAggregateStream::new(self, context, partition, config.limit)?,
911 ));
912 }
913
914 Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
916 self, context, partition,
917 )?))
918 }
919
920 pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
922 let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
923 agg_expr.get_minmax_desc()
924 }
925
926 pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
931 if self
932 .limit_options()
933 .and_then(|config| config.descending)
934 .is_some()
935 {
936 return false;
937 }
938 if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
940 return false;
941 }
942 if !self.aggr_expr().is_empty() {
944 return false;
945 }
946 if self.filter_expr().iter().any(|e| e.is_some()) {
949 return false;
950 }
951 if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
953 return false;
954 }
955 if self.properties().output_ordering().is_some() {
957 return false;
958 }
959 if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
961 return matches!(requirement, OrderingRequirements::Hard(_));
962 }
963 true
964 }
965
966 pub fn compute_properties(
968 input: &Arc<dyn ExecutionPlan>,
969 schema: SchemaRef,
970 group_expr_mapping: &ProjectionMapping,
971 mode: &AggregateMode,
972 input_order_mode: &InputOrderMode,
973 aggr_exprs: &[Arc<AggregateFunctionExpr>],
974 ) -> Result<PlanProperties> {
975 let mut eq_properties = input
977 .equivalence_properties()
978 .project(group_expr_mapping, schema);
979
980 if group_expr_mapping.is_empty() {
983 let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
984 let column = Arc::new(Column::new(func.name(), idx));
985 ConstExpr::from(column as Arc<dyn PhysicalExpr>)
986 });
987 eq_properties.add_constants(new_constants)?;
988 }
989
990 let mut constraints = eq_properties.constraints().to_vec();
993 let new_constraint = Constraint::Unique(
994 group_expr_mapping
995 .iter()
996 .flat_map(|(_, target_cols)| {
997 target_cols.iter().flat_map(|(expr, _)| {
998 expr.as_any().downcast_ref::<Column>().map(|c| c.index())
999 })
1000 })
1001 .collect(),
1002 );
1003 constraints.push(new_constraint);
1004 eq_properties =
1005 eq_properties.with_constraints(Constraints::new_unverified(constraints));
1006
1007 let input_partitioning = input.output_partitioning().clone();
1009 let output_partitioning = match mode.input_mode() {
1010 AggregateInputMode::Raw => {
1011 let input_eq_properties = input.equivalence_properties();
1015 input_partitioning.project(group_expr_mapping, input_eq_properties)
1016 }
1017 AggregateInputMode::Partial => input_partitioning.clone(),
1018 };
1019
1020 let emission_type = if *input_order_mode == InputOrderMode::Linear {
1022 EmissionType::Final
1023 } else {
1024 input.pipeline_behavior()
1025 };
1026
1027 Ok(PlanProperties::new(
1028 eq_properties,
1029 output_partitioning,
1030 emission_type,
1031 input.boundedness(),
1032 ))
1033 }
1034
1035 pub fn input_order_mode(&self) -> &InputOrderMode {
1036 &self.input_order_mode
1037 }
1038
1039 fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
1040 let column_statistics = {
1047 let mut column_statistics = Statistics::unknown_column(&self.schema());
1049
1050 for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
1051 if let Some(col) = expr.as_any().downcast_ref::<Column>() {
1052 column_statistics[idx].max_value = child_statistics.column_statistics
1053 [col.index()]
1054 .max_value
1055 .clone();
1056
1057 column_statistics[idx].min_value = child_statistics.column_statistics
1058 [col.index()]
1059 .min_value
1060 .clone();
1061 }
1062 }
1063
1064 column_statistics
1065 };
1066 match self.mode {
1067 AggregateMode::Final | AggregateMode::FinalPartitioned
1068 if self.group_by.expr.is_empty() =>
1069 {
1070 let total_byte_size =
1071 Self::calculate_scaled_byte_size(child_statistics, 1);
1072
1073 Ok(Statistics {
1074 num_rows: Precision::Exact(1),
1075 column_statistics,
1076 total_byte_size,
1077 })
1078 }
1079 _ => {
1080 let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
1083 {
1084 if *value > 1 {
1085 child_statistics.num_rows.to_inexact()
1086 } else if *value == 0 {
1087 child_statistics.num_rows
1088 } else {
1089 let grouping_set_num = self.group_by.groups.len();
1091 child_statistics.num_rows.map(|x| x * grouping_set_num)
1092 }
1093 } else {
1094 Precision::Absent
1095 };
1096
1097 let total_byte_size = num_rows
1098 .get_value()
1099 .and_then(|&output_rows| {
1100 Self::calculate_scaled_byte_size(child_statistics, output_rows)
1101 .get_value()
1102 .map(|&bytes| Precision::Inexact(bytes))
1103 })
1104 .unwrap_or(Precision::Absent);
1105
1106 Ok(Statistics {
1107 num_rows,
1108 column_statistics,
1109 total_byte_size,
1110 })
1111 }
1112 }
1113 }
1114
1115 fn init_dynamic_filter(&mut self) {
1119 if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
1120 debug_assert!(
1121 self.dynamic_filter.is_none(),
1122 "The current operator node does not support dynamic filter"
1123 );
1124 return;
1125 }
1126
1127 if self.dynamic_filter.is_some() {
1129 return;
1130 }
1131
1132 let mut aggr_dyn_filters = Vec::new();
1136 let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
1140 for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
1141 let fun_name = aggr_expr.fun().name();
1143 let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
1146 DynamicFilterAggregateType::Min
1147 } else if fun_name.eq_ignore_ascii_case("max") {
1148 DynamicFilterAggregateType::Max
1149 } else {
1150 return;
1151 };
1152
1153 if let [arg] = aggr_expr.expressions().as_slice()
1155 && arg.as_any().is::<Column>()
1156 {
1157 all_cols.push(Arc::clone(arg));
1158 aggr_dyn_filters.push(PerAccumulatorDynFilter {
1159 aggr_type,
1160 aggr_index: i,
1161 shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
1162 });
1163 }
1164 }
1165
1166 if !aggr_dyn_filters.is_empty() {
1167 self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1168 filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1169 supported_accumulators_info: aggr_dyn_filters,
1170 }))
1171 }
1172 }
1173
1174 #[inline]
1180 fn calculate_scaled_byte_size(
1181 input_stats: &Statistics,
1182 target_row_count: usize,
1183 ) -> Precision<usize> {
1184 match (
1185 input_stats.num_rows.get_value(),
1186 input_stats.total_byte_size.get_value(),
1187 ) {
1188 (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
1189 let bytes_per_row = input_bytes as f64 / input_rows as f64;
1190 let scaled_bytes =
1191 (bytes_per_row * target_row_count as f64).ceil() as usize;
1192 Precision::Inexact(scaled_bytes)
1193 }
1194 _ => Precision::Absent,
1195 }
1196 }
1197
1198 fn with_new_children_and_same_properties(
1199 &self,
1200 mut children: Vec<Arc<dyn ExecutionPlan>>,
1201 ) -> Self {
1202 Self {
1203 input: children.swap_remove(0),
1204 metrics: ExecutionPlanMetricsSet::new(),
1205 ..Self::clone(self)
1206 }
1207 }
1208}
1209
1210impl DisplayAs for AggregateExec {
1211 fn fmt_as(
1212 &self,
1213 t: DisplayFormatType,
1214 f: &mut std::fmt::Formatter,
1215 ) -> std::fmt::Result {
1216 match t {
1217 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1218 let format_expr_with_alias =
1219 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1220 let e = e.to_string();
1221 if &e != alias {
1222 format!("{e} as {alias}")
1223 } else {
1224 e
1225 }
1226 };
1227
1228 write!(f, "AggregateExec: mode={:?}", self.mode)?;
1229 let g: Vec<String> = if self.group_by.is_single() {
1230 self.group_by
1231 .expr
1232 .iter()
1233 .map(format_expr_with_alias)
1234 .collect()
1235 } else {
1236 self.group_by
1237 .groups
1238 .iter()
1239 .map(|group| {
1240 let terms = group
1241 .iter()
1242 .enumerate()
1243 .map(|(idx, is_null)| {
1244 if *is_null {
1245 format_expr_with_alias(
1246 &self.group_by.null_expr[idx],
1247 )
1248 } else {
1249 format_expr_with_alias(&self.group_by.expr[idx])
1250 }
1251 })
1252 .collect::<Vec<String>>()
1253 .join(", ");
1254 format!("({terms})")
1255 })
1256 .collect()
1257 };
1258
1259 write!(f, ", gby=[{}]", g.join(", "))?;
1260
1261 let a: Vec<String> = self
1262 .aggr_expr
1263 .iter()
1264 .map(|agg| agg.name().to_string())
1265 .collect();
1266 write!(f, ", aggr=[{}]", a.join(", "))?;
1267 if let Some(config) = self.limit_options {
1268 write!(f, ", lim=[{}]", config.limit)?;
1269 }
1270
1271 if self.input_order_mode != InputOrderMode::Linear {
1272 write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
1273 }
1274 }
1275 DisplayFormatType::TreeRender => {
1276 let format_expr_with_alias =
1277 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1278 let expr_sql = fmt_sql(e.as_ref()).to_string();
1279 if &expr_sql != alias {
1280 format!("{expr_sql} as {alias}")
1281 } else {
1282 expr_sql
1283 }
1284 };
1285
1286 let g: Vec<String> = if self.group_by.is_single() {
1287 self.group_by
1288 .expr
1289 .iter()
1290 .map(format_expr_with_alias)
1291 .collect()
1292 } else {
1293 self.group_by
1294 .groups
1295 .iter()
1296 .map(|group| {
1297 let terms = group
1298 .iter()
1299 .enumerate()
1300 .map(|(idx, is_null)| {
1301 if *is_null {
1302 format_expr_with_alias(
1303 &self.group_by.null_expr[idx],
1304 )
1305 } else {
1306 format_expr_with_alias(&self.group_by.expr[idx])
1307 }
1308 })
1309 .collect::<Vec<String>>()
1310 .join(", ");
1311 format!("({terms})")
1312 })
1313 .collect()
1314 };
1315 let a: Vec<String> = self
1316 .aggr_expr
1317 .iter()
1318 .map(|agg| agg.human_display().to_string())
1319 .collect();
1320 writeln!(f, "mode={:?}", self.mode)?;
1321 if !g.is_empty() {
1322 writeln!(f, "group_by={}", g.join(", "))?;
1323 }
1324 if !a.is_empty() {
1325 writeln!(f, "aggr={}", a.join(", "))?;
1326 }
1327 if let Some(config) = self.limit_options {
1328 writeln!(f, "limit={}", config.limit)?;
1329 }
1330 }
1331 }
1332 Ok(())
1333 }
1334}
1335
1336impl ExecutionPlan for AggregateExec {
1337 fn name(&self) -> &'static str {
1338 "AggregateExec"
1339 }
1340
1341 fn as_any(&self) -> &dyn Any {
1343 self
1344 }
1345
1346 fn properties(&self) -> &Arc<PlanProperties> {
1347 &self.cache
1348 }
1349
1350 fn required_input_distribution(&self) -> Vec<Distribution> {
1351 match &self.mode {
1352 AggregateMode::Partial | AggregateMode::PartialReduce => {
1353 vec![Distribution::UnspecifiedDistribution]
1354 }
1355 AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
1356 vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
1357 }
1358 AggregateMode::Final | AggregateMode::Single => {
1359 vec![Distribution::SinglePartition]
1360 }
1361 }
1362 }
1363
1364 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
1365 vec![self.required_input_ordering.clone()]
1366 }
1367
1368 fn maintains_input_order(&self) -> Vec<bool> {
1378 vec![self.input_order_mode != InputOrderMode::Linear]
1379 }
1380
1381 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1382 vec![&self.input]
1383 }
1384
1385 fn with_new_children(
1386 self: Arc<Self>,
1387 children: Vec<Arc<dyn ExecutionPlan>>,
1388 ) -> Result<Arc<dyn ExecutionPlan>> {
1389 check_if_same_properties!(self, children);
1390
1391 let mut me = AggregateExec::try_new_with_schema(
1392 self.mode,
1393 Arc::clone(&self.group_by),
1394 self.aggr_expr.to_vec(),
1395 Arc::clone(&self.filter_expr),
1396 Arc::clone(&children[0]),
1397 Arc::clone(&self.input_schema),
1398 Arc::clone(&self.schema),
1399 )?;
1400 me.limit_options = self.limit_options;
1401 me.dynamic_filter = self.dynamic_filter.clone();
1402
1403 Ok(Arc::new(me))
1404 }
1405
1406 fn execute(
1407 &self,
1408 partition: usize,
1409 context: Arc<TaskContext>,
1410 ) -> Result<SendableRecordBatchStream> {
1411 self.execute_typed(partition, &context)
1412 .map(|stream| stream.into())
1413 }
1414
1415 fn metrics(&self) -> Option<MetricsSet> {
1416 Some(self.metrics.clone_inner())
1417 }
1418
1419 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1420 let child_statistics = self.input().partition_statistics(partition)?;
1421 self.statistics_inner(&child_statistics)
1422 }
1423
1424 fn cardinality_effect(&self) -> CardinalityEffect {
1425 CardinalityEffect::LowerEqual
1426 }
1427
1428 fn gather_filters_for_pushdown(
1431 &self,
1432 phase: FilterPushdownPhase,
1433 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1434 config: &ConfigOptions,
1435 ) -> Result<FilterDescription> {
1436 let grouping_columns: HashSet<_> = self
1445 .group_by
1446 .expr()
1447 .iter()
1448 .flat_map(|(expr, _)| collect_columns(expr))
1449 .collect();
1450
1451 let mut safe_filters = Vec::new();
1453 let mut unsafe_filters = Vec::new();
1454
1455 for filter in parent_filters {
1456 let filter_columns: HashSet<_> =
1457 collect_columns(&filter).into_iter().collect();
1458
1459 let references_non_grouping = !grouping_columns.is_empty()
1461 && !filter_columns.is_subset(&grouping_columns);
1462
1463 if references_non_grouping {
1464 unsafe_filters.push(filter);
1465 continue;
1466 }
1467
1468 if self.group_by.groups().len() > 1 {
1470 let filter_column_indices: Vec<usize> = filter_columns
1471 .iter()
1472 .filter_map(|filter_col| {
1473 self.group_by.expr().iter().position(|(expr, _)| {
1474 collect_columns(expr).contains(filter_col)
1475 })
1476 })
1477 .collect();
1478
1479 let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1481 filter_column_indices
1482 .iter()
1483 .any(|&idx| null_mask.get(idx) == Some(&true))
1484 });
1485
1486 if has_missing_column {
1487 unsafe_filters.push(filter);
1488 continue;
1489 }
1490 }
1491
1492 safe_filters.push(filter);
1494 }
1495
1496 let child = self.children()[0];
1498 let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1499
1500 child_desc.parent_filters.extend(
1502 unsafe_filters
1503 .into_iter()
1504 .map(PushedDownPredicate::unsupported),
1505 );
1506
1507 if phase == FilterPushdownPhase::Post
1509 && config.optimizer.enable_aggregate_dynamic_filter_pushdown
1510 && let Some(self_dyn_filter) = &self.dynamic_filter
1511 {
1512 let dyn_filter = Arc::clone(&self_dyn_filter.filter);
1513 child_desc = child_desc.with_self_filter(dyn_filter);
1514 }
1515
1516 Ok(FilterDescription::new().with_child(child_desc))
1517 }
1518
1519 fn handle_child_pushdown_result(
1522 &self,
1523 phase: FilterPushdownPhase,
1524 child_pushdown_result: ChildPushdownResult,
1525 _config: &ConfigOptions,
1526 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1527 let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
1528
1529 if phase == FilterPushdownPhase::Post
1532 && let Some(dyn_filter) = &self.dynamic_filter
1533 {
1534 let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
1555
1556 if !child_accepts_dyn_filter {
1557 let mut new_node = self.clone();
1560 new_node.dynamic_filter = None;
1561
1562 result = result
1563 .with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
1564 }
1565 }
1566
1567 Ok(result)
1568 }
1569}
1570
1571fn create_schema(
1572 input_schema: &Schema,
1573 group_by: &PhysicalGroupBy,
1574 aggr_expr: &[Arc<AggregateFunctionExpr>],
1575 mode: AggregateMode,
1576) -> Result<Schema> {
1577 let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1578 fields.extend(group_by.output_fields(input_schema)?);
1579
1580 match mode.output_mode() {
1581 AggregateOutputMode::Final => {
1582 for expr in aggr_expr {
1584 fields.push(expr.field())
1585 }
1586 }
1587 AggregateOutputMode::Partial => {
1588 for expr in aggr_expr {
1590 fields.extend(expr.state_fields()?.iter().cloned());
1591 }
1592 }
1593 }
1594
1595 Ok(Schema::new_with_metadata(
1596 fields,
1597 input_schema.metadata().clone(),
1598 ))
1599}
1600
1601fn get_aggregate_expr_req(
1622 aggr_expr: &AggregateFunctionExpr,
1623 group_by: &PhysicalGroupBy,
1624 agg_mode: &AggregateMode,
1625 include_soft_requirement: bool,
1626) -> Option<LexOrdering> {
1627 if agg_mode.input_mode() == AggregateInputMode::Partial {
1631 return None;
1632 }
1633
1634 match aggr_expr.order_sensitivity() {
1635 AggregateOrderSensitivity::Insensitive => return None,
1636 AggregateOrderSensitivity::HardRequirement => {}
1637 AggregateOrderSensitivity::SoftRequirement => {
1638 if !include_soft_requirement {
1639 return None;
1640 }
1641 }
1642 AggregateOrderSensitivity::Beneficial => return None,
1643 }
1644
1645 let mut sort_exprs = aggr_expr.order_bys().to_vec();
1646 if group_by.is_single() {
1652 let physical_exprs = group_by.input_exprs();
1656 sort_exprs.retain(|sort_expr| {
1657 !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1658 });
1659 }
1660 LexOrdering::new(sort_exprs)
1661}
1662
1663pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1665 [lhs, rhs].concat()
1666}
1667
1668fn determine_finer(
1672 current: &Option<LexOrdering>,
1673 candidate: &LexOrdering,
1674) -> Option<bool> {
1675 if let Some(ordering) = current {
1676 candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1677 } else {
1678 Some(true)
1679 }
1680}
1681
1682pub fn get_finer_aggregate_exprs_requirement(
1703 aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1704 group_by: &PhysicalGroupBy,
1705 eq_properties: &EquivalenceProperties,
1706 agg_mode: &AggregateMode,
1707) -> Result<Vec<PhysicalSortRequirement>> {
1708 let mut requirement = None;
1709
1710 for include_soft_requirement in [false, true] {
1714 for aggr_expr in aggr_exprs.iter_mut() {
1715 let Some(aggr_req) = get_aggregate_expr_req(
1716 aggr_expr,
1717 group_by,
1718 agg_mode,
1719 include_soft_requirement,
1720 )
1721 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1722 continue;
1725 };
1726 let forward_finer = determine_finer(&requirement, &aggr_req);
1731 if let Some(finer) = forward_finer {
1732 if !finer {
1733 continue;
1734 } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1735 requirement = Some(aggr_req);
1736 continue;
1737 }
1738 }
1739 if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1740 let Some(rev_aggr_req) = get_aggregate_expr_req(
1741 &reverse_aggr_expr,
1742 group_by,
1743 agg_mode,
1744 include_soft_requirement,
1745 )
1746 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1747 *aggr_expr = Arc::new(reverse_aggr_expr);
1750 continue;
1751 };
1752 if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1758 if !finer {
1759 *aggr_expr = Arc::new(reverse_aggr_expr);
1760 } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1761 *aggr_expr = Arc::new(reverse_aggr_expr);
1762 requirement = Some(rev_aggr_req);
1763 } else {
1764 requirement = Some(aggr_req);
1765 }
1766 } else if forward_finer.is_some() {
1767 requirement = Some(aggr_req);
1768 } else {
1769 if !include_soft_requirement {
1774 return not_impl_err!(
1775 "Conflicting ordering requirements in aggregate functions is not supported"
1776 );
1777 }
1778 }
1779 }
1780 }
1781 }
1782
1783 Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1784}
1785
1786pub fn aggregate_expressions(
1792 aggr_expr: &[Arc<AggregateFunctionExpr>],
1793 mode: &AggregateMode,
1794 col_idx_base: usize,
1795) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1796 match mode.input_mode() {
1797 AggregateInputMode::Raw => Ok(aggr_expr
1798 .iter()
1799 .map(|agg| {
1800 let mut result = agg.expressions();
1801 result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
1805 result
1806 })
1807 .collect()),
1808 AggregateInputMode::Partial => {
1809 let mut col_idx_base = col_idx_base;
1811 aggr_expr
1812 .iter()
1813 .map(|agg| {
1814 let exprs = merge_expressions(col_idx_base, agg)?;
1815 col_idx_base += exprs.len();
1816 Ok(exprs)
1817 })
1818 .collect()
1819 }
1820 }
1821}
1822
1823fn merge_expressions(
1828 index_base: usize,
1829 expr: &AggregateFunctionExpr,
1830) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1831 expr.state_fields().map(|fields| {
1832 fields
1833 .iter()
1834 .enumerate()
1835 .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1836 .collect()
1837 })
1838}
1839
1840pub type AccumulatorItem = Box<dyn Accumulator>;
1841
1842pub fn create_accumulators(
1843 aggr_expr: &[Arc<AggregateFunctionExpr>],
1844) -> Result<Vec<AccumulatorItem>> {
1845 aggr_expr
1846 .iter()
1847 .map(|expr| expr.create_accumulator())
1848 .collect()
1849}
1850
1851pub fn finalize_aggregation(
1854 accumulators: &mut [AccumulatorItem],
1855 mode: &AggregateMode,
1856) -> Result<Vec<ArrayRef>> {
1857 match mode.output_mode() {
1858 AggregateOutputMode::Final => {
1859 accumulators
1861 .iter_mut()
1862 .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1863 .collect()
1864 }
1865 AggregateOutputMode::Partial => {
1866 accumulators
1868 .iter_mut()
1869 .map(|accumulator| {
1870 accumulator.state().and_then(|e| {
1871 e.iter()
1872 .map(|v| v.to_array())
1873 .collect::<Result<Vec<ArrayRef>>>()
1874 })
1875 })
1876 .flatten_ok()
1877 .collect()
1878 }
1879 }
1880}
1881
1882pub fn evaluate_many(
1884 expr: &[Vec<Arc<dyn PhysicalExpr>>],
1885 batch: &RecordBatch,
1886) -> Result<Vec<Vec<ArrayRef>>> {
1887 expr.iter()
1888 .map(|expr| evaluate_expressions_to_arrays(expr, batch))
1889 .collect()
1890}
1891
1892fn evaluate_optional(
1893 expr: &[Option<Arc<dyn PhysicalExpr>>],
1894 batch: &RecordBatch,
1895) -> Result<Vec<Option<ArrayRef>>> {
1896 expr.iter()
1897 .map(|expr| {
1898 expr.as_ref()
1899 .map(|expr| {
1900 expr.evaluate(batch)
1901 .and_then(|v| v.into_array(batch.num_rows()))
1902 })
1903 .transpose()
1904 })
1905 .collect()
1906}
1907
1908fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1909 if group.len() > 64 {
1910 return not_impl_err!(
1911 "Grouping sets with more than 64 columns are not supported"
1912 );
1913 }
1914 let group_id = group.iter().fold(0u64, |acc, &is_null| {
1915 (acc << 1) | if is_null { 1 } else { 0 }
1916 });
1917 let num_rows = batch.num_rows();
1918 if group.len() <= 8 {
1919 Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1920 } else if group.len() <= 16 {
1921 Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1922 } else if group.len() <= 32 {
1923 Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1924 } else {
1925 Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1926 }
1927}
1928
1929pub fn evaluate_group_by(
1940 group_by: &PhysicalGroupBy,
1941 batch: &RecordBatch,
1942) -> Result<Vec<Vec<ArrayRef>>> {
1943 let exprs = evaluate_expressions_to_arrays(
1944 group_by.expr.iter().map(|(expr, _)| expr),
1945 batch,
1946 )?;
1947 let null_exprs = evaluate_expressions_to_arrays(
1948 group_by.null_expr.iter().map(|(expr, _)| expr),
1949 batch,
1950 )?;
1951
1952 group_by
1953 .groups
1954 .iter()
1955 .map(|group| {
1956 let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1957 group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1958 if *is_null {
1959 Arc::clone(&null_exprs[idx])
1960 } else {
1961 Arc::clone(&exprs[idx])
1962 }
1963 }));
1964 if !group_by.is_single() {
1965 group_values.push(group_id_array(group, batch)?);
1966 }
1967 Ok(group_values)
1968 })
1969 .collect()
1970}
1971
1972#[cfg(test)]
1973mod tests {
1974 use std::task::{Context, Poll};
1975
1976 use super::*;
1977 use crate::RecordBatchStream;
1978 use crate::coalesce_partitions::CoalescePartitionsExec;
1979 use crate::common;
1980 use crate::common::collect;
1981 use crate::execution_plan::Boundedness;
1982 use crate::expressions::col;
1983 use crate::metrics::MetricValue;
1984 use crate::test::TestMemoryExec;
1985 use crate::test::assert_is_pending;
1986 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
1987
1988 use arrow::array::{
1989 DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
1990 UInt32Array, UInt64Array,
1991 };
1992 use arrow::compute::{SortOptions, concat_batches};
1993 use arrow::datatypes::{DataType, Int32Type};
1994 use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
1995 use datafusion_common::{DataFusionError, ScalarValue, internal_err};
1996 use datafusion_execution::config::SessionConfig;
1997 use datafusion_execution::memory_pool::FairSpillPool;
1998 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1999 use datafusion_functions_aggregate::array_agg::array_agg_udaf;
2000 use datafusion_functions_aggregate::average::avg_udaf;
2001 use datafusion_functions_aggregate::count::count_udaf;
2002 use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
2003 use datafusion_functions_aggregate::median::median_udaf;
2004 use datafusion_functions_aggregate::sum::sum_udaf;
2005 use datafusion_physical_expr::Partitioning;
2006 use datafusion_physical_expr::PhysicalSortExpr;
2007 use datafusion_physical_expr::aggregate::AggregateExprBuilder;
2008 use datafusion_physical_expr::expressions::Literal;
2009 use datafusion_physical_expr::expressions::lit;
2010
2011 use crate::projection::ProjectionExec;
2012 use datafusion_physical_expr::projection::ProjectionExpr;
2013 use futures::{FutureExt, Stream};
2014 use insta::{allow_duplicates, assert_snapshot};
2015
2016 fn create_test_schema() -> Result<SchemaRef> {
2018 let a = Field::new("a", DataType::Int32, true);
2019 let b = Field::new("b", DataType::Int32, true);
2020 let c = Field::new("c", DataType::Int32, true);
2021 let d = Field::new("d", DataType::Int32, true);
2022 let e = Field::new("e", DataType::Int32, true);
2023 let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
2024
2025 Ok(schema)
2026 }
2027
2028 fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
2030 let schema = Arc::new(Schema::new(vec![
2032 Field::new("a", DataType::UInt32, false),
2033 Field::new("b", DataType::Float64, false),
2034 ]));
2035
2036 (
2038 Arc::clone(&schema),
2039 vec![
2040 RecordBatch::try_new(
2041 Arc::clone(&schema),
2042 vec![
2043 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2044 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2045 ],
2046 )
2047 .unwrap(),
2048 RecordBatch::try_new(
2049 schema,
2050 vec![
2051 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2052 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2053 ],
2054 )
2055 .unwrap(),
2056 ],
2057 )
2058 }
2059
2060 fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
2062 let schema = Arc::new(Schema::new(vec![
2064 Field::new("a", DataType::UInt32, false),
2065 Field::new("b", DataType::Float64, false),
2066 ]));
2067
2068 (
2073 Arc::clone(&schema),
2074 vec![
2075 RecordBatch::try_new(
2076 Arc::clone(&schema),
2077 vec![
2078 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2079 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2080 ],
2081 )
2082 .unwrap(),
2083 RecordBatch::try_new(
2084 Arc::clone(&schema),
2085 vec![
2086 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2087 Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
2088 ],
2089 )
2090 .unwrap(),
2091 RecordBatch::try_new(
2092 Arc::clone(&schema),
2093 vec![
2094 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2095 Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
2096 ],
2097 )
2098 .unwrap(),
2099 RecordBatch::try_new(
2100 schema,
2101 vec![
2102 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2103 Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
2104 ],
2105 )
2106 .unwrap(),
2107 ],
2108 )
2109 }
2110
2111 fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
2112 let session_config = SessionConfig::new().with_batch_size(batch_size);
2113 let runtime = RuntimeEnvBuilder::new()
2114 .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
2115 .build_arc()
2116 .unwrap();
2117 let task_ctx = TaskContext::default()
2118 .with_session_config(session_config)
2119 .with_runtime(runtime);
2120 Arc::new(task_ctx)
2121 }
2122
2123 async fn check_grouping_sets(
2124 input: Arc<dyn ExecutionPlan>,
2125 spill: bool,
2126 ) -> Result<()> {
2127 let input_schema = input.schema();
2128
2129 let grouping_set = PhysicalGroupBy::new(
2130 vec![
2131 (col("a", &input_schema)?, "a".to_string()),
2132 (col("b", &input_schema)?, "b".to_string()),
2133 ],
2134 vec![
2135 (lit(ScalarValue::UInt32(None)), "a".to_string()),
2136 (lit(ScalarValue::Float64(None)), "b".to_string()),
2137 ],
2138 vec![
2139 vec![false, true], vec![true, false], vec![false, false], ],
2143 true,
2144 );
2145
2146 let aggregates = vec![Arc::new(
2147 AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
2148 .schema(Arc::clone(&input_schema))
2149 .alias("COUNT(1)")
2150 .build()?,
2151 )];
2152
2153 let task_ctx = if spill {
2154 new_spill_ctx(4, 500)
2156 } else {
2157 Arc::new(TaskContext::default())
2158 };
2159
2160 let partial_aggregate = Arc::new(AggregateExec::try_new(
2161 AggregateMode::Partial,
2162 grouping_set.clone(),
2163 aggregates.clone(),
2164 vec![None],
2165 input,
2166 Arc::clone(&input_schema),
2167 )?);
2168
2169 let result =
2170 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2171
2172 if spill {
2173 allow_duplicates! {
2176 assert_snapshot!(batches_to_sort_string(&result),
2177 @r"
2178 +---+-----+---------------+-----------------+
2179 | a | b | __grouping_id | COUNT(1)[count] |
2180 +---+-----+---------------+-----------------+
2181 | | 1.0 | 2 | 1 |
2182 | | 1.0 | 2 | 1 |
2183 | | 2.0 | 2 | 1 |
2184 | | 2.0 | 2 | 1 |
2185 | | 3.0 | 2 | 1 |
2186 | | 3.0 | 2 | 1 |
2187 | | 4.0 | 2 | 1 |
2188 | | 4.0 | 2 | 1 |
2189 | 2 | | 1 | 1 |
2190 | 2 | | 1 | 1 |
2191 | 2 | 1.0 | 0 | 1 |
2192 | 2 | 1.0 | 0 | 1 |
2193 | 3 | | 1 | 1 |
2194 | 3 | | 1 | 2 |
2195 | 3 | 2.0 | 0 | 2 |
2196 | 3 | 3.0 | 0 | 1 |
2197 | 4 | | 1 | 1 |
2198 | 4 | | 1 | 2 |
2199 | 4 | 3.0 | 0 | 1 |
2200 | 4 | 4.0 | 0 | 2 |
2201 +---+-----+---------------+-----------------+
2202 "
2203 );
2204 }
2205 } else {
2206 allow_duplicates! {
2207 assert_snapshot!(batches_to_sort_string(&result),
2208 @r"
2209 +---+-----+---------------+-----------------+
2210 | a | b | __grouping_id | COUNT(1)[count] |
2211 +---+-----+---------------+-----------------+
2212 | | 1.0 | 2 | 2 |
2213 | | 2.0 | 2 | 2 |
2214 | | 3.0 | 2 | 2 |
2215 | | 4.0 | 2 | 2 |
2216 | 2 | | 1 | 2 |
2217 | 2 | 1.0 | 0 | 2 |
2218 | 3 | | 1 | 3 |
2219 | 3 | 2.0 | 0 | 2 |
2220 | 3 | 3.0 | 0 | 1 |
2221 | 4 | | 1 | 3 |
2222 | 4 | 3.0 | 0 | 1 |
2223 | 4 | 4.0 | 0 | 2 |
2224 +---+-----+---------------+-----------------+
2225 "
2226 );
2227 }
2228 };
2229
2230 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2231
2232 let final_grouping_set = grouping_set.as_final();
2233
2234 let task_ctx = if spill {
2235 new_spill_ctx(4, 3160)
2236 } else {
2237 task_ctx
2238 };
2239
2240 let merged_aggregate = Arc::new(AggregateExec::try_new(
2241 AggregateMode::Final,
2242 final_grouping_set,
2243 aggregates,
2244 vec![None],
2245 merge,
2246 input_schema,
2247 )?);
2248
2249 let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2250 let batch = concat_batches(&result[0].schema(), &result)?;
2251 assert_eq!(batch.num_columns(), 4);
2252 assert_eq!(batch.num_rows(), 12);
2253
2254 allow_duplicates! {
2255 assert_snapshot!(
2256 batches_to_sort_string(&result),
2257 @r"
2258 +---+-----+---------------+----------+
2259 | a | b | __grouping_id | COUNT(1) |
2260 +---+-----+---------------+----------+
2261 | | 1.0 | 2 | 2 |
2262 | | 2.0 | 2 | 2 |
2263 | | 3.0 | 2 | 2 |
2264 | | 4.0 | 2 | 2 |
2265 | 2 | | 1 | 2 |
2266 | 2 | 1.0 | 0 | 2 |
2267 | 3 | | 1 | 3 |
2268 | 3 | 2.0 | 0 | 2 |
2269 | 3 | 3.0 | 0 | 1 |
2270 | 4 | | 1 | 3 |
2271 | 4 | 3.0 | 0 | 1 |
2272 | 4 | 4.0 | 0 | 2 |
2273 +---+-----+---------------+----------+
2274 "
2275 );
2276 }
2277
2278 let metrics = merged_aggregate.metrics().unwrap();
2279 let output_rows = metrics.output_rows().unwrap();
2280 assert_eq!(12, output_rows);
2281
2282 Ok(())
2283 }
2284
2285 async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
2287 let input_schema = input.schema();
2288
2289 let grouping_set = PhysicalGroupBy::new(
2290 vec![(col("a", &input_schema)?, "a".to_string())],
2291 vec![],
2292 vec![vec![false]],
2293 false,
2294 );
2295
2296 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2297 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2298 .schema(Arc::clone(&input_schema))
2299 .alias("AVG(b)")
2300 .build()?,
2301 )];
2302
2303 let task_ctx = if spill {
2304 new_spill_ctx(2, 1600)
2306 } else {
2307 Arc::new(TaskContext::default())
2308 };
2309
2310 let partial_aggregate = Arc::new(AggregateExec::try_new(
2311 AggregateMode::Partial,
2312 grouping_set.clone(),
2313 aggregates.clone(),
2314 vec![None],
2315 input,
2316 Arc::clone(&input_schema),
2317 )?);
2318
2319 let result =
2320 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2321
2322 if spill {
2323 allow_duplicates! {
2324 assert_snapshot!(batches_to_sort_string(&result), @r"
2325 +---+---------------+-------------+
2326 | a | AVG(b)[count] | AVG(b)[sum] |
2327 +---+---------------+-------------+
2328 | 2 | 1 | 1.0 |
2329 | 2 | 1 | 1.0 |
2330 | 3 | 1 | 2.0 |
2331 | 3 | 2 | 5.0 |
2332 | 4 | 3 | 11.0 |
2333 +---+---------------+-------------+
2334 ");
2335 }
2336 } else {
2337 allow_duplicates! {
2338 assert_snapshot!(batches_to_sort_string(&result), @r"
2339 +---+---------------+-------------+
2340 | a | AVG(b)[count] | AVG(b)[sum] |
2341 +---+---------------+-------------+
2342 | 2 | 2 | 2.0 |
2343 | 3 | 3 | 7.0 |
2344 | 4 | 3 | 11.0 |
2345 +---+---------------+-------------+
2346 ");
2347 }
2348 };
2349
2350 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2351
2352 let final_grouping_set = grouping_set.as_final();
2353
2354 let merged_aggregate = Arc::new(AggregateExec::try_new(
2355 AggregateMode::Final,
2356 final_grouping_set,
2357 aggregates,
2358 vec![None],
2359 merge,
2360 input_schema,
2361 )?);
2362
2363 let final_stats = merged_aggregate.partition_statistics(None)?;
2365 assert!(final_stats.total_byte_size.get_value().is_some());
2366
2367 let task_ctx = if spill {
2368 new_spill_ctx(2, 2600)
2370 } else {
2371 Arc::clone(&task_ctx)
2372 };
2373 let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
2374 let batch = concat_batches(&result[0].schema(), &result)?;
2375 assert_eq!(batch.num_columns(), 2);
2376 assert_eq!(batch.num_rows(), 3);
2377
2378 allow_duplicates! {
2379 assert_snapshot!(batches_to_sort_string(&result), @r"
2380 +---+--------------------+
2381 | a | AVG(b) |
2382 +---+--------------------+
2383 | 2 | 1.0 |
2384 | 3 | 2.3333333333333335 |
2385 | 4 | 3.6666666666666665 |
2386 +---+--------------------+
2387 ");
2388 }
2391
2392 let metrics = merged_aggregate.metrics().unwrap();
2393 let output_rows = metrics.output_rows().unwrap();
2394 let spill_count = metrics.spill_count().unwrap();
2395 let spilled_bytes = metrics.spilled_bytes().unwrap();
2396 let spilled_rows = metrics.spilled_rows().unwrap();
2397
2398 if spill {
2399 assert_eq!(8, output_rows);
2402
2403 assert!(spill_count > 0);
2404 assert!(spilled_bytes > 0);
2405 assert!(spilled_rows > 0);
2406 } else {
2407 assert_eq!(3, output_rows);
2408
2409 assert_eq!(0, spill_count);
2410 assert_eq!(0, spilled_bytes);
2411 assert_eq!(0, spilled_rows);
2412 }
2413
2414 Ok(())
2415 }
2416
2417 #[derive(Debug)]
2420 struct TestYieldingExec {
2421 pub yield_first: bool,
2423 cache: Arc<PlanProperties>,
2424 }
2425
2426 impl TestYieldingExec {
2427 fn new(yield_first: bool) -> Self {
2428 let schema = some_data().0;
2429 let cache = Self::compute_properties(schema);
2430 Self {
2431 yield_first,
2432 cache: Arc::new(cache),
2433 }
2434 }
2435
2436 fn compute_properties(schema: SchemaRef) -> PlanProperties {
2438 PlanProperties::new(
2439 EquivalenceProperties::new(schema),
2440 Partitioning::UnknownPartitioning(1),
2441 EmissionType::Incremental,
2442 Boundedness::Bounded,
2443 )
2444 }
2445 }
2446
2447 impl DisplayAs for TestYieldingExec {
2448 fn fmt_as(
2449 &self,
2450 t: DisplayFormatType,
2451 f: &mut std::fmt::Formatter,
2452 ) -> std::fmt::Result {
2453 match t {
2454 DisplayFormatType::Default | DisplayFormatType::Verbose => {
2455 write!(f, "TestYieldingExec")
2456 }
2457 DisplayFormatType::TreeRender => {
2458 write!(f, "")
2460 }
2461 }
2462 }
2463 }
2464
2465 impl ExecutionPlan for TestYieldingExec {
2466 fn name(&self) -> &'static str {
2467 "TestYieldingExec"
2468 }
2469
2470 fn as_any(&self) -> &dyn Any {
2471 self
2472 }
2473
2474 fn properties(&self) -> &Arc<PlanProperties> {
2475 &self.cache
2476 }
2477
2478 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2479 vec![]
2480 }
2481
2482 fn with_new_children(
2483 self: Arc<Self>,
2484 _: Vec<Arc<dyn ExecutionPlan>>,
2485 ) -> Result<Arc<dyn ExecutionPlan>> {
2486 internal_err!("Children cannot be replaced in {self:?}")
2487 }
2488
2489 fn execute(
2490 &self,
2491 _partition: usize,
2492 _context: Arc<TaskContext>,
2493 ) -> Result<SendableRecordBatchStream> {
2494 let stream = if self.yield_first {
2495 TestYieldingStream::New
2496 } else {
2497 TestYieldingStream::Yielded
2498 };
2499
2500 Ok(Box::pin(stream))
2501 }
2502
2503 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
2504 if partition.is_some() {
2505 return Ok(Statistics::new_unknown(self.schema().as_ref()));
2506 }
2507 let (_, batches) = some_data();
2508 Ok(common::compute_record_batch_statistics(
2509 &[batches],
2510 &self.schema(),
2511 None,
2512 ))
2513 }
2514 }
2515
2516 enum TestYieldingStream {
2518 New,
2519 Yielded,
2520 ReturnedBatch1,
2521 ReturnedBatch2,
2522 }
2523
2524 impl Stream for TestYieldingStream {
2525 type Item = Result<RecordBatch>;
2526
2527 fn poll_next(
2528 mut self: std::pin::Pin<&mut Self>,
2529 cx: &mut Context<'_>,
2530 ) -> Poll<Option<Self::Item>> {
2531 match &*self {
2532 TestYieldingStream::New => {
2533 *(self.as_mut()) = TestYieldingStream::Yielded;
2534 cx.waker().wake_by_ref();
2535 Poll::Pending
2536 }
2537 TestYieldingStream::Yielded => {
2538 *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2539 Poll::Ready(Some(Ok(some_data().1[0].clone())))
2540 }
2541 TestYieldingStream::ReturnedBatch1 => {
2542 *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2543 Poll::Ready(Some(Ok(some_data().1[1].clone())))
2544 }
2545 TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2546 }
2547 }
2548 }
2549
2550 impl RecordBatchStream for TestYieldingStream {
2551 fn schema(&self) -> SchemaRef {
2552 some_data().0
2553 }
2554 }
2555
2556 #[tokio::test]
2559 async fn aggregate_source_not_yielding() -> Result<()> {
2560 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2561
2562 check_aggregates(input, false).await
2563 }
2564
2565 #[tokio::test]
2566 async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2567 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2568
2569 check_grouping_sets(input, false).await
2570 }
2571
2572 #[tokio::test]
2573 async fn aggregate_source_with_yielding() -> Result<()> {
2574 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2575
2576 check_aggregates(input, false).await
2577 }
2578
2579 #[tokio::test]
2580 async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2581 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2582
2583 check_grouping_sets(input, false).await
2584 }
2585
2586 #[tokio::test]
2587 async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2588 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2589
2590 check_aggregates(input, true).await
2591 }
2592
2593 #[tokio::test]
2594 async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2595 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2596
2597 check_grouping_sets(input, true).await
2598 }
2599
2600 #[tokio::test]
2601 async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2602 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2603
2604 check_aggregates(input, true).await
2605 }
2606
2607 #[tokio::test]
2608 async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2609 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2610
2611 check_grouping_sets(input, true).await
2612 }
2613
2614 fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2616 AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2617 .schema(schema)
2618 .alias("MEDIAN(a)")
2619 .build()
2620 }
2621
2622 #[tokio::test]
2623 async fn test_oom() -> Result<()> {
2624 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2625 let input_schema = input.schema();
2626
2627 let runtime = RuntimeEnvBuilder::new()
2628 .with_memory_limit(1, 1.0)
2629 .build_arc()?;
2630 let task_ctx = TaskContext::default().with_runtime(runtime);
2631 let task_ctx = Arc::new(task_ctx);
2632
2633 let groups_none = PhysicalGroupBy::default();
2634 let groups_some = PhysicalGroupBy::new(
2635 vec![(col("a", &input_schema)?, "a".to_string())],
2636 vec![],
2637 vec![vec![false]],
2638 false,
2639 );
2640
2641 let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2643 vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2644
2645 let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2647 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2648 .schema(Arc::clone(&input_schema))
2649 .alias("AVG(b)")
2650 .build()?,
2651 )];
2652
2653 for (version, groups, aggregates) in [
2654 (0, groups_none, aggregates_v0),
2655 (2, groups_some, aggregates_v2),
2656 ] {
2657 let n_aggr = aggregates.len();
2658 let partial_aggregate = Arc::new(AggregateExec::try_new(
2659 AggregateMode::Single,
2660 groups,
2661 aggregates,
2662 vec![None; n_aggr],
2663 Arc::clone(&input),
2664 Arc::clone(&input_schema),
2665 )?);
2666
2667 let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
2668
2669 match version {
2671 0 => {
2672 assert!(matches!(stream, StreamType::AggregateStream(_)));
2673 }
2674 1 => {
2675 assert!(matches!(stream, StreamType::GroupedHash(_)));
2676 }
2677 2 => {
2678 assert!(matches!(stream, StreamType::GroupedHash(_)));
2679 }
2680 _ => panic!("Unknown version: {version}"),
2681 }
2682
2683 let stream: SendableRecordBatchStream = stream.into();
2684 let err = collect(stream).await.unwrap_err();
2685
2686 let err = err.find_root();
2688 assert!(
2689 matches!(err, DataFusionError::ResourcesExhausted(_)),
2690 "Wrong error type: {err}",
2691 );
2692 }
2693
2694 Ok(())
2695 }
2696
2697 #[tokio::test]
2698 async fn test_drop_cancel_without_groups() -> Result<()> {
2699 let task_ctx = Arc::new(TaskContext::default());
2700 let schema =
2701 Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2702
2703 let groups = PhysicalGroupBy::default();
2704
2705 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2706 AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2707 .schema(Arc::clone(&schema))
2708 .alias("AVG(a)")
2709 .build()?,
2710 )];
2711
2712 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2713 let refs = blocking_exec.refs();
2714 let aggregate_exec = Arc::new(AggregateExec::try_new(
2715 AggregateMode::Partial,
2716 groups.clone(),
2717 aggregates.clone(),
2718 vec![None],
2719 blocking_exec,
2720 schema,
2721 )?);
2722
2723 let fut = crate::collect(aggregate_exec, task_ctx);
2724 let mut fut = fut.boxed();
2725
2726 assert_is_pending(&mut fut);
2727 drop(fut);
2728 assert_strong_count_converges_to_zero(refs).await;
2729
2730 Ok(())
2731 }
2732
2733 #[tokio::test]
2734 async fn test_drop_cancel_with_groups() -> Result<()> {
2735 let task_ctx = Arc::new(TaskContext::default());
2736 let schema = Arc::new(Schema::new(vec![
2737 Field::new("a", DataType::Float64, true),
2738 Field::new("b", DataType::Float64, true),
2739 ]));
2740
2741 let groups =
2742 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2743
2744 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2745 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2746 .schema(Arc::clone(&schema))
2747 .alias("AVG(b)")
2748 .build()?,
2749 )];
2750
2751 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2752 let refs = blocking_exec.refs();
2753 let aggregate_exec = Arc::new(AggregateExec::try_new(
2754 AggregateMode::Partial,
2755 groups,
2756 aggregates.clone(),
2757 vec![None],
2758 blocking_exec,
2759 schema,
2760 )?);
2761
2762 let fut = crate::collect(aggregate_exec, task_ctx);
2763 let mut fut = fut.boxed();
2764
2765 assert_is_pending(&mut fut);
2766 drop(fut);
2767 assert_strong_count_converges_to_zero(refs).await;
2768
2769 Ok(())
2770 }
2771
2772 #[tokio::test]
2773 async fn run_first_last_multi_partitions() -> Result<()> {
2774 for is_first_acc in [false, true] {
2775 for spill in [false, true] {
2776 first_last_multi_partitions(is_first_acc, spill, 4200).await?
2777 }
2778 }
2779 Ok(())
2780 }
2781
2782 fn test_first_value_agg_expr(
2784 schema: &Schema,
2785 sort_options: SortOptions,
2786 ) -> Result<Arc<AggregateFunctionExpr>> {
2787 let order_bys = vec![PhysicalSortExpr {
2788 expr: col("b", schema)?,
2789 options: sort_options,
2790 }];
2791 let args = [col("b", schema)?];
2792
2793 AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2794 .order_by(order_bys)
2795 .schema(Arc::new(schema.clone()))
2796 .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2797 .build()
2798 .map(Arc::new)
2799 }
2800
2801 fn test_last_value_agg_expr(
2803 schema: &Schema,
2804 sort_options: SortOptions,
2805 ) -> Result<Arc<AggregateFunctionExpr>> {
2806 let order_bys = vec![PhysicalSortExpr {
2807 expr: col("b", schema)?,
2808 options: sort_options,
2809 }];
2810 let args = [col("b", schema)?];
2811 AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2812 .order_by(order_bys)
2813 .schema(Arc::new(schema.clone()))
2814 .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2815 .build()
2816 .map(Arc::new)
2817 }
2818
2819 async fn first_last_multi_partitions(
2829 is_first_acc: bool,
2830 spill: bool,
2831 max_memory: usize,
2832 ) -> Result<()> {
2833 let task_ctx = if spill {
2834 new_spill_ctx(2, max_memory)
2835 } else {
2836 Arc::new(TaskContext::default())
2837 };
2838
2839 let (schema, data) = some_data_v2();
2840 let partition1 = data[0].clone();
2841 let partition2 = data[1].clone();
2842 let partition3 = data[2].clone();
2843 let partition4 = data[3].clone();
2844
2845 let groups =
2846 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2847
2848 let sort_options = SortOptions {
2849 descending: false,
2850 nulls_first: false,
2851 };
2852 let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2853 vec![test_first_value_agg_expr(&schema, sort_options)?]
2854 } else {
2855 vec![test_last_value_agg_expr(&schema, sort_options)?]
2856 };
2857
2858 let memory_exec = TestMemoryExec::try_new_exec(
2859 &[
2860 vec![partition1],
2861 vec![partition2],
2862 vec![partition3],
2863 vec![partition4],
2864 ],
2865 Arc::clone(&schema),
2866 None,
2867 )?;
2868 let aggregate_exec = Arc::new(AggregateExec::try_new(
2869 AggregateMode::Partial,
2870 groups.clone(),
2871 aggregates.clone(),
2872 vec![None],
2873 memory_exec,
2874 Arc::clone(&schema),
2875 )?);
2876 let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2877 as Arc<dyn ExecutionPlan>;
2878 let aggregate_final = Arc::new(AggregateExec::try_new(
2879 AggregateMode::Final,
2880 groups,
2881 aggregates.clone(),
2882 vec![None],
2883 coalesce,
2884 schema,
2885 )?) as Arc<dyn ExecutionPlan>;
2886
2887 let result = crate::collect(aggregate_final, task_ctx).await?;
2888 if is_first_acc {
2889 allow_duplicates! {
2890 assert_snapshot!(batches_to_string(&result), @r"
2891 +---+--------------------------------------------+
2892 | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
2893 +---+--------------------------------------------+
2894 | 2 | 0.0 |
2895 | 3 | 1.0 |
2896 | 4 | 3.0 |
2897 +---+--------------------------------------------+
2898 ");
2899 }
2900 } else {
2901 allow_duplicates! {
2902 assert_snapshot!(batches_to_string(&result), @r"
2903 +---+-------------------------------------------+
2904 | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
2905 +---+-------------------------------------------+
2906 | 2 | 3.0 |
2907 | 3 | 5.0 |
2908 | 4 | 6.0 |
2909 +---+-------------------------------------------+
2910 ");
2911 }
2912 };
2913 Ok(())
2914 }
2915
2916 #[tokio::test]
2917 async fn test_get_finest_requirements() -> Result<()> {
2918 let test_schema = create_test_schema()?;
2919
2920 let options = SortOptions {
2921 descending: false,
2922 nulls_first: false,
2923 };
2924 let col_a = &col("a", &test_schema)?;
2925 let col_b = &col("b", &test_schema)?;
2926 let col_c = &col("c", &test_schema)?;
2927 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2928 eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
2930 let order_by_exprs = vec![
2933 vec![],
2934 vec![PhysicalSortExpr {
2935 expr: Arc::clone(col_a),
2936 options,
2937 }],
2938 vec![
2939 PhysicalSortExpr {
2940 expr: Arc::clone(col_a),
2941 options,
2942 },
2943 PhysicalSortExpr {
2944 expr: Arc::clone(col_b),
2945 options,
2946 },
2947 PhysicalSortExpr {
2948 expr: Arc::clone(col_c),
2949 options,
2950 },
2951 ],
2952 vec![
2953 PhysicalSortExpr {
2954 expr: Arc::clone(col_a),
2955 options,
2956 },
2957 PhysicalSortExpr {
2958 expr: Arc::clone(col_b),
2959 options,
2960 },
2961 ],
2962 ];
2963
2964 let common_requirement = vec![
2965 PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
2966 PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
2967 ];
2968 let mut aggr_exprs = order_by_exprs
2969 .into_iter()
2970 .map(|order_by_expr| {
2971 AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2972 .alias("a")
2973 .order_by(order_by_expr)
2974 .schema(Arc::clone(&test_schema))
2975 .build()
2976 .map(Arc::new)
2977 .unwrap()
2978 })
2979 .collect::<Vec<_>>();
2980 let group_by = PhysicalGroupBy::new_single(vec![]);
2981 let result = get_finer_aggregate_exprs_requirement(
2982 &mut aggr_exprs,
2983 &group_by,
2984 &eq_properties,
2985 &AggregateMode::Partial,
2986 )?;
2987 assert_eq!(result, common_requirement);
2988 Ok(())
2989 }
2990
2991 #[test]
2992 fn test_agg_exec_same_schema() -> Result<()> {
2993 let schema = Arc::new(Schema::new(vec![
2994 Field::new("a", DataType::Float32, true),
2995 Field::new("b", DataType::Float32, true),
2996 ]));
2997
2998 let col_a = col("a", &schema)?;
2999 let option_desc = SortOptions {
3000 descending: true,
3001 nulls_first: true,
3002 };
3003 let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
3004
3005 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3006 test_first_value_agg_expr(&schema, option_desc)?,
3007 test_last_value_agg_expr(&schema, option_desc)?,
3008 ];
3009 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
3010 let aggregate_exec = Arc::new(AggregateExec::try_new(
3011 AggregateMode::Partial,
3012 groups,
3013 aggregates,
3014 vec![None, None],
3015 Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
3016 schema,
3017 )?);
3018 let new_agg =
3019 Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
3020 assert_eq!(new_agg.schema(), aggregate_exec.schema());
3021 Ok(())
3022 }
3023
3024 #[tokio::test]
3025 async fn test_agg_exec_group_by_const() -> Result<()> {
3026 let schema = Arc::new(Schema::new(vec![
3027 Field::new("a", DataType::Float32, true),
3028 Field::new("b", DataType::Float32, true),
3029 Field::new("const", DataType::Int32, false),
3030 ]));
3031
3032 let col_a = col("a", &schema)?;
3033 let col_b = col("b", &schema)?;
3034 let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
3035
3036 let groups = PhysicalGroupBy::new(
3037 vec![
3038 (col_a, "a".to_string()),
3039 (col_b, "b".to_string()),
3040 (const_expr, "const".to_string()),
3041 ],
3042 vec![
3043 (
3044 Arc::new(Literal::new(ScalarValue::Float32(None))),
3045 "a".to_string(),
3046 ),
3047 (
3048 Arc::new(Literal::new(ScalarValue::Float32(None))),
3049 "b".to_string(),
3050 ),
3051 (
3052 Arc::new(Literal::new(ScalarValue::Int32(None))),
3053 "const".to_string(),
3054 ),
3055 ],
3056 vec![
3057 vec![false, true, true],
3058 vec![true, false, true],
3059 vec![true, true, false],
3060 ],
3061 true,
3062 );
3063
3064 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3065 AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
3066 .schema(Arc::clone(&schema))
3067 .alias("1")
3068 .build()
3069 .map(Arc::new)?,
3070 ];
3071
3072 let input_batches = (0..4)
3073 .map(|_| {
3074 let a = Arc::new(Float32Array::from(vec![0.; 8192]));
3075 let b = Arc::new(Float32Array::from(vec![0.; 8192]));
3076 let c = Arc::new(Int32Array::from(vec![1; 8192]));
3077
3078 RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
3079 })
3080 .collect();
3081
3082 let input =
3083 TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
3084
3085 let aggregate_exec = Arc::new(AggregateExec::try_new(
3086 AggregateMode::Single,
3087 groups,
3088 aggregates.clone(),
3089 vec![None],
3090 input,
3091 schema,
3092 )?);
3093
3094 let output =
3095 collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
3096
3097 allow_duplicates! {
3098 assert_snapshot!(batches_to_sort_string(&output), @r"
3099 +-----+-----+-------+---------------+-------+
3100 | a | b | const | __grouping_id | 1 |
3101 +-----+-----+-------+---------------+-------+
3102 | | | 1 | 6 | 32768 |
3103 | | 0.0 | | 5 | 32768 |
3104 | 0.0 | | | 3 | 32768 |
3105 +-----+-----+-------+---------------+-------+
3106 ");
3107 }
3108
3109 Ok(())
3110 }
3111
3112 #[tokio::test]
3113 async fn test_agg_exec_struct_of_dicts() -> Result<()> {
3114 let batch = RecordBatch::try_new(
3115 Arc::new(Schema::new(vec![
3116 Field::new(
3117 "labels".to_string(),
3118 DataType::Struct(
3119 vec![
3120 Field::new(
3121 "a".to_string(),
3122 DataType::Dictionary(
3123 Box::new(DataType::Int32),
3124 Box::new(DataType::Utf8),
3125 ),
3126 true,
3127 ),
3128 Field::new(
3129 "b".to_string(),
3130 DataType::Dictionary(
3131 Box::new(DataType::Int32),
3132 Box::new(DataType::Utf8),
3133 ),
3134 true,
3135 ),
3136 ]
3137 .into(),
3138 ),
3139 false,
3140 ),
3141 Field::new("value", DataType::UInt64, false),
3142 ])),
3143 vec![
3144 Arc::new(StructArray::from(vec![
3145 (
3146 Arc::new(Field::new(
3147 "a".to_string(),
3148 DataType::Dictionary(
3149 Box::new(DataType::Int32),
3150 Box::new(DataType::Utf8),
3151 ),
3152 true,
3153 )),
3154 Arc::new(
3155 vec![Some("a"), None, Some("a")]
3156 .into_iter()
3157 .collect::<DictionaryArray<Int32Type>>(),
3158 ) as ArrayRef,
3159 ),
3160 (
3161 Arc::new(Field::new(
3162 "b".to_string(),
3163 DataType::Dictionary(
3164 Box::new(DataType::Int32),
3165 Box::new(DataType::Utf8),
3166 ),
3167 true,
3168 )),
3169 Arc::new(
3170 vec![Some("b"), Some("c"), Some("b")]
3171 .into_iter()
3172 .collect::<DictionaryArray<Int32Type>>(),
3173 ) as ArrayRef,
3174 ),
3175 ])),
3176 Arc::new(UInt64Array::from(vec![1, 1, 1])),
3177 ],
3178 )
3179 .expect("Failed to create RecordBatch");
3180
3181 let group_by = PhysicalGroupBy::new_single(vec![(
3182 col("labels", &batch.schema())?,
3183 "labels".to_string(),
3184 )]);
3185
3186 let aggr_expr = vec![
3187 AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
3188 .schema(Arc::clone(&batch.schema()))
3189 .alias(String::from("SUM(value)"))
3190 .build()
3191 .map(Arc::new)?,
3192 ];
3193
3194 let input = TestMemoryExec::try_new_exec(
3195 &[vec![batch.clone()]],
3196 Arc::<Schema>::clone(&batch.schema()),
3197 None,
3198 )?;
3199 let aggregate_exec = Arc::new(AggregateExec::try_new(
3200 AggregateMode::FinalPartitioned,
3201 group_by,
3202 aggr_expr,
3203 vec![None],
3204 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3205 batch.schema(),
3206 )?);
3207
3208 let session_config = SessionConfig::default();
3209 let ctx = TaskContext::default().with_session_config(session_config);
3210 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3211
3212 allow_duplicates! {
3213 assert_snapshot!(batches_to_string(&output), @r"
3214 +--------------+------------+
3215 | labels | SUM(value) |
3216 +--------------+------------+
3217 | {a: a, b: b} | 2 |
3218 | {a: , b: c} | 1 |
3219 +--------------+------------+
3220 ");
3221 }
3222
3223 Ok(())
3224 }
3225
3226 #[tokio::test]
3227 async fn test_skip_aggregation_after_first_batch() -> Result<()> {
3228 let schema = Arc::new(Schema::new(vec![
3229 Field::new("key", DataType::Int32, true),
3230 Field::new("val", DataType::Int32, true),
3231 ]));
3232
3233 let group_by =
3234 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3235
3236 let aggr_expr = vec![
3237 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3238 .schema(Arc::clone(&schema))
3239 .alias(String::from("COUNT(val)"))
3240 .build()
3241 .map(Arc::new)?,
3242 ];
3243
3244 let input_data = vec![
3245 RecordBatch::try_new(
3246 Arc::clone(&schema),
3247 vec![
3248 Arc::new(Int32Array::from(vec![1, 2, 3])),
3249 Arc::new(Int32Array::from(vec![0, 0, 0])),
3250 ],
3251 )
3252 .unwrap(),
3253 RecordBatch::try_new(
3254 Arc::clone(&schema),
3255 vec![
3256 Arc::new(Int32Array::from(vec![2, 3, 4])),
3257 Arc::new(Int32Array::from(vec![0, 0, 0])),
3258 ],
3259 )
3260 .unwrap(),
3261 ];
3262
3263 let input =
3264 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3265 let aggregate_exec = Arc::new(AggregateExec::try_new(
3266 AggregateMode::Partial,
3267 group_by,
3268 aggr_expr,
3269 vec![None],
3270 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3271 schema,
3272 )?);
3273
3274 let mut session_config = SessionConfig::default();
3275 session_config = session_config.set(
3276 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3277 &ScalarValue::Int64(Some(2)),
3278 );
3279 session_config = session_config.set(
3280 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3281 &ScalarValue::Float64(Some(0.1)),
3282 );
3283
3284 let ctx = TaskContext::default().with_session_config(session_config);
3285 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3286
3287 allow_duplicates! {
3288 assert_snapshot!(batches_to_string(&output), @r"
3289 +-----+-------------------+
3290 | key | COUNT(val)[count] |
3291 +-----+-------------------+
3292 | 1 | 1 |
3293 | 2 | 1 |
3294 | 3 | 1 |
3295 | 2 | 1 |
3296 | 3 | 1 |
3297 | 4 | 1 |
3298 +-----+-------------------+
3299 ");
3300 }
3301
3302 Ok(())
3303 }
3304
3305 #[tokio::test]
3306 async fn test_skip_aggregation_after_threshold() -> Result<()> {
3307 let schema = Arc::new(Schema::new(vec![
3308 Field::new("key", DataType::Int32, true),
3309 Field::new("val", DataType::Int32, true),
3310 ]));
3311
3312 let group_by =
3313 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3314
3315 let aggr_expr = vec![
3316 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3317 .schema(Arc::clone(&schema))
3318 .alias(String::from("COUNT(val)"))
3319 .build()
3320 .map(Arc::new)?,
3321 ];
3322
3323 let input_data = vec![
3324 RecordBatch::try_new(
3325 Arc::clone(&schema),
3326 vec![
3327 Arc::new(Int32Array::from(vec![1, 2, 3])),
3328 Arc::new(Int32Array::from(vec![0, 0, 0])),
3329 ],
3330 )
3331 .unwrap(),
3332 RecordBatch::try_new(
3333 Arc::clone(&schema),
3334 vec![
3335 Arc::new(Int32Array::from(vec![2, 3, 4])),
3336 Arc::new(Int32Array::from(vec![0, 0, 0])),
3337 ],
3338 )
3339 .unwrap(),
3340 RecordBatch::try_new(
3341 Arc::clone(&schema),
3342 vec![
3343 Arc::new(Int32Array::from(vec![2, 3, 4])),
3344 Arc::new(Int32Array::from(vec![0, 0, 0])),
3345 ],
3346 )
3347 .unwrap(),
3348 ];
3349
3350 let input =
3351 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3352 let aggregate_exec = Arc::new(AggregateExec::try_new(
3353 AggregateMode::Partial,
3354 group_by,
3355 aggr_expr,
3356 vec![None],
3357 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3358 schema,
3359 )?);
3360
3361 let mut session_config = SessionConfig::default();
3362 session_config = session_config.set(
3363 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3364 &ScalarValue::Int64(Some(5)),
3365 );
3366 session_config = session_config.set(
3367 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3368 &ScalarValue::Float64(Some(0.1)),
3369 );
3370
3371 let ctx = TaskContext::default().with_session_config(session_config);
3372 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3373
3374 allow_duplicates! {
3375 assert_snapshot!(batches_to_string(&output), @r"
3376 +-----+-------------------+
3377 | key | COUNT(val)[count] |
3378 +-----+-------------------+
3379 | 1 | 1 |
3380 | 2 | 2 |
3381 | 3 | 2 |
3382 | 4 | 1 |
3383 | 2 | 1 |
3384 | 3 | 1 |
3385 | 4 | 1 |
3386 +-----+-------------------+
3387 ");
3388 }
3389
3390 Ok(())
3391 }
3392
3393 #[test]
3394 fn group_exprs_nullable() -> Result<()> {
3395 let input_schema = Arc::new(Schema::new(vec![
3396 Field::new("a", DataType::Float32, false),
3397 Field::new("b", DataType::Float32, false),
3398 ]));
3399
3400 let aggr_expr = vec![
3401 AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
3402 .schema(Arc::clone(&input_schema))
3403 .alias("COUNT(a)")
3404 .build()
3405 .map(Arc::new)?,
3406 ];
3407
3408 let grouping_set = PhysicalGroupBy::new(
3409 vec![
3410 (col("a", &input_schema)?, "a".to_string()),
3411 (col("b", &input_schema)?, "b".to_string()),
3412 ],
3413 vec![
3414 (lit(ScalarValue::Float32(None)), "a".to_string()),
3415 (lit(ScalarValue::Float32(None)), "b".to_string()),
3416 ],
3417 vec![
3418 vec![false, true], vec![false, false], ],
3421 true,
3422 );
3423 let aggr_schema = create_schema(
3424 &input_schema,
3425 &grouping_set,
3426 &aggr_expr,
3427 AggregateMode::Final,
3428 )?;
3429 let expected_schema = Schema::new(vec![
3430 Field::new("a", DataType::Float32, false),
3431 Field::new("b", DataType::Float32, true),
3432 Field::new("__grouping_id", DataType::UInt8, false),
3433 Field::new("COUNT(a)", DataType::Int64, false),
3434 ]);
3435 assert_eq!(aggr_schema, expected_schema);
3436 Ok(())
3437 }
3438
3439 async fn run_test_with_spill_pool_if_necessary(
3441 pool_size: usize,
3442 expect_spill: bool,
3443 ) -> Result<()> {
3444 fn create_record_batch(
3445 schema: &Arc<Schema>,
3446 data: (Vec<u32>, Vec<f64>),
3447 ) -> Result<RecordBatch> {
3448 Ok(RecordBatch::try_new(
3449 Arc::clone(schema),
3450 vec![
3451 Arc::new(UInt32Array::from(data.0)),
3452 Arc::new(Float64Array::from(data.1)),
3453 ],
3454 )?)
3455 }
3456
3457 let schema = Arc::new(Schema::new(vec![
3458 Field::new("a", DataType::UInt32, false),
3459 Field::new("b", DataType::Float64, false),
3460 ]));
3461
3462 let batches = vec![
3463 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3464 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3465 ];
3466 let plan: Arc<dyn ExecutionPlan> =
3467 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3468
3469 let grouping_set = PhysicalGroupBy::new(
3470 vec![(col("a", &schema)?, "a".to_string())],
3471 vec![],
3472 vec![vec![false]],
3473 false,
3474 );
3475
3476 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3478 Arc::new(
3479 AggregateExprBuilder::new(
3480 datafusion_functions_aggregate::min_max::min_udaf(),
3481 vec![col("b", &schema)?],
3482 )
3483 .schema(Arc::clone(&schema))
3484 .alias("MIN(b)")
3485 .build()?,
3486 ),
3487 Arc::new(
3488 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3489 .schema(Arc::clone(&schema))
3490 .alias("AVG(b)")
3491 .build()?,
3492 ),
3493 ];
3494
3495 let single_aggregate = Arc::new(AggregateExec::try_new(
3496 AggregateMode::Single,
3497 grouping_set,
3498 aggregates,
3499 vec![None, None],
3500 plan,
3501 Arc::clone(&schema),
3502 )?);
3503
3504 let batch_size = 2;
3505 let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3506 let task_ctx = Arc::new(
3507 TaskContext::default()
3508 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3509 .with_runtime(Arc::new(
3510 RuntimeEnvBuilder::new()
3511 .with_memory_pool(memory_pool)
3512 .build()?,
3513 )),
3514 );
3515
3516 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3517
3518 assert_spill_count_metric(expect_spill, single_aggregate);
3519
3520 allow_duplicates! {
3521 assert_snapshot!(batches_to_string(&result), @r"
3522 +---+--------+--------+
3523 | a | MIN(b) | AVG(b) |
3524 +---+--------+--------+
3525 | 2 | 1.0 | 1.0 |
3526 | 3 | 2.0 | 2.0 |
3527 | 4 | 3.0 | 3.5 |
3528 +---+--------+--------+
3529 ");
3530 }
3531
3532 Ok(())
3533 }
3534
3535 fn assert_spill_count_metric(
3536 expect_spill: bool,
3537 single_aggregate: Arc<AggregateExec>,
3538 ) {
3539 if let Some(metrics_set) = single_aggregate.metrics() {
3540 let mut spill_count = 0;
3541
3542 for metric in metrics_set.iter() {
3544 if let MetricValue::SpillCount(count) = metric.value() {
3545 spill_count = count.value();
3546 break;
3547 }
3548 }
3549
3550 if expect_spill && spill_count == 0 {
3551 panic!(
3552 "Expected spill but SpillCount metric not found or SpillCount was 0."
3553 );
3554 } else if !expect_spill && spill_count > 0 {
3555 panic!(
3556 "Expected no spill but found SpillCount metric with value greater than 0."
3557 );
3558 }
3559 } else {
3560 panic!("No metrics returned from the operator; cannot verify spilling.");
3561 }
3562 }
3563
3564 #[tokio::test]
3565 async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3566 run_test_with_spill_pool_if_necessary(2_000, true).await?;
3568 run_test_with_spill_pool_if_necessary(20_000, false).await?;
3570 Ok(())
3571 }
3572
3573 #[tokio::test]
3574 async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
3575 fn create_record_batch(
3577 schema: &Arc<Schema>,
3578 data: (Vec<u32>, Vec<f64>),
3579 ) -> Result<RecordBatch> {
3580 Ok(RecordBatch::try_new(
3581 Arc::clone(schema),
3582 vec![
3583 Arc::new(UInt32Array::from(data.0)),
3584 Arc::new(Float64Array::from(data.1)),
3585 ],
3586 )?)
3587 }
3588
3589 let schema = Arc::new(Schema::new(vec![
3590 Field::new("a", DataType::UInt32, false),
3591 Field::new("b", DataType::Float64, false),
3592 ]));
3593
3594 let batches = vec![
3595 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3596 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3597 ];
3598 let plan: Arc<dyn ExecutionPlan> =
3599 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3600 let proj = ProjectionExec::try_new(
3601 vec![
3602 ProjectionExpr::new(lit("0"), "l".to_string()),
3603 ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
3604 ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
3605 ],
3606 plan,
3607 )?;
3608 let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
3609 let schema = plan.schema();
3610
3611 let grouping_set = PhysicalGroupBy::new(
3612 vec![
3613 (col("l", &schema)?, "l".to_string()),
3614 (col("a", &schema)?, "a".to_string()),
3615 ],
3616 vec![],
3617 vec![vec![false, false]],
3618 false,
3619 );
3620
3621 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3623 Arc::new(
3624 AggregateExprBuilder::new(
3625 datafusion_functions_aggregate::min_max::min_udaf(),
3626 vec![col("b", &schema)?],
3627 )
3628 .schema(Arc::clone(&schema))
3629 .alias("MIN(b)")
3630 .build()?,
3631 ),
3632 Arc::new(
3633 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3634 .schema(Arc::clone(&schema))
3635 .alias("AVG(b)")
3636 .build()?,
3637 ),
3638 ];
3639
3640 let single_aggregate = Arc::new(AggregateExec::try_new(
3641 AggregateMode::Single,
3642 grouping_set,
3643 aggregates,
3644 vec![None, None],
3645 plan,
3646 Arc::clone(&schema),
3647 )?);
3648
3649 let batch_size = 2;
3650 let memory_pool = Arc::new(FairSpillPool::new(2000));
3651 let task_ctx = Arc::new(
3652 TaskContext::default()
3653 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3654 .with_runtime(Arc::new(
3655 RuntimeEnvBuilder::new()
3656 .with_memory_pool(memory_pool)
3657 .build()?,
3658 )),
3659 );
3660
3661 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
3662 match result {
3663 Ok(result) => {
3664 assert_spill_count_metric(true, single_aggregate);
3665
3666 allow_duplicates! {
3667 assert_snapshot!(batches_to_string(&result), @r"
3668 +---+---+--------+--------+
3669 | l | a | MIN(b) | AVG(b) |
3670 +---+---+--------+--------+
3671 | 0 | 2 | 1.0 | 1.0 |
3672 | 0 | 3 | 2.0 | 2.0 |
3673 | 0 | 4 | 3.0 | 3.5 |
3674 +---+---+--------+--------+
3675 ");
3676 }
3677 }
3678 Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
3679 }
3680
3681 Ok(())
3682 }
3683
3684 #[tokio::test]
3685 async fn test_aggregate_statistics_edge_cases() -> Result<()> {
3686 use crate::test::exec::StatisticsExec;
3687 use datafusion_common::ColumnStatistics;
3688
3689 let schema = Arc::new(Schema::new(vec![
3690 Field::new("a", DataType::Int32, false),
3691 Field::new("b", DataType::Float64, false),
3692 ]));
3693
3694 let input = Arc::new(StatisticsExec::new(
3696 Statistics {
3697 num_rows: Precision::Exact(100),
3698 total_byte_size: Precision::Absent,
3699 column_statistics: vec![
3700 ColumnStatistics::new_unknown(),
3701 ColumnStatistics::new_unknown(),
3702 ],
3703 },
3704 (*schema).clone(),
3705 )) as Arc<dyn ExecutionPlan>;
3706
3707 let agg = Arc::new(AggregateExec::try_new(
3708 AggregateMode::Final,
3709 PhysicalGroupBy::default(),
3710 vec![Arc::new(
3711 AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3712 .schema(Arc::clone(&schema))
3713 .alias("COUNT(a)")
3714 .build()?,
3715 )],
3716 vec![None],
3717 input,
3718 Arc::clone(&schema),
3719 )?);
3720
3721 let stats = agg.partition_statistics(None)?;
3722 assert_eq!(stats.total_byte_size, Precision::Absent);
3723
3724 let input_zero = Arc::new(StatisticsExec::new(
3726 Statistics {
3727 num_rows: Precision::Exact(0),
3728 total_byte_size: Precision::Exact(0),
3729 column_statistics: vec![
3730 ColumnStatistics::new_unknown(),
3731 ColumnStatistics::new_unknown(),
3732 ],
3733 },
3734 (*schema).clone(),
3735 )) as Arc<dyn ExecutionPlan>;
3736
3737 let agg_zero = Arc::new(AggregateExec::try_new(
3738 AggregateMode::Final,
3739 PhysicalGroupBy::default(),
3740 vec![Arc::new(
3741 AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?])
3742 .schema(Arc::clone(&schema))
3743 .alias("COUNT(a)")
3744 .build()?,
3745 )],
3746 vec![None],
3747 input_zero,
3748 Arc::clone(&schema),
3749 )?);
3750
3751 let stats_zero = agg_zero.partition_statistics(None)?;
3752 assert_eq!(stats_zero.total_byte_size, Precision::Absent);
3753
3754 Ok(())
3755 }
3756
3757 #[tokio::test]
3758 async fn test_order_is_retained_when_spilling() -> Result<()> {
3759 let schema = Arc::new(Schema::new(vec![
3760 Field::new("a", DataType::Int64, false),
3761 Field::new("b", DataType::Int64, false),
3762 Field::new("c", DataType::Int64, false),
3763 ]));
3764
3765 let batches = vec![vec![
3766 RecordBatch::try_new(
3767 Arc::clone(&schema),
3768 vec![
3769 Arc::new(Int64Array::from(vec![2])),
3770 Arc::new(Int64Array::from(vec![2])),
3771 Arc::new(Int64Array::from(vec![1])),
3772 ],
3773 )?,
3774 RecordBatch::try_new(
3775 Arc::clone(&schema),
3776 vec![
3777 Arc::new(Int64Array::from(vec![1])),
3778 Arc::new(Int64Array::from(vec![1])),
3779 Arc::new(Int64Array::from(vec![1])),
3780 ],
3781 )?,
3782 RecordBatch::try_new(
3783 Arc::clone(&schema),
3784 vec![
3785 Arc::new(Int64Array::from(vec![0])),
3786 Arc::new(Int64Array::from(vec![0])),
3787 Arc::new(Int64Array::from(vec![1])),
3788 ],
3789 )?,
3790 ]];
3791 let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3792 let scan = scan.try_with_sort_information(vec![
3793 LexOrdering::new([PhysicalSortExpr::new(
3794 col("b", schema.as_ref())?,
3795 SortOptions::default().desc(),
3796 )])
3797 .unwrap(),
3798 ])?;
3799
3800 let aggr = Arc::new(AggregateExec::try_new(
3801 AggregateMode::Single,
3802 PhysicalGroupBy::new(
3803 vec![
3804 (col("b", schema.as_ref())?, "b".to_string()),
3805 (col("c", schema.as_ref())?, "c".to_string()),
3806 ],
3807 vec![],
3808 vec![vec![false, false]],
3809 false,
3810 ),
3811 vec![Arc::new(
3812 AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
3813 .schema(Arc::clone(&schema))
3814 .alias("SUM(c)")
3815 .build()?,
3816 )],
3817 vec![None],
3818 Arc::new(scan) as Arc<dyn ExecutionPlan>,
3819 Arc::clone(&schema),
3820 )?);
3821
3822 let task_ctx = new_spill_ctx(1, 600);
3823 let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
3824 assert_spill_count_metric(true, aggr);
3825
3826 allow_duplicates! {
3827 assert_snapshot!(batches_to_string(&result), @r"
3828 +---+---+--------+
3829 | b | c | SUM(c) |
3830 +---+---+--------+
3831 | 2 | 1 | 1 |
3832 | 1 | 1 | 1 |
3833 | 0 | 1 | 1 |
3834 +---+---+--------+
3835 ");
3836 }
3837 Ok(())
3838 }
3839
3840 #[tokio::test]
3844 async fn test_sort_reservation_fails_during_spill() -> Result<()> {
3845 let schema = Arc::new(Schema::new(vec![
3846 Field::new("g", DataType::Int64, false),
3847 Field::new("a", DataType::Float64, false),
3848 Field::new("b", DataType::Float64, false),
3849 Field::new("c", DataType::Float64, false),
3850 Field::new("d", DataType::Float64, false),
3851 Field::new("e", DataType::Float64, false),
3852 ]));
3853
3854 let batches = vec![vec![
3855 RecordBatch::try_new(
3856 Arc::clone(&schema),
3857 vec![
3858 Arc::new(Int64Array::from(vec![1])),
3859 Arc::new(Float64Array::from(vec![10.0])),
3860 Arc::new(Float64Array::from(vec![20.0])),
3861 Arc::new(Float64Array::from(vec![30.0])),
3862 Arc::new(Float64Array::from(vec![40.0])),
3863 Arc::new(Float64Array::from(vec![50.0])),
3864 ],
3865 )?,
3866 RecordBatch::try_new(
3867 Arc::clone(&schema),
3868 vec![
3869 Arc::new(Int64Array::from(vec![2])),
3870 Arc::new(Float64Array::from(vec![11.0])),
3871 Arc::new(Float64Array::from(vec![21.0])),
3872 Arc::new(Float64Array::from(vec![31.0])),
3873 Arc::new(Float64Array::from(vec![41.0])),
3874 Arc::new(Float64Array::from(vec![51.0])),
3875 ],
3876 )?,
3877 RecordBatch::try_new(
3878 Arc::clone(&schema),
3879 vec![
3880 Arc::new(Int64Array::from(vec![3])),
3881 Arc::new(Float64Array::from(vec![12.0])),
3882 Arc::new(Float64Array::from(vec![22.0])),
3883 Arc::new(Float64Array::from(vec![32.0])),
3884 Arc::new(Float64Array::from(vec![42.0])),
3885 Arc::new(Float64Array::from(vec![52.0])),
3886 ],
3887 )?,
3888 ]];
3889
3890 let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
3891
3892 let aggr = Arc::new(AggregateExec::try_new(
3893 AggregateMode::Single,
3894 PhysicalGroupBy::new(
3895 vec![(col("g", schema.as_ref())?, "g".to_string())],
3896 vec![],
3897 vec![vec![false]],
3898 false,
3899 ),
3900 vec![
3901 Arc::new(
3902 AggregateExprBuilder::new(
3903 avg_udaf(),
3904 vec![col("a", schema.as_ref())?],
3905 )
3906 .schema(Arc::clone(&schema))
3907 .alias("AVG(a)")
3908 .build()?,
3909 ),
3910 Arc::new(
3911 AggregateExprBuilder::new(
3912 avg_udaf(),
3913 vec![col("b", schema.as_ref())?],
3914 )
3915 .schema(Arc::clone(&schema))
3916 .alias("AVG(b)")
3917 .build()?,
3918 ),
3919 Arc::new(
3920 AggregateExprBuilder::new(
3921 avg_udaf(),
3922 vec![col("c", schema.as_ref())?],
3923 )
3924 .schema(Arc::clone(&schema))
3925 .alias("AVG(c)")
3926 .build()?,
3927 ),
3928 Arc::new(
3929 AggregateExprBuilder::new(
3930 avg_udaf(),
3931 vec![col("d", schema.as_ref())?],
3932 )
3933 .schema(Arc::clone(&schema))
3934 .alias("AVG(d)")
3935 .build()?,
3936 ),
3937 Arc::new(
3938 AggregateExprBuilder::new(
3939 avg_udaf(),
3940 vec![col("e", schema.as_ref())?],
3941 )
3942 .schema(Arc::clone(&schema))
3943 .alias("AVG(e)")
3944 .build()?,
3945 ),
3946 ],
3947 vec![None, None, None, None, None],
3948 Arc::new(scan) as Arc<dyn ExecutionPlan>,
3949 Arc::clone(&schema),
3950 )?);
3951
3952 let task_ctx = new_spill_ctx(1, 500);
3955 let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;
3956
3957 match &result {
3958 Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
3959 Err(e) => {
3960 let root = e.find_root();
3961 assert!(
3962 matches!(root, DataFusionError::ResourcesExhausted(_)),
3963 "Expected ResourcesExhausted, got: {root}",
3964 );
3965 let msg = root.to_string();
3966 assert!(
3967 msg.contains("Failed to reserve memory for sort during spill"),
3968 "Expected sort reservation error, got: {msg}",
3969 );
3970 }
3971 }
3972
3973 Ok(())
3974 }
3975
3976 #[tokio::test]
3984 async fn test_partial_reduce_mode() -> Result<()> {
3985 let schema = Arc::new(Schema::new(vec![
3986 Field::new("a", DataType::UInt32, false),
3987 Field::new("b", DataType::Float64, false),
3988 ]));
3989
3990 let batch1 = RecordBatch::try_new(
3992 Arc::clone(&schema),
3993 vec![
3994 Arc::new(UInt32Array::from(vec![1, 2, 3])),
3995 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
3996 ],
3997 )?;
3998 let batch2 = RecordBatch::try_new(
3999 Arc::clone(&schema),
4000 vec![
4001 Arc::new(UInt32Array::from(vec![1, 2, 3])),
4002 Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
4003 ],
4004 )?;
4005
4006 let groups =
4007 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
4008 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
4009 AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
4010 .schema(Arc::clone(&schema))
4011 .alias("SUM(b)")
4012 .build()?,
4013 )];
4014
4015 let input1 =
4017 TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?;
4018 let partial1 = Arc::new(AggregateExec::try_new(
4019 AggregateMode::Partial,
4020 groups.clone(),
4021 aggregates.clone(),
4022 vec![None],
4023 input1,
4024 Arc::clone(&schema),
4025 )?);
4026
4027 let input2 =
4029 TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?;
4030 let partial2 = Arc::new(AggregateExec::try_new(
4031 AggregateMode::Partial,
4032 groups.clone(),
4033 aggregates.clone(),
4034 vec![None],
4035 input2,
4036 Arc::clone(&schema),
4037 )?);
4038
4039 let task_ctx = Arc::new(TaskContext::default());
4041 let partial_result1 =
4042 crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?;
4043 let partial_result2 =
4044 crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?;
4045
4046 let partial_schema = partial1.schema();
4048
4049 let combined_input = TestMemoryExec::try_new_exec(
4051 &[partial_result1, partial_result2],
4052 Arc::clone(&partial_schema),
4053 None,
4054 )?;
4055 let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
4057
4058 let partial_reduce = Arc::new(AggregateExec::try_new(
4059 AggregateMode::PartialReduce,
4060 groups.clone(),
4061 aggregates.clone(),
4062 vec![None],
4063 coalesced,
4064 Arc::clone(&partial_schema),
4065 )?);
4066
4067 assert_eq!(partial_reduce.schema(), partial_schema);
4070
4071 let reduce_result =
4073 crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx))
4074 .await?;
4075
4076 let final_input = TestMemoryExec::try_new_exec(
4078 &[reduce_result],
4079 Arc::clone(&partial_schema),
4080 None,
4081 )?;
4082 let final_agg = Arc::new(AggregateExec::try_new(
4083 AggregateMode::Final,
4084 groups.clone(),
4085 aggregates.clone(),
4086 vec![None],
4087 final_input,
4088 Arc::clone(&partial_schema),
4089 )?);
4090
4091 let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
4092
4093 assert_snapshot!(batches_to_sort_string(&result), @r"
4095 +---+--------+
4096 | a | SUM(b) |
4097 +---+--------+
4098 | 1 | 50.0 |
4099 | 2 | 70.0 |
4100 | 3 | 90.0 |
4101 +---+--------+
4102 ");
4103
4104 Ok(())
4105 }
4106}