1use std::borrow::Cow;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25 no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26 topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30 ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
31 FilterPushdownPropagation, PushedDownPredicate,
32};
33use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
34use crate::{
35 DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36 SendableRecordBatchStream, Statistics, check_if_same_properties,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use parking_lot::Mutex;
41use std::collections::{HashMap, HashSet};
42
43use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use arrow_schema::FieldRef;
47use datafusion_common::stats::Precision;
48use datafusion_common::{
49 Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err,
50 internal_err, not_impl_err,
51};
52use datafusion_execution::TaskContext;
53use datafusion_expr::{Accumulator, Aggregate};
54use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
55use datafusion_physical_expr::equivalence::ProjectionMapping;
56use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
57use datafusion_physical_expr::{
58 ConstExpr, EquivalenceProperties, physical_exprs_contains,
59};
60use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql};
61use datafusion_physical_expr_common::sort_expr::{
62 LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
63};
64
65use datafusion_expr::utils::AggregateOrderSensitivity;
66use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
67use itertools::Itertools;
68use topk::hash_table::is_supported_hash_key_type;
69use topk::heap::is_supported_heap_type;
70
71pub mod group_values;
72mod no_grouping;
73pub mod order;
74mod row_hash;
75mod topk;
76mod topk_stream;
77
78pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool {
86 is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type)
87}
88
89const AGGREGATION_HASH_SEED: datafusion_common::hash_utils::RandomState =
91 datafusion_common::hash_utils::RandomState::with_seed(15395726432021054657);
93
94#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
100pub enum AggregateInputMode {
101 Raw,
104 Partial,
107}
108
109#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
115pub enum AggregateOutputMode {
116 Partial,
119 Final,
122}
123
124#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
143pub enum AggregateMode {
144 Partial,
151 Final,
167 FinalPartitioned,
176 Single,
184 SinglePartitioned,
193 PartialReduce,
210}
211
212impl AggregateMode {
213 pub fn input_mode(&self) -> AggregateInputMode {
219 match self {
220 AggregateMode::Partial
221 | AggregateMode::Single
222 | AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
223 AggregateMode::Final
224 | AggregateMode::FinalPartitioned
225 | AggregateMode::PartialReduce => AggregateInputMode::Partial,
226 }
227 }
228
229 pub fn output_mode(&self) -> AggregateOutputMode {
235 match self {
236 AggregateMode::Final
237 | AggregateMode::FinalPartitioned
238 | AggregateMode::Single
239 | AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
240 AggregateMode::Partial | AggregateMode::PartialReduce => {
241 AggregateOutputMode::Partial
242 }
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default)]
266pub struct PhysicalGroupBy {
267 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
269 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
271 groups: Vec<Vec<bool>>,
276 has_grouping_set: bool,
279}
280
281impl PhysicalGroupBy {
282 pub fn new(
284 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
285 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
286 groups: Vec<Vec<bool>>,
287 has_grouping_set: bool,
288 ) -> Self {
289 Self {
290 expr,
291 null_expr,
292 groups,
293 has_grouping_set,
294 }
295 }
296
297 pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
300 let num_exprs = expr.len();
301 Self {
302 expr,
303 null_expr: vec![],
304 groups: vec![vec![false; num_exprs]],
305 has_grouping_set: false,
306 }
307 }
308
309 pub fn exprs_nullable(&self) -> Vec<bool> {
311 let mut exprs_nullable = vec![false; self.expr.len()];
312 for group in self.groups.iter() {
313 group.iter().enumerate().for_each(|(index, is_null)| {
314 if *is_null {
315 exprs_nullable[index] = true;
316 }
317 })
318 }
319 exprs_nullable
320 }
321
322 pub fn is_true_no_grouping(&self) -> bool {
324 self.is_empty() && !self.has_grouping_set
325 }
326
327 pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
329 &self.expr
330 }
331
332 pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
334 &self.null_expr
335 }
336
337 pub fn groups(&self) -> &[Vec<bool>] {
339 &self.groups
340 }
341
342 pub fn has_grouping_set(&self) -> bool {
344 self.has_grouping_set
345 }
346
347 pub fn is_empty(&self) -> bool {
349 self.expr.is_empty()
350 }
351
352 pub fn is_single(&self) -> bool {
355 !self.has_grouping_set
356 }
357
358 pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
360 self.expr
361 .iter()
362 .map(|(expr, _alias)| Arc::clone(expr))
363 .collect()
364 }
365
366 fn num_output_exprs(&self) -> usize {
368 let mut num_exprs = self.expr.len();
369 if self.has_grouping_set {
370 num_exprs += 1
371 }
372 num_exprs
373 }
374
375 pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
377 let num_output_exprs = self.num_output_exprs();
378 let mut output_exprs = Vec::with_capacity(num_output_exprs);
379 output_exprs.extend(
380 self.expr
381 .iter()
382 .enumerate()
383 .take(num_output_exprs)
384 .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
385 );
386 if self.has_grouping_set {
387 output_exprs.push(Arc::new(Column::new(
388 Aggregate::INTERNAL_GROUPING_ID,
389 self.expr.len(),
390 )) as _);
391 }
392 output_exprs
393 }
394
395 pub fn num_group_exprs(&self) -> usize {
397 self.expr.len() + usize::from(self.has_grouping_set)
398 }
399
400 fn grouping_id_data_type(&self) -> DataType {
406 Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups))
407 }
408
409 pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
410 Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
411 }
412
413 fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
415 let mut fields = Vec::with_capacity(self.num_group_exprs());
416 for ((expr, name), group_expr_nullable) in
417 self.expr.iter().zip(self.exprs_nullable())
418 {
419 fields.push(
420 Field::new(
421 name,
422 expr.data_type(input_schema)?,
423 group_expr_nullable || expr.nullable(input_schema)?,
424 )
425 .with_metadata(expr.return_field(input_schema)?.metadata().clone())
426 .into(),
427 );
428 }
429 if self.has_grouping_set {
430 fields.push(
431 Field::new(
432 Aggregate::INTERNAL_GROUPING_ID,
433 self.grouping_id_data_type(),
434 false,
435 )
436 .into(),
437 );
438 }
439 Ok(fields)
440 }
441
442 fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
447 let mut fields = self.group_fields(input_schema)?;
448 fields.truncate(self.num_output_exprs());
449 Ok(fields)
450 }
451
452 pub fn as_final(&self) -> PhysicalGroupBy {
455 let expr: Vec<_> =
456 self.output_exprs()
457 .into_iter()
458 .zip(
459 self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
460 Aggregate::INTERNAL_GROUPING_ID.to_owned(),
461 )),
462 )
463 .collect();
464 let num_exprs = expr.len();
465 let groups = if self.expr.is_empty() && !self.has_grouping_set {
466 vec![]
468 } else {
469 vec![vec![false; num_exprs]]
470 };
471 Self {
472 expr,
473 null_expr: vec![],
474 groups,
475 has_grouping_set: false,
476 }
477 }
478}
479
480impl PartialEq for PhysicalGroupBy {
481 fn eq(&self, other: &PhysicalGroupBy) -> bool {
482 self.expr.len() == other.expr.len()
483 && self
484 .expr
485 .iter()
486 .zip(other.expr.iter())
487 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
488 && self.null_expr.len() == other.null_expr.len()
489 && self
490 .null_expr
491 .iter()
492 .zip(other.null_expr.iter())
493 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
494 && self.groups == other.groups
495 && self.has_grouping_set == other.has_grouping_set
496 }
497}
498
499#[expect(clippy::large_enum_variant)]
500enum StreamType {
501 AggregateStream(AggregateStream),
502 GroupedHash(GroupedHashAggregateStream),
503 GroupedPriorityQueue(GroupedTopKAggregateStream),
504}
505
506impl From<StreamType> for SendableRecordBatchStream {
507 fn from(stream: StreamType) -> Self {
508 match stream {
509 StreamType::AggregateStream(stream) => Box::pin(stream),
510 StreamType::GroupedHash(stream) => Box::pin(stream),
511 StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
512 }
513 }
514}
515
516#[derive(Debug, Clone)]
560struct AggrDynFilter {
561 filter: Arc<DynamicFilterPhysicalExpr>,
564 supported_accumulators_info: Vec<PerAccumulatorDynFilter>,
572}
573
574#[derive(Debug, Clone)]
579struct PerAccumulatorDynFilter {
580 aggr_type: DynamicFilterAggregateType,
581 aggr_index: usize,
587 shared_bound: Arc<Mutex<ScalarValue>>,
589}
590
591#[derive(Debug, Clone)]
593enum DynamicFilterAggregateType {
594 Min,
595 Max,
596}
597
598#[derive(Debug, Clone, Copy, PartialEq, Eq)]
600pub struct LimitOptions {
601 pub limit: usize,
603 pub descending: Option<bool>,
606}
607
608impl LimitOptions {
609 pub fn new(limit: usize) -> Self {
611 Self {
612 limit,
613 descending: None,
614 }
615 }
616
617 pub fn new_with_order(limit: usize, descending: bool) -> Self {
619 Self {
620 limit,
621 descending: Some(descending),
622 }
623 }
624
625 pub fn limit(&self) -> usize {
626 self.limit
627 }
628
629 pub fn descending(&self) -> Option<bool> {
630 self.descending
631 }
632}
633
634#[derive(Debug, Clone)]
636pub struct AggregateExec {
637 mode: AggregateMode,
639 group_by: Arc<PhysicalGroupBy>,
642 aggr_expr: Arc<[Arc<AggregateFunctionExpr>]>,
645 filter_expr: Arc<[Option<Arc<dyn PhysicalExpr>>]>,
648 limit_options: Option<LimitOptions>,
650 pub input: Arc<dyn ExecutionPlan>,
652 schema: SchemaRef,
655 pub input_schema: SchemaRef,
661 metrics: ExecutionPlanMetricsSet,
663 required_input_ordering: Option<OrderingRequirements>,
664 input_order_mode: InputOrderMode,
666 cache: Arc<PlanProperties>,
667 dynamic_filter: Option<Arc<AggrDynFilter>>,
674}
675
676impl AggregateExec {
677 pub fn with_new_aggr_exprs(
681 &self,
682 aggr_expr: impl Into<Arc<[Arc<AggregateFunctionExpr>]>>,
683 ) -> Self {
684 Self {
685 aggr_expr: aggr_expr.into(),
686 required_input_ordering: self.required_input_ordering.clone(),
688 metrics: ExecutionPlanMetricsSet::new(),
689 input_order_mode: self.input_order_mode.clone(),
690 cache: Arc::clone(&self.cache),
691 mode: self.mode,
692 group_by: Arc::clone(&self.group_by),
693 filter_expr: Arc::clone(&self.filter_expr),
694 limit_options: self.limit_options,
695 input: Arc::clone(&self.input),
696 schema: Arc::clone(&self.schema),
697 input_schema: Arc::clone(&self.input_schema),
698 dynamic_filter: self.dynamic_filter.clone(),
699 }
700 }
701
702 pub fn with_new_limit_options(&self, limit_options: Option<LimitOptions>) -> Self {
704 Self {
705 limit_options,
706 required_input_ordering: self.required_input_ordering.clone(),
708 metrics: ExecutionPlanMetricsSet::new(),
709 input_order_mode: self.input_order_mode.clone(),
710 cache: Arc::clone(&self.cache),
711 mode: self.mode,
712 group_by: Arc::clone(&self.group_by),
713 aggr_expr: Arc::clone(&self.aggr_expr),
714 filter_expr: Arc::clone(&self.filter_expr),
715 input: Arc::clone(&self.input),
716 schema: Arc::clone(&self.schema),
717 input_schema: Arc::clone(&self.input_schema),
718 dynamic_filter: self.dynamic_filter.clone(),
719 }
720 }
721
722 pub fn cache(&self) -> &PlanProperties {
723 &self.cache
724 }
725
726 pub fn try_new(
728 mode: AggregateMode,
729 group_by: impl Into<Arc<PhysicalGroupBy>>,
730 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
731 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
732 input: Arc<dyn ExecutionPlan>,
733 input_schema: SchemaRef,
734 ) -> Result<Self> {
735 let group_by = group_by.into();
736 let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
737
738 let schema = Arc::new(schema);
739 AggregateExec::try_new_with_schema(
740 mode,
741 group_by,
742 aggr_expr,
743 filter_expr,
744 input,
745 input_schema,
746 schema,
747 )
748 }
749
750 fn try_new_with_schema(
759 mode: AggregateMode,
760 group_by: impl Into<Arc<PhysicalGroupBy>>,
761 mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
762 filter_expr: impl Into<Arc<[Option<Arc<dyn PhysicalExpr>>]>>,
763 input: Arc<dyn ExecutionPlan>,
764 input_schema: SchemaRef,
765 schema: SchemaRef,
766 ) -> Result<Self> {
767 let group_by = group_by.into();
768 let filter_expr = filter_expr.into();
769
770 assert_eq_or_internal_err!(
772 aggr_expr.len(),
773 filter_expr.len(),
774 "Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match",
775 aggr_expr,
776 filter_expr
777 );
778
779 let input_eq_properties = input.equivalence_properties();
780 let groupby_exprs = group_by.input_exprs();
782 let (new_sort_exprs, indices) =
787 input_eq_properties.find_longest_permutation(&groupby_exprs)?;
788
789 let mut new_requirements = new_sort_exprs
790 .into_iter()
791 .map(PhysicalSortRequirement::from)
792 .collect::<Vec<_>>();
793
794 let req = get_finer_aggregate_exprs_requirement(
795 &mut aggr_expr,
796 &group_by,
797 input_eq_properties,
798 &mode,
799 )?;
800 new_requirements.extend(req);
801
802 let required_input_ordering =
803 LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
804
805 let indices: Vec<usize> = indices
811 .into_iter()
812 .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
813 .collect();
814
815 let input_order_mode = if indices.len() == groupby_exprs.len()
816 && !indices.is_empty()
817 && group_by.groups.len() == 1
818 {
819 InputOrderMode::Sorted
820 } else if !indices.is_empty() {
821 InputOrderMode::PartiallySorted(indices)
822 } else {
823 InputOrderMode::Linear
824 };
825
826 let group_expr_mapping =
828 ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
829
830 let cache = Self::compute_properties(
831 &input,
832 Arc::clone(&schema),
833 &group_expr_mapping,
834 &mode,
835 &input_order_mode,
836 aggr_expr.as_ref(),
837 )?;
838
839 let mut exec = AggregateExec {
840 mode,
841 group_by,
842 aggr_expr: aggr_expr.into(),
843 filter_expr,
844 input,
845 schema,
846 input_schema,
847 metrics: ExecutionPlanMetricsSet::new(),
848 required_input_ordering,
849 limit_options: None,
850 input_order_mode,
851 cache: Arc::new(cache),
852 dynamic_filter: None,
853 };
854
855 exec.init_dynamic_filter();
856
857 Ok(exec)
858 }
859
860 pub fn mode(&self) -> &AggregateMode {
862 &self.mode
863 }
864
865 pub fn with_limit_options(mut self, limit_options: Option<LimitOptions>) -> Self {
867 self.limit_options = limit_options;
868 self
869 }
870
871 pub fn limit_options(&self) -> Option<LimitOptions> {
873 self.limit_options
874 }
875
876 pub fn group_expr(&self) -> &PhysicalGroupBy {
878 &self.group_by
879 }
880
881 pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
883 self.group_by.output_exprs()
884 }
885
886 pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
888 &self.aggr_expr
889 }
890
891 pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
893 &self.filter_expr
894 }
895
896 pub fn dynamic_filter_expr(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> {
898 self.dynamic_filter.as_ref().map(|df| &df.filter)
899 }
900
901 pub fn with_dynamic_filter_expr(
905 mut self,
906 filter: Arc<DynamicFilterPhysicalExpr>,
907 ) -> Result<Self> {
908 let Some(dyn_filter) = self.dynamic_filter.as_ref() else {
911 return internal_err!("Aggregate does not support dynamic filtering");
912 };
913
914 let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info);
916 if cols.len() != filter.children().len() {
917 return internal_err!(
918 "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
919 );
920 }
921 for (col, child) in cols.iter().zip(filter.children()) {
922 if !col.eq(child) {
923 return internal_err!(
924 "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
925 );
926 }
927 }
928
929 self.dynamic_filter = Some(Arc::new(AggrDynFilter {
931 filter,
932 supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(),
933 }));
934 Ok(self)
935 }
936
937 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
939 &self.input
940 }
941
942 pub fn input_schema(&self) -> SchemaRef {
944 Arc::clone(&self.input_schema)
945 }
946
947 fn execute_typed(
948 &self,
949 partition: usize,
950 context: &Arc<TaskContext>,
951 ) -> Result<StreamType> {
952 if self.group_by.is_true_no_grouping() {
953 return Ok(StreamType::AggregateStream(AggregateStream::new(
954 self, context, partition,
955 )?));
956 }
957
958 if let Some(config) = self.limit_options
960 && !self.is_unordered_unfiltered_group_by_distinct()
961 {
962 return Ok(StreamType::GroupedPriorityQueue(
963 GroupedTopKAggregateStream::new(self, context, partition, config.limit)?,
964 ));
965 }
966
967 Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
969 self, context, partition,
970 )?))
971 }
972
973 pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
975 let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
976 agg_expr.get_minmax_desc()
977 }
978
979 pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
984 if self
985 .limit_options()
986 .and_then(|config| config.descending)
987 .is_some()
988 {
989 return false;
990 }
991 if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() {
993 return false;
994 }
995 if !self.aggr_expr().is_empty() {
997 return false;
998 }
999 if self.filter_expr().iter().any(|e| e.is_some()) {
1002 return false;
1003 }
1004 if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
1006 return false;
1007 }
1008 if self.properties().output_ordering().is_some() {
1010 return false;
1011 }
1012 if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
1014 return matches!(requirement, OrderingRequirements::Hard(_));
1015 }
1016 true
1017 }
1018
1019 pub fn compute_properties(
1021 input: &Arc<dyn ExecutionPlan>,
1022 schema: SchemaRef,
1023 group_expr_mapping: &ProjectionMapping,
1024 mode: &AggregateMode,
1025 input_order_mode: &InputOrderMode,
1026 aggr_exprs: &[Arc<AggregateFunctionExpr>],
1027 ) -> Result<PlanProperties> {
1028 let mut eq_properties = input
1030 .equivalence_properties()
1031 .project(group_expr_mapping, schema);
1032
1033 if group_expr_mapping.is_empty() {
1036 let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
1037 let column = Arc::new(Column::new(func.name(), idx));
1038 ConstExpr::from(column as Arc<dyn PhysicalExpr>)
1039 });
1040 eq_properties.add_constants(new_constants)?;
1041 }
1042
1043 let mut constraints = eq_properties.constraints().to_vec();
1046 let new_constraint = Constraint::Unique(
1047 group_expr_mapping
1048 .iter()
1049 .flat_map(|(_, target_cols)| {
1050 target_cols.iter().flat_map(|(expr, _)| {
1051 expr.downcast_ref::<Column>().map(|c| c.index())
1052 })
1053 })
1054 .collect(),
1055 );
1056 constraints.push(new_constraint);
1057 eq_properties =
1058 eq_properties.with_constraints(Constraints::new_unverified(constraints));
1059
1060 let input_partitioning = input.output_partitioning().clone();
1062 let output_partitioning = match mode.input_mode() {
1063 AggregateInputMode::Raw => {
1064 let input_eq_properties = input.equivalence_properties();
1068 input_partitioning.project(group_expr_mapping, input_eq_properties)
1069 }
1070 AggregateInputMode::Partial => input_partitioning.clone(),
1071 };
1072
1073 let emission_type = if *input_order_mode == InputOrderMode::Linear {
1075 EmissionType::Final
1076 } else {
1077 input.pipeline_behavior()
1078 };
1079
1080 Ok(PlanProperties::new(
1081 eq_properties,
1082 output_partitioning,
1083 emission_type,
1084 input.boundedness(),
1085 ))
1086 }
1087
1088 pub fn input_order_mode(&self) -> &InputOrderMode {
1089 &self.input_order_mode
1090 }
1091
1092 fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
1132 let column_statistics = {
1139 let mut column_statistics = Statistics::unknown_column(&self.schema());
1141
1142 for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
1143 if let Some(col) = expr.downcast_ref::<Column>() {
1144 let child_col_stats =
1145 &child_statistics.column_statistics[col.index()];
1146 column_statistics[idx].max_value = child_col_stats.max_value.clone();
1147 column_statistics[idx].min_value = child_col_stats.min_value.clone();
1148 column_statistics[idx].distinct_count =
1149 child_col_stats.distinct_count;
1150 }
1151 }
1152
1153 column_statistics
1154 };
1155 match self.mode {
1156 AggregateMode::Final | AggregateMode::FinalPartitioned
1157 if self.group_by.expr.is_empty() =>
1158 {
1159 let total_byte_size =
1160 Self::calculate_scaled_byte_size(child_statistics, 1);
1161
1162 Ok(Statistics {
1163 num_rows: Precision::Exact(1),
1164 column_statistics,
1165 total_byte_size,
1166 })
1167 }
1168 _ => {
1169 let num_rows = self.estimate_num_rows(child_statistics);
1170
1171 let total_byte_size = num_rows
1172 .get_value()
1173 .and_then(|&output_rows| {
1174 Self::calculate_scaled_byte_size(child_statistics, output_rows)
1175 .get_value()
1176 .map(|&bytes| Precision::Inexact(bytes))
1177 })
1178 .unwrap_or(Precision::Absent);
1179
1180 Ok(Statistics {
1181 num_rows,
1182 column_statistics,
1183 total_byte_size,
1184 })
1185 }
1186 }
1187 }
1188
1189 fn estimate_num_rows(&self, child_statistics: &Statistics) -> Precision<usize> {
1192 let ndv = if !self.group_by.expr.is_empty() {
1193 self.compute_group_ndv(child_statistics)
1194 } else {
1195 None
1196 };
1197 let limit = self.limit_options.as_ref().map(|lo| lo.limit);
1198
1199 if let Some(&value) = child_statistics.num_rows.get_value() {
1200 if value > 1 {
1201 let mut num_rows = child_statistics.num_rows.to_inexact();
1202 if let Some(ndv) = ndv {
1203 num_rows = num_rows.map(|n| n.min(ndv));
1204 }
1205 if let Some(limit) = limit {
1206 num_rows = num_rows.map(|n| n.min(limit));
1207 }
1208 num_rows
1209 } else if value == 0 {
1210 child_statistics.num_rows
1211 } else {
1212 let grouping_set_num = self.group_by.groups.len();
1213 let mut num_rows =
1214 child_statistics.num_rows.map(|x| x * grouping_set_num);
1215 if let Some(limit) = limit {
1216 num_rows = num_rows.map(|n| n.min(limit));
1217 }
1218 num_rows
1219 }
1220 } else {
1221 match (ndv, limit) {
1222 (Some(n), Some(l)) => Precision::Inexact(n.min(l)),
1223 (Some(n), None) => Precision::Inexact(n),
1224 (None, Some(l)) => Precision::Inexact(l),
1225 (None, None) => Precision::Absent,
1226 }
1227 }
1228 }
1229
1230 fn compute_group_ndv(&self, child_statistics: &Statistics) -> Option<usize> {
1248 let mut total: usize = 0;
1249 for group_mask in &self.group_by.groups {
1250 let mut set_product: usize = 1;
1251 for (j, (expr, _)) in self.group_by.expr.iter().enumerate() {
1252 if group_mask[j] {
1253 continue;
1254 }
1255 let col = expr.downcast_ref::<Column>()?;
1256 let col_stats = &child_statistics.column_statistics[col.index()];
1257 let ndv = *col_stats.distinct_count.get_value()?;
1258 let null_adjustment = match col_stats.null_count.get_value() {
1259 Some(&n) if n > 0 => 1usize,
1260 _ => 0,
1261 };
1262 set_product = set_product
1263 .saturating_mul(ndv.saturating_add(null_adjustment).max(1));
1264 }
1265 total = total.saturating_add(set_product);
1266 }
1267 Some(total)
1268 }
1269
1270 fn init_dynamic_filter(&mut self) {
1274 if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
1275 debug_assert!(
1276 self.dynamic_filter.is_none(),
1277 "The current operator node does not support dynamic filter"
1278 );
1279 return;
1280 }
1281
1282 if self.dynamic_filter.is_some() {
1284 return;
1285 }
1286
1287 let mut aggr_dyn_filters = Vec::new();
1291 let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
1295 for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
1296 let fun_name = aggr_expr.fun().name();
1298 let aggr_type = if fun_name.eq_ignore_ascii_case("min") {
1301 DynamicFilterAggregateType::Min
1302 } else if fun_name.eq_ignore_ascii_case("max") {
1303 DynamicFilterAggregateType::Max
1304 } else {
1305 return;
1306 };
1307
1308 if let [arg] = aggr_expr.expressions().as_slice()
1310 && arg.is::<Column>()
1311 {
1312 all_cols.push(Arc::clone(arg));
1313 aggr_dyn_filters.push(PerAccumulatorDynFilter {
1314 aggr_type,
1315 aggr_index: i,
1316 shared_bound: Arc::new(Mutex::new(ScalarValue::Null)),
1317 });
1318 }
1319 }
1320
1321 if !aggr_dyn_filters.is_empty() {
1322 self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1323 filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1324 supported_accumulators_info: aggr_dyn_filters,
1325 }))
1326 }
1327 }
1328
1329 fn cols_for_dynamic_filter(
1331 &self,
1332 supported_accumulators_info: &[PerAccumulatorDynFilter],
1333 ) -> Vec<Arc<dyn PhysicalExpr>> {
1334 let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1335 .iter()
1336 .filter_map(|info| {
1337 if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice()
1340 && arg.is::<Column>()
1341 {
1342 return Some(Arc::clone(arg));
1343 }
1344 None
1345 })
1346 .collect();
1347 debug_assert!(all_cols.len() == supported_accumulators_info.len());
1348 all_cols
1349 }
1350
1351 #[inline]
1357 fn calculate_scaled_byte_size(
1358 input_stats: &Statistics,
1359 target_row_count: usize,
1360 ) -> Precision<usize> {
1361 match (
1362 input_stats.num_rows.get_value(),
1363 input_stats.total_byte_size.get_value(),
1364 ) {
1365 (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => {
1366 let bytes_per_row = input_bytes as f64 / input_rows as f64;
1367 let scaled_bytes =
1368 (bytes_per_row * target_row_count as f64).ceil() as usize;
1369 Precision::Inexact(scaled_bytes)
1370 }
1371 _ => Precision::Absent,
1372 }
1373 }
1374
1375 fn with_new_children_and_same_properties(
1376 &self,
1377 mut children: Vec<Arc<dyn ExecutionPlan>>,
1378 ) -> Self {
1379 Self {
1380 input: children.swap_remove(0),
1381 metrics: ExecutionPlanMetricsSet::new(),
1382 ..Self::clone(self)
1383 }
1384 }
1385}
1386
1387impl DisplayAs for AggregateExec {
1388 fn fmt_as(
1389 &self,
1390 t: DisplayFormatType,
1391 f: &mut std::fmt::Formatter,
1392 ) -> std::fmt::Result {
1393 match t {
1394 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1395 let format_expr_with_alias =
1396 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1397 let e = e.to_string();
1398 if &e != alias {
1399 format!("{e} as {alias}")
1400 } else {
1401 e
1402 }
1403 };
1404
1405 write!(f, "AggregateExec: mode={:?}", self.mode)?;
1406 let g: Vec<String> = if self.group_by.is_single() {
1407 self.group_by
1408 .expr
1409 .iter()
1410 .map(format_expr_with_alias)
1411 .collect()
1412 } else {
1413 self.group_by
1414 .groups
1415 .iter()
1416 .map(|group| {
1417 let terms = group
1418 .iter()
1419 .enumerate()
1420 .map(|(idx, is_null)| {
1421 if *is_null {
1422 format_expr_with_alias(
1423 &self.group_by.null_expr[idx],
1424 )
1425 } else {
1426 format_expr_with_alias(&self.group_by.expr[idx])
1427 }
1428 })
1429 .collect::<Vec<String>>()
1430 .join(", ");
1431 format!("({terms})")
1432 })
1433 .collect()
1434 };
1435
1436 write!(f, ", gby=[{}]", g.join(", "))?;
1437
1438 let a: Vec<String> = self
1439 .aggr_expr
1440 .iter()
1441 .map(|agg| format_aggregate_exec_expr(agg).to_string())
1442 .collect();
1443 write!(f, ", aggr=[{}]", a.join(", "))?;
1444 if let Some(config) = self.limit_options {
1445 write!(f, ", lim=[{}]", config.limit)?;
1446 }
1447
1448 if self.input_order_mode != InputOrderMode::Linear {
1449 write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
1450 }
1451 }
1452 DisplayFormatType::TreeRender => {
1453 let format_expr_with_alias =
1454 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
1455 let expr_sql = fmt_sql(e.as_ref()).to_string();
1456 if &expr_sql != alias {
1457 format!("{expr_sql} as {alias}")
1458 } else {
1459 expr_sql
1460 }
1461 };
1462
1463 let g: Vec<String> = if self.group_by.is_single() {
1464 self.group_by
1465 .expr
1466 .iter()
1467 .map(format_expr_with_alias)
1468 .collect()
1469 } else {
1470 self.group_by
1471 .groups
1472 .iter()
1473 .map(|group| {
1474 let terms = group
1475 .iter()
1476 .enumerate()
1477 .map(|(idx, is_null)| {
1478 if *is_null {
1479 format_expr_with_alias(
1480 &self.group_by.null_expr[idx],
1481 )
1482 } else {
1483 format_expr_with_alias(&self.group_by.expr[idx])
1484 }
1485 })
1486 .collect::<Vec<String>>()
1487 .join(", ");
1488 format!("({terms})")
1489 })
1490 .collect()
1491 };
1492 let a: Vec<String> = self
1493 .aggr_expr
1494 .iter()
1495 .map(|agg| format_tree_aggregate_expr(agg).to_string())
1496 .collect();
1497 writeln!(f, "mode={:?}", self.mode)?;
1498 if !g.is_empty() {
1499 writeln!(f, "group_by={}", g.join(", "))?;
1500 }
1501 if !a.is_empty() {
1502 writeln!(f, "aggr={}", a.join(", "))?;
1503 }
1504 if let Some(config) = self.limit_options {
1505 writeln!(f, "limit={}", config.limit)?;
1506 }
1507 }
1508 }
1509 Ok(())
1510 }
1511}
1512
1513fn format_aggregate_exec_expr(agg: &AggregateFunctionExpr) -> Cow<'_, str> {
1514 match agg.human_display_alias() {
1515 Some(_) => format_human_display(agg.human_display(), agg.human_display_alias())
1516 .unwrap_or_else(|| Cow::Borrowed(agg.name())),
1517 None => Cow::Borrowed(agg.name()),
1518 }
1519}
1520
1521fn format_tree_aggregate_expr(agg: &AggregateFunctionExpr) -> Cow<'_, str> {
1522 format_human_display(agg.human_display(), agg.human_display_alias())
1523 .unwrap_or_else(|| Cow::Borrowed(agg.name()))
1524}
1525
1526fn format_human_display<'a>(
1527 human_display: Option<&'a str>,
1528 alias: Option<&'a str>,
1529) -> Option<Cow<'a, str>> {
1530 human_display.map(|human_display| match alias {
1531 Some(alias) => Cow::Owned(format!("{human_display} as {alias}")),
1532 None => Cow::Borrowed(human_display),
1533 })
1534}
1535
1536impl ExecutionPlan for AggregateExec {
1537 fn name(&self) -> &'static str {
1538 "AggregateExec"
1539 }
1540
1541 fn properties(&self) -> &Arc<PlanProperties> {
1543 &self.cache
1544 }
1545
1546 fn required_input_distribution(&self) -> Vec<Distribution> {
1547 match &self.mode {
1548 AggregateMode::Partial | AggregateMode::PartialReduce => {
1549 vec![Distribution::UnspecifiedDistribution]
1550 }
1551 AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
1552 vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
1553 }
1554 AggregateMode::Final | AggregateMode::Single => {
1555 vec![Distribution::SinglePartition]
1556 }
1557 }
1558 }
1559
1560 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
1561 vec![self.required_input_ordering.clone()]
1562 }
1563
1564 fn maintains_input_order(&self) -> Vec<bool> {
1574 vec![self.input_order_mode != InputOrderMode::Linear]
1575 }
1576
1577 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1578 vec![&self.input]
1579 }
1580
1581 fn with_new_children(
1582 self: Arc<Self>,
1583 children: Vec<Arc<dyn ExecutionPlan>>,
1584 ) -> Result<Arc<dyn ExecutionPlan>> {
1585 check_if_same_properties!(self, children);
1586
1587 let mut me = AggregateExec::try_new_with_schema(
1588 self.mode,
1589 Arc::clone(&self.group_by),
1590 self.aggr_expr.to_vec(),
1591 Arc::clone(&self.filter_expr),
1592 Arc::clone(&children[0]),
1593 Arc::clone(&self.input_schema),
1594 Arc::clone(&self.schema),
1595 )?;
1596 me.limit_options = self.limit_options;
1597 me.dynamic_filter.clone_from(&self.dynamic_filter);
1598
1599 Ok(Arc::new(me))
1600 }
1601
1602 fn execute(
1603 &self,
1604 partition: usize,
1605 context: Arc<TaskContext>,
1606 ) -> Result<SendableRecordBatchStream> {
1607 self.execute_typed(partition, &context)
1608 .map(|stream| stream.into())
1609 }
1610
1611 fn metrics(&self) -> Option<MetricsSet> {
1612 Some(self.metrics.clone_inner())
1613 }
1614
1615 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
1616 let child_statistics = self.input().partition_statistics(partition)?;
1617 Ok(Arc::new(self.statistics_inner(&child_statistics)?))
1618 }
1619
1620 fn cardinality_effect(&self) -> CardinalityEffect {
1621 CardinalityEffect::LowerEqual
1622 }
1623
1624 fn gather_filters_for_pushdown(
1627 &self,
1628 phase: FilterPushdownPhase,
1629 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1630 config: &ConfigOptions,
1631 ) -> Result<FilterDescription> {
1632 let output_schema = self.schema();
1647 let grouping_columns: HashSet<_> = (0..self.group_by.expr().len())
1648 .map(|i| Column::new(output_schema.field(i).name(), i))
1649 .collect();
1650
1651 let mut safe_filters = Vec::new();
1653 let mut unsafe_filters = Vec::new();
1654
1655 for filter in parent_filters {
1656 let filter_columns: HashSet<_> =
1657 collect_columns(&filter).into_iter().collect();
1658
1659 let references_non_grouping = !grouping_columns.is_empty()
1661 && !filter_columns.is_subset(&grouping_columns);
1662
1663 if references_non_grouping {
1664 unsafe_filters.push(filter);
1665 continue;
1666 }
1667
1668 if self.group_by.groups().len() > 1 {
1670 let filter_column_indices: Vec<usize> = filter_columns
1671 .iter()
1672 .filter_map(|filter_col| {
1673 grouping_columns.get(filter_col).map(|col| col.index())
1674 })
1675 .collect();
1676
1677 let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1679 filter_column_indices
1680 .iter()
1681 .any(|&idx| null_mask.get(idx) == Some(&true))
1682 });
1683
1684 if has_missing_column {
1685 unsafe_filters.push(filter);
1686 continue;
1687 }
1688 }
1689
1690 safe_filters.push(filter);
1692 }
1693
1694 let child = self.children()[0];
1696 let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1697
1698 child_desc.parent_filters.extend(
1700 unsafe_filters
1701 .into_iter()
1702 .map(PushedDownPredicate::unsupported),
1703 );
1704
1705 if phase == FilterPushdownPhase::Post
1707 && config.optimizer.enable_aggregate_dynamic_filter_pushdown
1708 && let Some(self_dyn_filter) = &self.dynamic_filter
1709 {
1710 let dyn_filter = Arc::clone(&self_dyn_filter.filter);
1711 child_desc = child_desc.with_self_filter(dyn_filter);
1712 }
1713
1714 Ok(FilterDescription::new().with_child(child_desc))
1715 }
1716
1717 fn handle_child_pushdown_result(
1720 &self,
1721 phase: FilterPushdownPhase,
1722 child_pushdown_result: ChildPushdownResult,
1723 _config: &ConfigOptions,
1724 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1725 let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
1726
1727 if phase == FilterPushdownPhase::Post
1730 && let Some(dyn_filter) = &self.dynamic_filter
1731 {
1732 let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1;
1753
1754 if !child_accepts_dyn_filter {
1755 let mut new_node = self.clone();
1758 new_node.dynamic_filter = None;
1759
1760 result = result
1761 .with_updated_node(Arc::new(new_node) as Arc<dyn ExecutionPlan>);
1762 }
1763 }
1764
1765 Ok(result)
1766 }
1767}
1768
1769fn create_schema(
1772 input_schema: &Schema,
1773 group_by: &PhysicalGroupBy,
1774 aggr_expr: &[Arc<AggregateFunctionExpr>],
1775 mode: AggregateMode,
1776) -> Result<Schema> {
1777 let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1778 fields.extend(group_by.output_fields(input_schema)?);
1779
1780 match mode.output_mode() {
1781 AggregateOutputMode::Final => {
1782 for expr in aggr_expr {
1784 fields.push(expr.field())
1785 }
1786 }
1787 AggregateOutputMode::Partial => {
1788 for expr in aggr_expr {
1790 fields.extend(expr.state_fields()?.iter().cloned());
1791 }
1792 }
1793 }
1794
1795 Ok(Schema::new_with_metadata(
1796 fields,
1797 input_schema.metadata().clone(),
1798 ))
1799}
1800
1801fn get_aggregate_expr_req(
1822 aggr_expr: &AggregateFunctionExpr,
1823 group_by: &PhysicalGroupBy,
1824 agg_mode: &AggregateMode,
1825 include_soft_requirement: bool,
1826) -> Option<LexOrdering> {
1827 if agg_mode.input_mode() == AggregateInputMode::Partial {
1831 return None;
1832 }
1833
1834 match aggr_expr.order_sensitivity() {
1835 AggregateOrderSensitivity::Insensitive => return None,
1836 AggregateOrderSensitivity::HardRequirement => {}
1837 AggregateOrderSensitivity::SoftRequirement => {
1838 if !include_soft_requirement {
1839 return None;
1840 }
1841 }
1842 AggregateOrderSensitivity::Beneficial => return None,
1843 }
1844
1845 let mut sort_exprs = aggr_expr.order_bys().to_vec();
1846 if group_by.is_single() {
1852 let physical_exprs = group_by.input_exprs();
1856 sort_exprs.retain(|sort_expr| {
1857 !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1858 });
1859 }
1860 LexOrdering::new(sort_exprs)
1861}
1862
1863pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1865 [lhs, rhs].concat()
1866}
1867
1868fn determine_finer(
1872 current: &Option<LexOrdering>,
1873 candidate: &LexOrdering,
1874) -> Option<bool> {
1875 if let Some(ordering) = current {
1876 candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1877 } else {
1878 Some(true)
1879 }
1880}
1881
1882pub fn get_finer_aggregate_exprs_requirement(
1903 aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1904 group_by: &PhysicalGroupBy,
1905 eq_properties: &EquivalenceProperties,
1906 agg_mode: &AggregateMode,
1907) -> Result<Vec<PhysicalSortRequirement>> {
1908 let mut requirement = None;
1909
1910 for include_soft_requirement in [false, true] {
1914 for aggr_expr in aggr_exprs.iter_mut() {
1915 let Some(aggr_req) = get_aggregate_expr_req(
1916 aggr_expr,
1917 group_by,
1918 agg_mode,
1919 include_soft_requirement,
1920 )
1921 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1922 continue;
1925 };
1926 let forward_finer = determine_finer(&requirement, &aggr_req);
1931 if let Some(finer) = forward_finer {
1932 if !finer {
1933 continue;
1934 } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1935 requirement = Some(aggr_req);
1936 continue;
1937 }
1938 }
1939 if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1940 let Some(rev_aggr_req) = get_aggregate_expr_req(
1941 &reverse_aggr_expr,
1942 group_by,
1943 agg_mode,
1944 include_soft_requirement,
1945 )
1946 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1947 *aggr_expr = Arc::new(reverse_aggr_expr);
1950 continue;
1951 };
1952 if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1958 if !finer {
1959 *aggr_expr = Arc::new(reverse_aggr_expr);
1960 } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1961 *aggr_expr = Arc::new(reverse_aggr_expr);
1962 requirement = Some(rev_aggr_req);
1963 } else {
1964 requirement = Some(aggr_req);
1965 }
1966 } else if forward_finer.is_some() {
1967 requirement = Some(aggr_req);
1968 } else {
1969 if !include_soft_requirement {
1974 return not_impl_err!(
1975 "Conflicting ordering requirements in aggregate functions is not supported"
1976 );
1977 }
1978 }
1979 }
1980 }
1981 }
1982
1983 Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1984}
1985
1986pub fn aggregate_expressions(
1992 aggr_expr: &[Arc<AggregateFunctionExpr>],
1993 mode: &AggregateMode,
1994 col_idx_base: usize,
1995) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1996 match mode.input_mode() {
1997 AggregateInputMode::Raw => Ok(aggr_expr
1998 .iter()
1999 .map(|agg| {
2000 let mut result = agg.expressions();
2001 result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
2005 result
2006 })
2007 .collect()),
2008 AggregateInputMode::Partial => {
2009 let mut col_idx_base = col_idx_base;
2011 aggr_expr
2012 .iter()
2013 .map(|agg| {
2014 let exprs = merge_expressions(col_idx_base, agg)?;
2015 col_idx_base += exprs.len();
2016 Ok(exprs)
2017 })
2018 .collect()
2019 }
2020 }
2021}
2022
2023fn merge_expressions(
2028 index_base: usize,
2029 expr: &AggregateFunctionExpr,
2030) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
2031 expr.state_fields().map(|fields| {
2032 fields
2033 .iter()
2034 .enumerate()
2035 .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
2036 .collect()
2037 })
2038}
2039
2040pub type AccumulatorItem = Box<dyn Accumulator>;
2041
2042pub fn create_accumulators(
2043 aggr_expr: &[Arc<AggregateFunctionExpr>],
2044) -> Result<Vec<AccumulatorItem>> {
2045 aggr_expr
2046 .iter()
2047 .map(|expr| expr.create_accumulator())
2048 .collect()
2049}
2050
2051pub fn finalize_aggregation(
2054 accumulators: &mut [AccumulatorItem],
2055 mode: &AggregateMode,
2056) -> Result<Vec<ArrayRef>> {
2057 match mode.output_mode() {
2058 AggregateOutputMode::Final => {
2059 accumulators
2061 .iter_mut()
2062 .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
2063 .collect()
2064 }
2065 AggregateOutputMode::Partial => {
2066 accumulators
2068 .iter_mut()
2069 .map(|accumulator| {
2070 accumulator.state().and_then(|e| {
2071 e.iter()
2072 .map(|v| v.to_array())
2073 .collect::<Result<Vec<ArrayRef>>>()
2074 })
2075 })
2076 .flatten_ok()
2077 .collect()
2078 }
2079 }
2080}
2081
2082pub fn evaluate_many(
2084 expr: &[Vec<Arc<dyn PhysicalExpr>>],
2085 batch: &RecordBatch,
2086) -> Result<Vec<Vec<ArrayRef>>> {
2087 expr.iter()
2088 .map(|expr| evaluate_expressions_to_arrays(expr, batch))
2089 .collect()
2090}
2091
2092fn evaluate_optional(
2093 expr: &[Option<Arc<dyn PhysicalExpr>>],
2094 batch: &RecordBatch,
2095) -> Result<Vec<Option<ArrayRef>>> {
2096 expr.iter()
2097 .map(|expr| {
2098 expr.as_ref()
2099 .map(|expr| {
2100 expr.evaluate(batch)
2101 .and_then(|v| v.into_array(batch.num_rows()))
2102 })
2103 .transpose()
2104 })
2105 .collect()
2106}
2107
2108pub(crate) fn group_id_array(
2124 group: &[bool],
2125 ordinal: usize,
2126 max_ordinal: usize,
2127 num_rows: usize,
2128) -> Result<ArrayRef> {
2129 let n = group.len();
2130 if n > 64 {
2131 return not_impl_err!(
2132 "Grouping sets with more than 64 columns are not supported"
2133 );
2134 }
2135 let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize;
2136 let total_bits = n + ordinal_bits;
2137 if total_bits > 64 {
2138 return not_impl_err!(
2139 "Grouping sets with {n} columns and a maximum duplicate ordinal of \
2140 {max_ordinal} require {total_bits} bits, which exceeds 64"
2141 );
2142 }
2143 let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
2144 (acc << 1) | if is_null { 1 } else { 0 }
2145 });
2146 let full_id = semantic_id | ((ordinal as u64) << n);
2147 if total_bits <= 8 {
2148 Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
2149 } else if total_bits <= 16 {
2150 Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
2151 } else if total_bits <= 32 {
2152 Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
2153 } else {
2154 Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
2155 }
2156}
2157
2158pub(crate) fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
2166 let mut counts: HashMap<&[bool], usize> = HashMap::new();
2167 for group in groups {
2168 *counts.entry(group).or_insert(0) += 1;
2169 }
2170 counts.into_values().max().unwrap_or(0).saturating_sub(1)
2171}
2172
2173pub fn evaluate_group_by(
2184 group_by: &PhysicalGroupBy,
2185 batch: &RecordBatch,
2186) -> Result<Vec<Vec<ArrayRef>>> {
2187 let max_ordinal = max_duplicate_ordinal(&group_by.groups);
2188 let mut ordinal_per_pattern: HashMap<&[bool], usize> = HashMap::new();
2189 let exprs = evaluate_expressions_to_arrays(
2190 group_by.expr.iter().map(|(expr, _)| expr),
2191 batch,
2192 )?;
2193 let null_exprs = evaluate_expressions_to_arrays(
2194 group_by.null_expr.iter().map(|(expr, _)| expr),
2195 batch,
2196 )?;
2197
2198 group_by
2199 .groups
2200 .iter()
2201 .map(|group| {
2202 let ordinal = ordinal_per_pattern.entry(group).or_insert(0);
2203 let current_ordinal = *ordinal;
2204 *ordinal += 1;
2205
2206 let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
2207 group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
2208 if *is_null {
2209 Arc::clone(&null_exprs[idx])
2210 } else {
2211 Arc::clone(&exprs[idx])
2212 }
2213 }));
2214 if !group_by.is_single() {
2215 group_values.push(group_id_array(
2216 group,
2217 current_ordinal,
2218 max_ordinal,
2219 batch.num_rows(),
2220 )?);
2221 }
2222 Ok(group_values)
2223 })
2224 .collect()
2225}
2226
2227#[cfg(test)]
2228mod tests {
2229 use std::task::{Context, Poll};
2230
2231 use super::*;
2232 use crate::RecordBatchStream;
2233 use crate::coalesce_partitions::CoalescePartitionsExec;
2234 use crate::common;
2235 use crate::common::collect;
2236 use crate::empty::EmptyExec;
2237 use crate::execution_plan::Boundedness;
2238 use crate::expressions::col;
2239 use crate::metrics::MetricValue;
2240 use crate::test::TestMemoryExec;
2241 use crate::test::assert_is_pending;
2242 use crate::test::exec::{
2243 BlockingExec, StatisticsExec, assert_strong_count_converges_to_zero,
2244 };
2245
2246 use arrow::array::{
2247 DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
2248 UInt32Array, UInt64Array,
2249 };
2250 use arrow::compute::{SortOptions, concat_batches};
2251 use arrow::datatypes::Int32Type;
2252 use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
2253 use datafusion_common::{DataFusionError, internal_err};
2254 use datafusion_execution::config::SessionConfig;
2255 use datafusion_execution::memory_pool::FairSpillPool;
2256 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2257 use datafusion_functions_aggregate::array_agg::array_agg_udaf;
2258 use datafusion_functions_aggregate::average::avg_udaf;
2259 use datafusion_functions_aggregate::count::count_udaf;
2260 use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
2261 use datafusion_functions_aggregate::median::median_udaf;
2262 use datafusion_functions_aggregate::min_max::min_udaf;
2263 use datafusion_functions_aggregate::sum::sum_udaf;
2264 use datafusion_physical_expr::Partitioning;
2265 use datafusion_physical_expr::PhysicalSortExpr;
2266 use datafusion_physical_expr::aggregate::AggregateExprBuilder;
2267 use datafusion_physical_expr::expressions::Literal;
2268
2269 use crate::projection::ProjectionExec;
2270 use datafusion_physical_expr::projection::ProjectionExpr;
2271 use futures::{FutureExt, Stream};
2272 use insta::{allow_duplicates, assert_snapshot};
2273
2274 fn create_test_schema() -> Result<SchemaRef> {
2276 let a = Field::new("a", DataType::Int32, true);
2277 let b = Field::new("b", DataType::Int32, true);
2278 let c = Field::new("c", DataType::Int32, true);
2279 let d = Field::new("d", DataType::Int32, true);
2280 let e = Field::new("e", DataType::Int32, true);
2281 let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
2282
2283 Ok(schema)
2284 }
2285
2286 fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
2288 let schema = Arc::new(Schema::new(vec![
2290 Field::new("a", DataType::UInt32, false),
2291 Field::new("b", DataType::Float64, false),
2292 ]));
2293
2294 (
2296 Arc::clone(&schema),
2297 vec![
2298 RecordBatch::try_new(
2299 Arc::clone(&schema),
2300 vec![
2301 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2302 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2303 ],
2304 )
2305 .unwrap(),
2306 RecordBatch::try_new(
2307 schema,
2308 vec![
2309 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2310 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2311 ],
2312 )
2313 .unwrap(),
2314 ],
2315 )
2316 }
2317
2318 fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
2320 let schema = Arc::new(Schema::new(vec![
2322 Field::new("a", DataType::UInt32, false),
2323 Field::new("b", DataType::Float64, false),
2324 ]));
2325
2326 (
2331 Arc::clone(&schema),
2332 vec![
2333 RecordBatch::try_new(
2334 Arc::clone(&schema),
2335 vec![
2336 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
2337 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
2338 ],
2339 )
2340 .unwrap(),
2341 RecordBatch::try_new(
2342 Arc::clone(&schema),
2343 vec![
2344 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2345 Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
2346 ],
2347 )
2348 .unwrap(),
2349 RecordBatch::try_new(
2350 Arc::clone(&schema),
2351 vec![
2352 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2353 Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
2354 ],
2355 )
2356 .unwrap(),
2357 RecordBatch::try_new(
2358 schema,
2359 vec![
2360 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
2361 Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
2362 ],
2363 )
2364 .unwrap(),
2365 ],
2366 )
2367 }
2368
2369 fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
2370 let session_config = SessionConfig::new().with_batch_size(batch_size);
2371 let runtime = RuntimeEnvBuilder::new()
2372 .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
2373 .build_arc()
2374 .unwrap();
2375 let task_ctx = TaskContext::default()
2376 .with_session_config(session_config)
2377 .with_runtime(runtime);
2378 Arc::new(task_ctx)
2379 }
2380
2381 async fn check_grouping_sets(
2382 input: Arc<dyn ExecutionPlan>,
2383 spill: bool,
2384 ) -> Result<()> {
2385 let input_schema = input.schema();
2386
2387 let grouping_set = PhysicalGroupBy::new(
2388 vec![
2389 (col("a", &input_schema)?, "a".to_string()),
2390 (col("b", &input_schema)?, "b".to_string()),
2391 ],
2392 vec![
2393 (lit(ScalarValue::UInt32(None)), "a".to_string()),
2394 (lit(ScalarValue::Float64(None)), "b".to_string()),
2395 ],
2396 vec![
2397 vec![false, true], vec![true, false], vec![false, false], ],
2401 true,
2402 );
2403
2404 let aggregates = vec![Arc::new(
2405 AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
2406 .schema(Arc::clone(&input_schema))
2407 .alias("COUNT(1)")
2408 .build()?,
2409 )];
2410
2411 let task_ctx = if spill {
2412 new_spill_ctx(4, 500)
2414 } else {
2415 Arc::new(TaskContext::default())
2416 };
2417
2418 let partial_aggregate = Arc::new(AggregateExec::try_new(
2419 AggregateMode::Partial,
2420 grouping_set.clone(),
2421 aggregates.clone(),
2422 vec![None],
2423 input,
2424 Arc::clone(&input_schema),
2425 )?);
2426
2427 let result =
2428 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2429
2430 if spill {
2431 allow_duplicates! {
2434 assert_snapshot!(batches_to_sort_string(&result),
2435 @r"
2436 +---+-----+---------------+-----------------+
2437 | a | b | __grouping_id | COUNT(1)[count] |
2438 +---+-----+---------------+-----------------+
2439 | | 1.0 | 2 | 1 |
2440 | | 1.0 | 2 | 1 |
2441 | | 2.0 | 2 | 1 |
2442 | | 2.0 | 2 | 1 |
2443 | | 3.0 | 2 | 1 |
2444 | | 3.0 | 2 | 1 |
2445 | | 4.0 | 2 | 1 |
2446 | | 4.0 | 2 | 1 |
2447 | 2 | | 1 | 1 |
2448 | 2 | | 1 | 1 |
2449 | 2 | 1.0 | 0 | 1 |
2450 | 2 | 1.0 | 0 | 1 |
2451 | 3 | | 1 | 1 |
2452 | 3 | | 1 | 2 |
2453 | 3 | 2.0 | 0 | 2 |
2454 | 3 | 3.0 | 0 | 1 |
2455 | 4 | | 1 | 1 |
2456 | 4 | | 1 | 2 |
2457 | 4 | 3.0 | 0 | 1 |
2458 | 4 | 4.0 | 0 | 2 |
2459 +---+-----+---------------+-----------------+
2460 "
2461 );
2462 }
2463 } else {
2464 allow_duplicates! {
2465 assert_snapshot!(batches_to_sort_string(&result),
2466 @r"
2467 +---+-----+---------------+-----------------+
2468 | a | b | __grouping_id | COUNT(1)[count] |
2469 +---+-----+---------------+-----------------+
2470 | | 1.0 | 2 | 2 |
2471 | | 2.0 | 2 | 2 |
2472 | | 3.0 | 2 | 2 |
2473 | | 4.0 | 2 | 2 |
2474 | 2 | | 1 | 2 |
2475 | 2 | 1.0 | 0 | 2 |
2476 | 3 | | 1 | 3 |
2477 | 3 | 2.0 | 0 | 2 |
2478 | 3 | 3.0 | 0 | 1 |
2479 | 4 | | 1 | 3 |
2480 | 4 | 3.0 | 0 | 1 |
2481 | 4 | 4.0 | 0 | 2 |
2482 +---+-----+---------------+-----------------+
2483 "
2484 );
2485 }
2486 };
2487
2488 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2489
2490 let final_grouping_set = grouping_set.as_final();
2491
2492 let task_ctx = if spill {
2493 new_spill_ctx(4, 3160)
2494 } else {
2495 task_ctx
2496 };
2497
2498 let merged_aggregate = Arc::new(AggregateExec::try_new(
2499 AggregateMode::Final,
2500 final_grouping_set,
2501 aggregates,
2502 vec![None],
2503 merge,
2504 input_schema,
2505 )?);
2506
2507 let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2508 let batch = concat_batches(&result[0].schema(), &result)?;
2509 assert_eq!(batch.num_columns(), 4);
2510 assert_eq!(batch.num_rows(), 12);
2511
2512 allow_duplicates! {
2513 assert_snapshot!(
2514 batches_to_sort_string(&result),
2515 @r"
2516 +---+-----+---------------+----------+
2517 | a | b | __grouping_id | COUNT(1) |
2518 +---+-----+---------------+----------+
2519 | | 1.0 | 2 | 2 |
2520 | | 2.0 | 2 | 2 |
2521 | | 3.0 | 2 | 2 |
2522 | | 4.0 | 2 | 2 |
2523 | 2 | | 1 | 2 |
2524 | 2 | 1.0 | 0 | 2 |
2525 | 3 | | 1 | 3 |
2526 | 3 | 2.0 | 0 | 2 |
2527 | 3 | 3.0 | 0 | 1 |
2528 | 4 | | 1 | 3 |
2529 | 4 | 3.0 | 0 | 1 |
2530 | 4 | 4.0 | 0 | 2 |
2531 +---+-----+---------------+----------+
2532 "
2533 );
2534 }
2535
2536 let metrics = merged_aggregate.metrics().unwrap();
2537 let output_rows = metrics.output_rows().unwrap();
2538 assert_eq!(12, output_rows);
2539
2540 Ok(())
2541 }
2542
2543 async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
2545 let input_schema = input.schema();
2546
2547 let grouping_set = PhysicalGroupBy::new(
2548 vec![(col("a", &input_schema)?, "a".to_string())],
2549 vec![],
2550 vec![vec![false]],
2551 false,
2552 );
2553
2554 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2555 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2556 .schema(Arc::clone(&input_schema))
2557 .alias("AVG(b)")
2558 .build()?,
2559 )];
2560
2561 let task_ctx = if spill {
2562 new_spill_ctx(2, 1600)
2564 } else {
2565 Arc::new(TaskContext::default())
2566 };
2567
2568 let partial_aggregate = Arc::new(AggregateExec::try_new(
2569 AggregateMode::Partial,
2570 grouping_set.clone(),
2571 aggregates.clone(),
2572 vec![None],
2573 input,
2574 Arc::clone(&input_schema),
2575 )?);
2576
2577 let result =
2578 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2579
2580 if spill {
2581 allow_duplicates! {
2582 assert_snapshot!(batches_to_sort_string(&result), @r"
2583 +---+---------------+-------------+
2584 | a | AVG(b)[count] | AVG(b)[sum] |
2585 +---+---------------+-------------+
2586 | 2 | 1 | 1.0 |
2587 | 2 | 1 | 1.0 |
2588 | 3 | 1 | 2.0 |
2589 | 3 | 2 | 5.0 |
2590 | 4 | 3 | 11.0 |
2591 +---+---------------+-------------+
2592 ");
2593 }
2594 } else {
2595 allow_duplicates! {
2596 assert_snapshot!(batches_to_sort_string(&result), @r"
2597 +---+---------------+-------------+
2598 | a | AVG(b)[count] | AVG(b)[sum] |
2599 +---+---------------+-------------+
2600 | 2 | 2 | 2.0 |
2601 | 3 | 3 | 7.0 |
2602 | 4 | 3 | 11.0 |
2603 +---+---------------+-------------+
2604 ");
2605 }
2606 };
2607
2608 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
2609
2610 let final_grouping_set = grouping_set.as_final();
2611
2612 let merged_aggregate = Arc::new(AggregateExec::try_new(
2613 AggregateMode::Final,
2614 final_grouping_set,
2615 aggregates,
2616 vec![None],
2617 merge,
2618 input_schema,
2619 )?);
2620
2621 let final_stats = merged_aggregate.partition_statistics(None)?;
2623 assert!(final_stats.total_byte_size.get_value().is_some());
2624
2625 let task_ctx = if spill {
2626 new_spill_ctx(2, 2600)
2628 } else {
2629 Arc::clone(&task_ctx)
2630 };
2631 let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
2632 let batch = concat_batches(&result[0].schema(), &result)?;
2633 assert_eq!(batch.num_columns(), 2);
2634 assert_eq!(batch.num_rows(), 3);
2635
2636 allow_duplicates! {
2637 assert_snapshot!(batches_to_sort_string(&result), @r"
2638 +---+--------------------+
2639 | a | AVG(b) |
2640 +---+--------------------+
2641 | 2 | 1.0 |
2642 | 3 | 2.3333333333333335 |
2643 | 4 | 3.6666666666666665 |
2644 +---+--------------------+
2645 ");
2646 }
2649
2650 let metrics = merged_aggregate.metrics().unwrap();
2651 let output_rows = metrics.output_rows().unwrap();
2652 let spill_count = metrics.spill_count().unwrap();
2653 let spilled_bytes = metrics.spilled_bytes().unwrap();
2654 let spilled_rows = metrics.spilled_rows().unwrap();
2655
2656 if spill {
2657 assert_eq!(8, output_rows);
2660
2661 assert!(spill_count > 0);
2662 assert!(spilled_bytes > 0);
2663 assert!(spilled_rows > 0);
2664 } else {
2665 assert_eq!(3, output_rows);
2666
2667 assert_eq!(0, spill_count);
2668 assert_eq!(0, spilled_bytes);
2669 assert_eq!(0, spilled_rows);
2670 }
2671
2672 Ok(())
2673 }
2674
2675 #[derive(Debug)]
2678 struct TestYieldingExec {
2679 pub yield_first: bool,
2681 cache: Arc<PlanProperties>,
2682 }
2683
2684 impl TestYieldingExec {
2685 fn new(yield_first: bool) -> Self {
2686 let schema = some_data().0;
2687 let cache = Self::compute_properties(schema);
2688 Self {
2689 yield_first,
2690 cache: Arc::new(cache),
2691 }
2692 }
2693
2694 fn compute_properties(schema: SchemaRef) -> PlanProperties {
2696 PlanProperties::new(
2697 EquivalenceProperties::new(schema),
2698 Partitioning::UnknownPartitioning(1),
2699 EmissionType::Incremental,
2700 Boundedness::Bounded,
2701 )
2702 }
2703 }
2704
2705 impl DisplayAs for TestYieldingExec {
2706 fn fmt_as(
2707 &self,
2708 t: DisplayFormatType,
2709 f: &mut std::fmt::Formatter,
2710 ) -> std::fmt::Result {
2711 match t {
2712 DisplayFormatType::Default | DisplayFormatType::Verbose => {
2713 write!(f, "TestYieldingExec")
2714 }
2715 DisplayFormatType::TreeRender => {
2716 write!(f, "")
2718 }
2719 }
2720 }
2721 }
2722
2723 impl ExecutionPlan for TestYieldingExec {
2724 fn name(&self) -> &'static str {
2725 "TestYieldingExec"
2726 }
2727
2728 fn properties(&self) -> &Arc<PlanProperties> {
2729 &self.cache
2730 }
2731
2732 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2733 vec![]
2734 }
2735
2736 fn with_new_children(
2737 self: Arc<Self>,
2738 _: Vec<Arc<dyn ExecutionPlan>>,
2739 ) -> Result<Arc<dyn ExecutionPlan>> {
2740 internal_err!("Children cannot be replaced in {self:?}")
2741 }
2742
2743 fn execute(
2744 &self,
2745 _partition: usize,
2746 _context: Arc<TaskContext>,
2747 ) -> Result<SendableRecordBatchStream> {
2748 let stream = if self.yield_first {
2749 TestYieldingStream::New
2750 } else {
2751 TestYieldingStream::Yielded
2752 };
2753
2754 Ok(Box::pin(stream))
2755 }
2756
2757 fn partition_statistics(
2758 &self,
2759 partition: Option<usize>,
2760 ) -> Result<Arc<Statistics>> {
2761 if partition.is_some() {
2762 return Ok(Arc::new(Statistics::new_unknown(self.schema().as_ref())));
2763 }
2764 let (_, batches) = some_data();
2765 Ok(Arc::new(common::compute_record_batch_statistics(
2766 &[batches],
2767 &self.schema(),
2768 None,
2769 )))
2770 }
2771 }
2772
2773 enum TestYieldingStream {
2775 New,
2776 Yielded,
2777 ReturnedBatch1,
2778 ReturnedBatch2,
2779 }
2780
2781 impl Stream for TestYieldingStream {
2782 type Item = Result<RecordBatch>;
2783
2784 fn poll_next(
2785 mut self: std::pin::Pin<&mut Self>,
2786 cx: &mut Context<'_>,
2787 ) -> Poll<Option<Self::Item>> {
2788 match &*self {
2789 TestYieldingStream::New => {
2790 *(self.as_mut()) = TestYieldingStream::Yielded;
2791 cx.waker().wake_by_ref();
2792 Poll::Pending
2793 }
2794 TestYieldingStream::Yielded => {
2795 *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2796 Poll::Ready(Some(Ok(some_data().1[0].clone())))
2797 }
2798 TestYieldingStream::ReturnedBatch1 => {
2799 *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2800 Poll::Ready(Some(Ok(some_data().1[1].clone())))
2801 }
2802 TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2803 }
2804 }
2805 }
2806
2807 impl RecordBatchStream for TestYieldingStream {
2808 fn schema(&self) -> SchemaRef {
2809 some_data().0
2810 }
2811 }
2812
2813 #[tokio::test]
2816 async fn aggregate_source_not_yielding() -> Result<()> {
2817 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2818
2819 check_aggregates(input, false).await
2820 }
2821
2822 #[tokio::test]
2823 async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2824 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2825
2826 check_grouping_sets(input, false).await
2827 }
2828
2829 #[tokio::test]
2830 async fn aggregate_source_with_yielding() -> Result<()> {
2831 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2832
2833 check_aggregates(input, false).await
2834 }
2835
2836 #[tokio::test]
2837 async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2838 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2839
2840 check_grouping_sets(input, false).await
2841 }
2842
2843 #[tokio::test]
2844 async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2845 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2846
2847 check_aggregates(input, true).await
2848 }
2849
2850 #[tokio::test]
2851 async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2852 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2853
2854 check_grouping_sets(input, true).await
2855 }
2856
2857 #[tokio::test]
2858 async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2859 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2860
2861 check_aggregates(input, true).await
2862 }
2863
2864 #[tokio::test]
2865 async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2866 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2867
2868 check_grouping_sets(input, true).await
2869 }
2870
2871 fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2873 AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2874 .schema(schema)
2875 .alias("MEDIAN(a)")
2876 .build()
2877 }
2878
2879 #[tokio::test]
2880 async fn test_oom() -> Result<()> {
2881 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2882 let input_schema = input.schema();
2883
2884 let runtime = RuntimeEnvBuilder::new()
2885 .with_memory_limit(1, 1.0)
2886 .build_arc()?;
2887 let task_ctx = TaskContext::default().with_runtime(runtime);
2888 let task_ctx = Arc::new(task_ctx);
2889
2890 let groups_none = PhysicalGroupBy::default();
2891 let groups_some = PhysicalGroupBy::new(
2892 vec![(col("a", &input_schema)?, "a".to_string())],
2893 vec![],
2894 vec![vec![false]],
2895 false,
2896 );
2897
2898 let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2900 vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2901
2902 let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2904 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2905 .schema(Arc::clone(&input_schema))
2906 .alias("AVG(b)")
2907 .build()?,
2908 )];
2909
2910 for (version, groups, aggregates) in [
2911 (0, groups_none, aggregates_v0),
2912 (2, groups_some, aggregates_v2),
2913 ] {
2914 let n_aggr = aggregates.len();
2915 let partial_aggregate = Arc::new(AggregateExec::try_new(
2916 AggregateMode::Single,
2917 groups,
2918 aggregates,
2919 vec![None; n_aggr],
2920 Arc::clone(&input),
2921 Arc::clone(&input_schema),
2922 )?);
2923
2924 let stream = partial_aggregate.execute_typed(0, &task_ctx)?;
2925
2926 match version {
2928 0 => {
2929 assert!(matches!(stream, StreamType::AggregateStream(_)));
2930 }
2931 1 => {
2932 assert!(matches!(stream, StreamType::GroupedHash(_)));
2933 }
2934 2 => {
2935 assert!(matches!(stream, StreamType::GroupedHash(_)));
2936 }
2937 _ => panic!("Unknown version: {version}"),
2938 }
2939
2940 let stream: SendableRecordBatchStream = stream.into();
2941 let err = collect(stream).await.unwrap_err();
2942
2943 let err = err.find_root();
2945 assert!(
2946 matches!(err, DataFusionError::ResourcesExhausted(_)),
2947 "Wrong error type: {err}",
2948 );
2949 }
2950
2951 Ok(())
2952 }
2953
2954 #[tokio::test]
2955 async fn test_drop_cancel_without_groups() -> Result<()> {
2956 let task_ctx = Arc::new(TaskContext::default());
2957 let schema =
2958 Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2959
2960 let groups = PhysicalGroupBy::default();
2961
2962 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2963 AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2964 .schema(Arc::clone(&schema))
2965 .alias("AVG(a)")
2966 .build()?,
2967 )];
2968
2969 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2970 let refs = blocking_exec.refs();
2971 let aggregate_exec = Arc::new(AggregateExec::try_new(
2972 AggregateMode::Partial,
2973 groups.clone(),
2974 aggregates.clone(),
2975 vec![None],
2976 blocking_exec,
2977 schema,
2978 )?);
2979
2980 let fut = crate::collect(aggregate_exec, task_ctx);
2981 let mut fut = fut.boxed();
2982
2983 assert_is_pending(&mut fut);
2984 drop(fut);
2985 assert_strong_count_converges_to_zero(refs).await;
2986
2987 Ok(())
2988 }
2989
2990 #[tokio::test]
2991 async fn test_drop_cancel_with_groups() -> Result<()> {
2992 let task_ctx = Arc::new(TaskContext::default());
2993 let schema = Arc::new(Schema::new(vec![
2994 Field::new("a", DataType::Float64, true),
2995 Field::new("b", DataType::Float64, true),
2996 ]));
2997
2998 let groups =
2999 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
3000
3001 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
3002 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3003 .schema(Arc::clone(&schema))
3004 .alias("AVG(b)")
3005 .build()?,
3006 )];
3007
3008 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
3009 let refs = blocking_exec.refs();
3010 let aggregate_exec = Arc::new(AggregateExec::try_new(
3011 AggregateMode::Partial,
3012 groups,
3013 aggregates.clone(),
3014 vec![None],
3015 blocking_exec,
3016 schema,
3017 )?);
3018
3019 let fut = crate::collect(aggregate_exec, task_ctx);
3020 let mut fut = fut.boxed();
3021
3022 assert_is_pending(&mut fut);
3023 drop(fut);
3024 assert_strong_count_converges_to_zero(refs).await;
3025
3026 Ok(())
3027 }
3028
3029 #[tokio::test]
3030 async fn run_first_last_multi_partitions() -> Result<()> {
3031 for is_first_acc in [false, true] {
3032 for spill in [false, true] {
3033 first_last_multi_partitions(is_first_acc, spill, 4200).await?
3034 }
3035 }
3036 Ok(())
3037 }
3038
3039 fn test_first_value_agg_expr(
3041 schema: &Schema,
3042 sort_options: SortOptions,
3043 ) -> Result<Arc<AggregateFunctionExpr>> {
3044 let order_bys = vec![PhysicalSortExpr {
3045 expr: col("b", schema)?,
3046 options: sort_options,
3047 }];
3048 let args = [col("b", schema)?];
3049
3050 AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
3051 .order_by(order_bys)
3052 .schema(Arc::new(schema.clone()))
3053 .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
3054 .build()
3055 .map(Arc::new)
3056 }
3057
3058 fn test_last_value_agg_expr(
3060 schema: &Schema,
3061 sort_options: SortOptions,
3062 ) -> Result<Arc<AggregateFunctionExpr>> {
3063 let order_bys = vec![PhysicalSortExpr {
3064 expr: col("b", schema)?,
3065 options: sort_options,
3066 }];
3067 let args = [col("b", schema)?];
3068 AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
3069 .order_by(order_bys)
3070 .schema(Arc::new(schema.clone()))
3071 .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
3072 .build()
3073 .map(Arc::new)
3074 }
3075
3076 fn first_value_agg_expr(
3077 schema: &SchemaRef,
3078 column: &str,
3079 alias: &str,
3080 human_display: Option<&str>,
3081 human_display_alias: Option<&str>,
3082 ) -> Result<AggregateFunctionExpr> {
3083 let mut builder =
3084 AggregateExprBuilder::new(first_value_udaf(), vec![col(column, schema)?])
3085 .order_by(vec![PhysicalSortExpr {
3086 expr: col(column, schema)?,
3087 options: SortOptions::new(false, false),
3088 }])
3089 .schema(Arc::clone(schema))
3090 .alias(alias);
3091
3092 if let Some(human_display) = human_display {
3093 builder = builder.human_display(human_display);
3094 }
3095 if let Some(human_display_alias) = human_display_alias {
3096 builder = builder.human_display_alias(human_display_alias);
3097 }
3098
3099 builder.build()
3100 }
3101
3102 #[test]
3103 fn test_reverse_expr_preserves_aliased_human_display() -> Result<()> {
3104 let schema = create_test_schema()?;
3105 let agg = first_value_agg_expr(
3106 &schema,
3107 "b",
3108 "agg",
3109 Some("first_value(b) ORDER BY [b ASC NULLS LAST]"),
3110 Some("agg"),
3111 )?;
3112
3113 let reversed = agg.reverse_expr().expect("expected reverse expr");
3114
3115 assert_eq!(reversed.name(), "agg");
3116 assert_eq!(reversed.human_display_alias(), Some("agg"));
3117 assert_eq!(
3118 format_tree_aggregate_expr(&reversed),
3119 "last_value(b) ORDER BY [b DESC NULLS FIRST] as agg"
3120 );
3121 assert_eq!(
3122 reversed.human_display(),
3123 Some("last_value(b) ORDER BY [b DESC NULLS FIRST]")
3124 );
3125
3126 Ok(())
3127 }
3128
3129 #[test]
3130 fn test_reverse_expr_does_not_rewrite_column_names_in_human_display() -> Result<()> {
3131 let schema = Arc::new(Schema::new(vec![Field::new(
3132 "first_value_col",
3133 DataType::Int32,
3134 true,
3135 )]));
3136 let agg = first_value_agg_expr(
3137 &schema,
3138 "first_value_col",
3139 "agg",
3140 Some(
3141 "first_value(first_value_col) ORDER BY [first_value_col ASC NULLS LAST]",
3142 ),
3143 Some("agg"),
3144 )?;
3145
3146 let reversed = agg.reverse_expr().expect("expected reverse expr");
3147
3148 assert_eq!(reversed.name(), "agg");
3149 assert_eq!(
3150 reversed.human_display(),
3151 Some(
3152 "last_value(first_value_col) ORDER BY [first_value_col DESC NULLS FIRST]"
3153 )
3154 );
3155 assert_eq!(
3156 format_tree_aggregate_expr(&reversed),
3157 "last_value(first_value_col) ORDER BY [first_value_col DESC NULLS FIRST] as agg"
3158 );
3159
3160 Ok(())
3161 }
3162
3163 #[test]
3164 fn test_empty_human_display_is_treated_as_absent() -> Result<()> {
3165 let schema = create_test_schema()?;
3166 let agg = first_value_agg_expr(&schema, "b", "agg", Some(""), None)?;
3167
3168 assert_eq!(agg.human_display(), None);
3169 assert_eq!(format_tree_aggregate_expr(&agg), "agg");
3170
3171 Ok(())
3172 }
3173
3174 #[test]
3175 fn test_human_display_alias_must_match_name() -> Result<()> {
3176 let schema = create_test_schema()?;
3177 let error = first_value_agg_expr(
3178 &schema,
3179 "b",
3180 "agg",
3181 Some("first_value(b) ORDER BY [b ASC NULLS LAST]"),
3182 Some("other_alias"),
3183 )
3184 .unwrap_err();
3185
3186 assert!(
3187 error
3188 .to_string()
3189 .contains("aggregate human_display_alias must match")
3190 );
3191
3192 Ok(())
3193 }
3194
3195 #[test]
3196 fn test_reverse_expr_preserves_non_aliased_display_path() -> Result<()> {
3197 let schema = create_test_schema()?;
3198 let agg = first_value_agg_expr(
3199 &schema,
3200 "b",
3201 "first_value(b) ORDER BY [b ASC NULLS LAST]",
3202 None,
3203 None,
3204 )?;
3205
3206 let reversed = agg.reverse_expr().expect("expected reverse expr");
3207
3208 assert_eq!(
3209 reversed.name(),
3210 "last_value(b) ORDER BY [b DESC NULLS FIRST]"
3211 );
3212 assert_eq!(reversed.human_display(), None);
3213
3214 Ok(())
3215 }
3216
3217 async fn first_last_multi_partitions(
3227 is_first_acc: bool,
3228 spill: bool,
3229 max_memory: usize,
3230 ) -> Result<()> {
3231 let task_ctx = if spill {
3232 new_spill_ctx(2, max_memory)
3233 } else {
3234 Arc::new(TaskContext::default())
3235 };
3236
3237 let (schema, data) = some_data_v2();
3238 let partition1 = data[0].clone();
3239 let partition2 = data[1].clone();
3240 let partition3 = data[2].clone();
3241 let partition4 = data[3].clone();
3242
3243 let groups =
3244 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
3245
3246 let sort_options = SortOptions {
3247 descending: false,
3248 nulls_first: false,
3249 };
3250 let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
3251 vec![test_first_value_agg_expr(&schema, sort_options)?]
3252 } else {
3253 vec![test_last_value_agg_expr(&schema, sort_options)?]
3254 };
3255
3256 let memory_exec = TestMemoryExec::try_new_exec(
3257 &[
3258 vec![partition1],
3259 vec![partition2],
3260 vec![partition3],
3261 vec![partition4],
3262 ],
3263 Arc::clone(&schema),
3264 None,
3265 )?;
3266 let aggregate_exec = Arc::new(AggregateExec::try_new(
3267 AggregateMode::Partial,
3268 groups.clone(),
3269 aggregates.clone(),
3270 vec![None],
3271 memory_exec,
3272 Arc::clone(&schema),
3273 )?);
3274 let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec))
3275 as Arc<dyn ExecutionPlan>;
3276 let aggregate_final = Arc::new(AggregateExec::try_new(
3277 AggregateMode::Final,
3278 groups,
3279 aggregates.clone(),
3280 vec![None],
3281 coalesce,
3282 schema,
3283 )?) as Arc<dyn ExecutionPlan>;
3284
3285 let result = crate::collect(aggregate_final, task_ctx).await?;
3286 if is_first_acc {
3287 allow_duplicates! {
3288 assert_snapshot!(batches_to_string(&result), @r"
3289 +---+--------------------------------------------+
3290 | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
3291 +---+--------------------------------------------+
3292 | 2 | 0.0 |
3293 | 3 | 1.0 |
3294 | 4 | 3.0 |
3295 +---+--------------------------------------------+
3296 ");
3297 }
3298 } else {
3299 allow_duplicates! {
3300 assert_snapshot!(batches_to_string(&result), @r"
3301 +---+-------------------------------------------+
3302 | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
3303 +---+-------------------------------------------+
3304 | 2 | 3.0 |
3305 | 3 | 5.0 |
3306 | 4 | 6.0 |
3307 +---+-------------------------------------------+
3308 ");
3309 }
3310 };
3311 Ok(())
3312 }
3313
3314 #[tokio::test]
3315 async fn test_get_finest_requirements() -> Result<()> {
3316 let test_schema = create_test_schema()?;
3317
3318 let options = SortOptions {
3319 descending: false,
3320 nulls_first: false,
3321 };
3322 let col_a = &col("a", &test_schema)?;
3323 let col_b = &col("b", &test_schema)?;
3324 let col_c = &col("c", &test_schema)?;
3325 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
3326 eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
3328 let order_by_exprs = vec![
3331 vec![],
3332 vec![PhysicalSortExpr {
3333 expr: Arc::clone(col_a),
3334 options,
3335 }],
3336 vec![
3337 PhysicalSortExpr {
3338 expr: Arc::clone(col_a),
3339 options,
3340 },
3341 PhysicalSortExpr {
3342 expr: Arc::clone(col_b),
3343 options,
3344 },
3345 PhysicalSortExpr {
3346 expr: Arc::clone(col_c),
3347 options,
3348 },
3349 ],
3350 vec![
3351 PhysicalSortExpr {
3352 expr: Arc::clone(col_a),
3353 options,
3354 },
3355 PhysicalSortExpr {
3356 expr: Arc::clone(col_b),
3357 options,
3358 },
3359 ],
3360 ];
3361
3362 let common_requirement = vec![
3363 PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
3364 PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
3365 ];
3366 let mut aggr_exprs = order_by_exprs
3367 .into_iter()
3368 .map(|order_by_expr| {
3369 AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
3370 .alias("a")
3371 .order_by(order_by_expr)
3372 .schema(Arc::clone(&test_schema))
3373 .build()
3374 .map(Arc::new)
3375 .unwrap()
3376 })
3377 .collect::<Vec<_>>();
3378 let group_by = PhysicalGroupBy::new_single(vec![]);
3379 let result = get_finer_aggregate_exprs_requirement(
3380 &mut aggr_exprs,
3381 &group_by,
3382 &eq_properties,
3383 &AggregateMode::Partial,
3384 )?;
3385 assert_eq!(result, common_requirement);
3386 Ok(())
3387 }
3388
3389 #[test]
3390 fn test_agg_exec_same_schema() -> Result<()> {
3391 let schema = Arc::new(Schema::new(vec![
3392 Field::new("a", DataType::Float32, true),
3393 Field::new("b", DataType::Float32, true),
3394 ]));
3395
3396 let col_a = col("a", &schema)?;
3397 let option_desc = SortOptions {
3398 descending: true,
3399 nulls_first: true,
3400 };
3401 let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
3402
3403 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3404 test_first_value_agg_expr(&schema, option_desc)?,
3405 test_last_value_agg_expr(&schema, option_desc)?,
3406 ];
3407 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
3408 let aggregate_exec = Arc::new(AggregateExec::try_new(
3409 AggregateMode::Partial,
3410 groups,
3411 aggregates,
3412 vec![None, None],
3413 Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
3414 schema,
3415 )?);
3416 let new_agg =
3417 Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
3418 assert_eq!(new_agg.schema(), aggregate_exec.schema());
3419 Ok(())
3420 }
3421
3422 #[tokio::test]
3423 async fn test_agg_exec_group_by_const() -> Result<()> {
3424 let schema = Arc::new(Schema::new(vec![
3425 Field::new("a", DataType::Float32, true),
3426 Field::new("b", DataType::Float32, true),
3427 Field::new("const", DataType::Int32, false),
3428 ]));
3429
3430 let col_a = col("a", &schema)?;
3431 let col_b = col("b", &schema)?;
3432 let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
3433
3434 let groups = PhysicalGroupBy::new(
3435 vec![
3436 (col_a, "a".to_string()),
3437 (col_b, "b".to_string()),
3438 (const_expr, "const".to_string()),
3439 ],
3440 vec![
3441 (
3442 Arc::new(Literal::new(ScalarValue::Float32(None))),
3443 "a".to_string(),
3444 ),
3445 (
3446 Arc::new(Literal::new(ScalarValue::Float32(None))),
3447 "b".to_string(),
3448 ),
3449 (
3450 Arc::new(Literal::new(ScalarValue::Int32(None))),
3451 "const".to_string(),
3452 ),
3453 ],
3454 vec![
3455 vec![false, true, true],
3456 vec![true, false, true],
3457 vec![true, true, false],
3458 ],
3459 true,
3460 );
3461
3462 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3463 AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
3464 .schema(Arc::clone(&schema))
3465 .alias("1")
3466 .build()
3467 .map(Arc::new)?,
3468 ];
3469
3470 let input_batches = (0..4)
3471 .map(|_| {
3472 let a = Arc::new(Float32Array::from(vec![0.; 8192]));
3473 let b = Arc::new(Float32Array::from(vec![0.; 8192]));
3474 let c = Arc::new(Int32Array::from(vec![1; 8192]));
3475
3476 RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
3477 })
3478 .collect();
3479
3480 let input =
3481 TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
3482
3483 let aggregate_exec = Arc::new(AggregateExec::try_new(
3484 AggregateMode::Single,
3485 groups,
3486 aggregates.clone(),
3487 vec![None],
3488 input,
3489 schema,
3490 )?);
3491
3492 let output =
3493 collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
3494
3495 allow_duplicates! {
3496 assert_snapshot!(batches_to_sort_string(&output), @r"
3497 +-----+-----+-------+---------------+-------+
3498 | a | b | const | __grouping_id | 1 |
3499 +-----+-----+-------+---------------+-------+
3500 | | | 1 | 6 | 32768 |
3501 | | 0.0 | | 5 | 32768 |
3502 | 0.0 | | | 3 | 32768 |
3503 +-----+-----+-------+---------------+-------+
3504 ");
3505 }
3506
3507 Ok(())
3508 }
3509
3510 #[tokio::test]
3511 async fn test_agg_exec_struct_of_dicts() -> Result<()> {
3512 let batch = RecordBatch::try_new(
3513 Arc::new(Schema::new(vec![
3514 Field::new(
3515 "labels".to_string(),
3516 DataType::Struct(
3517 vec![
3518 Field::new(
3519 "a".to_string(),
3520 DataType::Dictionary(
3521 Box::new(DataType::Int32),
3522 Box::new(DataType::Utf8),
3523 ),
3524 true,
3525 ),
3526 Field::new(
3527 "b".to_string(),
3528 DataType::Dictionary(
3529 Box::new(DataType::Int32),
3530 Box::new(DataType::Utf8),
3531 ),
3532 true,
3533 ),
3534 ]
3535 .into(),
3536 ),
3537 false,
3538 ),
3539 Field::new("value", DataType::UInt64, false),
3540 ])),
3541 vec![
3542 Arc::new(StructArray::from(vec![
3543 (
3544 Arc::new(Field::new(
3545 "a".to_string(),
3546 DataType::Dictionary(
3547 Box::new(DataType::Int32),
3548 Box::new(DataType::Utf8),
3549 ),
3550 true,
3551 )),
3552 Arc::new(
3553 vec![Some("a"), None, Some("a")]
3554 .into_iter()
3555 .collect::<DictionaryArray<Int32Type>>(),
3556 ) as ArrayRef,
3557 ),
3558 (
3559 Arc::new(Field::new(
3560 "b".to_string(),
3561 DataType::Dictionary(
3562 Box::new(DataType::Int32),
3563 Box::new(DataType::Utf8),
3564 ),
3565 true,
3566 )),
3567 Arc::new(
3568 vec![Some("b"), Some("c"), Some("b")]
3569 .into_iter()
3570 .collect::<DictionaryArray<Int32Type>>(),
3571 ) as ArrayRef,
3572 ),
3573 ])),
3574 Arc::new(UInt64Array::from(vec![1, 1, 1])),
3575 ],
3576 )
3577 .expect("Failed to create RecordBatch");
3578
3579 let group_by = PhysicalGroupBy::new_single(vec![(
3580 col("labels", &batch.schema())?,
3581 "labels".to_string(),
3582 )]);
3583
3584 let aggr_expr = vec![
3585 AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?])
3586 .schema(Arc::clone(&batch.schema()))
3587 .alias(String::from("SUM(value)"))
3588 .build()
3589 .map(Arc::new)?,
3590 ];
3591
3592 let input = TestMemoryExec::try_new_exec(
3593 &[vec![batch.clone()]],
3594 Arc::<Schema>::clone(&batch.schema()),
3595 None,
3596 )?;
3597 let aggregate_exec = Arc::new(AggregateExec::try_new(
3598 AggregateMode::FinalPartitioned,
3599 group_by,
3600 aggr_expr,
3601 vec![None],
3602 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3603 batch.schema(),
3604 )?);
3605
3606 let session_config = SessionConfig::default();
3607 let ctx = TaskContext::default().with_session_config(session_config);
3608 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3609
3610 allow_duplicates! {
3611 assert_snapshot!(batches_to_string(&output), @r"
3612 +--------------+------------+
3613 | labels | SUM(value) |
3614 +--------------+------------+
3615 | {a: a, b: b} | 2 |
3616 | {a: , b: c} | 1 |
3617 +--------------+------------+
3618 ");
3619 }
3620
3621 Ok(())
3622 }
3623
3624 #[tokio::test]
3625 async fn test_skip_aggregation_after_first_batch() -> Result<()> {
3626 let schema = Arc::new(Schema::new(vec![
3627 Field::new("key", DataType::Int32, true),
3628 Field::new("val", DataType::Int32, true),
3629 ]));
3630
3631 let group_by =
3632 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3633
3634 let aggr_expr = vec![
3635 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3636 .schema(Arc::clone(&schema))
3637 .alias(String::from("COUNT(val)"))
3638 .build()
3639 .map(Arc::new)?,
3640 ];
3641
3642 let input_data = vec![
3643 RecordBatch::try_new(
3644 Arc::clone(&schema),
3645 vec![
3646 Arc::new(Int32Array::from(vec![1, 2, 3])),
3647 Arc::new(Int32Array::from(vec![0, 0, 0])),
3648 ],
3649 )
3650 .unwrap(),
3651 RecordBatch::try_new(
3652 Arc::clone(&schema),
3653 vec![
3654 Arc::new(Int32Array::from(vec![2, 3, 4])),
3655 Arc::new(Int32Array::from(vec![0, 0, 0])),
3656 ],
3657 )
3658 .unwrap(),
3659 ];
3660
3661 let input =
3662 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3663 let aggregate_exec = Arc::new(AggregateExec::try_new(
3664 AggregateMode::Partial,
3665 group_by,
3666 aggr_expr,
3667 vec![None],
3668 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3669 schema,
3670 )?);
3671
3672 let mut session_config = SessionConfig::default();
3673 session_config = session_config.set(
3674 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3675 &ScalarValue::Int64(Some(2)),
3676 );
3677 session_config = session_config.set(
3678 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3679 &ScalarValue::Float64(Some(0.1)),
3680 );
3681
3682 let ctx = TaskContext::default().with_session_config(session_config);
3683 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3684
3685 allow_duplicates! {
3686 assert_snapshot!(batches_to_string(&output), @r"
3687 +-----+-------------------+
3688 | key | COUNT(val)[count] |
3689 +-----+-------------------+
3690 | 1 | 1 |
3691 | 2 | 1 |
3692 | 3 | 1 |
3693 | 2 | 1 |
3694 | 3 | 1 |
3695 | 4 | 1 |
3696 +-----+-------------------+
3697 ");
3698 }
3699
3700 Ok(())
3701 }
3702
3703 #[tokio::test]
3704 async fn test_skip_aggregation_after_threshold() -> Result<()> {
3705 let schema = Arc::new(Schema::new(vec![
3706 Field::new("key", DataType::Int32, true),
3707 Field::new("val", DataType::Int32, true),
3708 ]));
3709
3710 let group_by =
3711 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
3712
3713 let aggr_expr = vec![
3714 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
3715 .schema(Arc::clone(&schema))
3716 .alias(String::from("COUNT(val)"))
3717 .build()
3718 .map(Arc::new)?,
3719 ];
3720
3721 let input_data = vec![
3722 RecordBatch::try_new(
3723 Arc::clone(&schema),
3724 vec![
3725 Arc::new(Int32Array::from(vec![1, 2, 3])),
3726 Arc::new(Int32Array::from(vec![0, 0, 0])),
3727 ],
3728 )
3729 .unwrap(),
3730 RecordBatch::try_new(
3731 Arc::clone(&schema),
3732 vec![
3733 Arc::new(Int32Array::from(vec![2, 3, 4])),
3734 Arc::new(Int32Array::from(vec![0, 0, 0])),
3735 ],
3736 )
3737 .unwrap(),
3738 RecordBatch::try_new(
3739 Arc::clone(&schema),
3740 vec![
3741 Arc::new(Int32Array::from(vec![2, 3, 4])),
3742 Arc::new(Int32Array::from(vec![0, 0, 0])),
3743 ],
3744 )
3745 .unwrap(),
3746 ];
3747
3748 let input =
3749 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
3750 let aggregate_exec = Arc::new(AggregateExec::try_new(
3751 AggregateMode::Partial,
3752 group_by,
3753 aggr_expr,
3754 vec![None],
3755 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3756 schema,
3757 )?);
3758
3759 let mut session_config = SessionConfig::default();
3760 session_config = session_config.set(
3761 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
3762 &ScalarValue::Int64(Some(5)),
3763 );
3764 session_config = session_config.set(
3765 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
3766 &ScalarValue::Float64(Some(0.1)),
3767 );
3768
3769 let ctx = TaskContext::default().with_session_config(session_config);
3770 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
3771
3772 allow_duplicates! {
3773 assert_snapshot!(batches_to_string(&output), @r"
3774 +-----+-------------------+
3775 | key | COUNT(val)[count] |
3776 +-----+-------------------+
3777 | 1 | 1 |
3778 | 2 | 2 |
3779 | 3 | 2 |
3780 | 4 | 1 |
3781 | 2 | 1 |
3782 | 3 | 1 |
3783 | 4 | 1 |
3784 +-----+-------------------+
3785 ");
3786 }
3787
3788 Ok(())
3789 }
3790
3791 #[test]
3792 fn group_exprs_nullable() -> Result<()> {
3793 let input_schema = Arc::new(Schema::new(vec![
3794 Field::new("a", DataType::Float32, false),
3795 Field::new("b", DataType::Float32, false),
3796 ]));
3797
3798 let aggr_expr = vec![
3799 AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
3800 .schema(Arc::clone(&input_schema))
3801 .alias("COUNT(a)")
3802 .build()
3803 .map(Arc::new)?,
3804 ];
3805
3806 let grouping_set = PhysicalGroupBy::new(
3807 vec![
3808 (col("a", &input_schema)?, "a".to_string()),
3809 (col("b", &input_schema)?, "b".to_string()),
3810 ],
3811 vec![
3812 (lit(ScalarValue::Float32(None)), "a".to_string()),
3813 (lit(ScalarValue::Float32(None)), "b".to_string()),
3814 ],
3815 vec![
3816 vec![false, true], vec![false, false], ],
3819 true,
3820 );
3821 let aggr_schema = create_schema(
3822 &input_schema,
3823 &grouping_set,
3824 &aggr_expr,
3825 AggregateMode::Final,
3826 )?;
3827 let expected_schema = Schema::new(vec![
3828 Field::new("a", DataType::Float32, false),
3829 Field::new("b", DataType::Float32, true),
3830 Field::new("__grouping_id", DataType::UInt8, false),
3831 Field::new("COUNT(a)", DataType::Int64, false),
3832 ]);
3833 assert_eq!(aggr_schema, expected_schema);
3834 Ok(())
3835 }
3836
3837 async fn run_test_with_spill_pool_if_necessary(
3839 pool_size: usize,
3840 expect_spill: bool,
3841 ) -> Result<()> {
3842 fn create_record_batch(
3843 schema: &Arc<Schema>,
3844 data: (Vec<u32>, Vec<f64>),
3845 ) -> Result<RecordBatch> {
3846 Ok(RecordBatch::try_new(
3847 Arc::clone(schema),
3848 vec![
3849 Arc::new(UInt32Array::from(data.0)),
3850 Arc::new(Float64Array::from(data.1)),
3851 ],
3852 )?)
3853 }
3854
3855 let schema = Arc::new(Schema::new(vec![
3856 Field::new("a", DataType::UInt32, false),
3857 Field::new("b", DataType::Float64, false),
3858 ]));
3859
3860 let batches = vec![
3861 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3862 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3863 ];
3864 let plan: Arc<dyn ExecutionPlan> =
3865 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3866
3867 let grouping_set = PhysicalGroupBy::new(
3868 vec![(col("a", &schema)?, "a".to_string())],
3869 vec![],
3870 vec![vec![false]],
3871 false,
3872 );
3873
3874 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3876 Arc::new(
3877 AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3878 .schema(Arc::clone(&schema))
3879 .alias("MIN(b)")
3880 .build()?,
3881 ),
3882 Arc::new(
3883 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3884 .schema(Arc::clone(&schema))
3885 .alias("AVG(b)")
3886 .build()?,
3887 ),
3888 ];
3889
3890 let single_aggregate = Arc::new(AggregateExec::try_new(
3891 AggregateMode::Single,
3892 grouping_set,
3893 aggregates,
3894 vec![None, None],
3895 plan,
3896 Arc::clone(&schema),
3897 )?);
3898
3899 let batch_size = 2;
3900 let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3901 let task_ctx = Arc::new(
3902 TaskContext::default()
3903 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3904 .with_runtime(Arc::new(
3905 RuntimeEnvBuilder::new()
3906 .with_memory_pool(memory_pool)
3907 .build()?,
3908 )),
3909 );
3910
3911 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3912
3913 assert_spill_count_metric(expect_spill, single_aggregate);
3914
3915 allow_duplicates! {
3916 assert_snapshot!(batches_to_string(&result), @r"
3917 +---+--------+--------+
3918 | a | MIN(b) | AVG(b) |
3919 +---+--------+--------+
3920 | 2 | 1.0 | 1.0 |
3921 | 3 | 2.0 | 2.0 |
3922 | 4 | 3.0 | 3.5 |
3923 +---+--------+--------+
3924 ");
3925 }
3926
3927 Ok(())
3928 }
3929
3930 fn assert_spill_count_metric(
3931 expect_spill: bool,
3932 single_aggregate: Arc<AggregateExec>,
3933 ) {
3934 if let Some(metrics_set) = single_aggregate.metrics() {
3935 let mut spill_count = 0;
3936
3937 for metric in metrics_set.iter() {
3939 if let MetricValue::SpillCount(count) = metric.value() {
3940 spill_count = count.value();
3941 break;
3942 }
3943 }
3944
3945 if expect_spill && spill_count == 0 {
3946 panic!(
3947 "Expected spill but SpillCount metric not found or SpillCount was 0."
3948 );
3949 } else if !expect_spill && spill_count > 0 {
3950 panic!(
3951 "Expected no spill but found SpillCount metric with value greater than 0."
3952 );
3953 }
3954 } else {
3955 panic!("No metrics returned from the operator; cannot verify spilling.");
3956 }
3957 }
3958
3959 #[tokio::test]
3960 async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3961 run_test_with_spill_pool_if_necessary(2_000, true).await?;
3963 run_test_with_spill_pool_if_necessary(20_000, false).await?;
3965 Ok(())
3966 }
3967
3968 #[tokio::test]
3969 async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
3970 fn create_record_batch(
3972 schema: &Arc<Schema>,
3973 data: (Vec<u32>, Vec<f64>),
3974 ) -> Result<RecordBatch> {
3975 Ok(RecordBatch::try_new(
3976 Arc::clone(schema),
3977 vec![
3978 Arc::new(UInt32Array::from(data.0)),
3979 Arc::new(Float64Array::from(data.1)),
3980 ],
3981 )?)
3982 }
3983
3984 let schema = Arc::new(Schema::new(vec![
3985 Field::new("a", DataType::UInt32, false),
3986 Field::new("b", DataType::Float64, false),
3987 ]));
3988
3989 let batches = vec![
3990 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3991 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3992 ];
3993 let plan: Arc<dyn ExecutionPlan> =
3994 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3995 let proj = ProjectionExec::try_new(
3996 vec![
3997 ProjectionExpr::new(lit("0"), "l".to_string()),
3998 ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?,
3999 ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?,
4000 ],
4001 plan,
4002 )?;
4003 let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
4004 let schema = plan.schema();
4005
4006 let grouping_set = PhysicalGroupBy::new(
4007 vec![
4008 (col("l", &schema)?, "l".to_string()),
4009 (col("a", &schema)?, "a".to_string()),
4010 ],
4011 vec![],
4012 vec![vec![false, false]],
4013 false,
4014 );
4015
4016 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
4018 Arc::new(
4019 AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
4020 .schema(Arc::clone(&schema))
4021 .alias("MIN(b)")
4022 .build()?,
4023 ),
4024 Arc::new(
4025 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
4026 .schema(Arc::clone(&schema))
4027 .alias("AVG(b)")
4028 .build()?,
4029 ),
4030 ];
4031
4032 let single_aggregate = Arc::new(AggregateExec::try_new(
4033 AggregateMode::Single,
4034 grouping_set,
4035 aggregates,
4036 vec![None, None],
4037 plan,
4038 Arc::clone(&schema),
4039 )?);
4040
4041 let batch_size = 2;
4042 let memory_pool = Arc::new(FairSpillPool::new(2000));
4043 let task_ctx = Arc::new(
4044 TaskContext::default()
4045 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
4046 .with_runtime(Arc::new(
4047 RuntimeEnvBuilder::new()
4048 .with_memory_pool(memory_pool)
4049 .build()?,
4050 )),
4051 );
4052
4053 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await;
4054 match result {
4055 Ok(result) => {
4056 assert_spill_count_metric(true, single_aggregate);
4057
4058 allow_duplicates! {
4059 assert_snapshot!(batches_to_string(&result), @r"
4060 +---+---+--------+--------+
4061 | l | a | MIN(b) | AVG(b) |
4062 +---+---+--------+--------+
4063 | 0 | 2 | 1.0 | 1.0 |
4064 | 0 | 3 | 2.0 | 2.0 |
4065 | 0 | 4 | 3.0 | 3.5 |
4066 +---+---+--------+--------+
4067 ");
4068 }
4069 }
4070 Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))),
4071 }
4072
4073 Ok(())
4074 }
4075
4076 #[tokio::test]
4077 async fn test_aggregate_statistics_edge_cases() -> Result<()> {
4078 use datafusion_common::ColumnStatistics;
4079
4080 let schema = Arc::new(Schema::new(vec![
4081 Field::new("a", DataType::Int32, false),
4082 Field::new("b", DataType::Float64, false),
4083 ]));
4084
4085 let absent_byte_stats = Statistics {
4086 num_rows: Precision::Exact(100),
4087 total_byte_size: Precision::Absent,
4088 column_statistics: vec![
4089 ColumnStatistics::new_unknown(),
4090 ColumnStatistics::new_unknown(),
4091 ],
4092 };
4093 let agg = build_test_aggregate(
4094 &schema,
4095 absent_byte_stats,
4096 PhysicalGroupBy::default(),
4097 None,
4098 )?;
4099 let stats = agg.partition_statistics(None)?;
4100 assert_eq!(stats.total_byte_size, Precision::Absent);
4101
4102 let zero_row_stats = Statistics {
4103 num_rows: Precision::Exact(0),
4104 total_byte_size: Precision::Exact(0),
4105 column_statistics: vec![
4106 ColumnStatistics::new_unknown(),
4107 ColumnStatistics::new_unknown(),
4108 ],
4109 };
4110 let agg_zero = build_test_aggregate(
4111 &schema,
4112 zero_row_stats,
4113 PhysicalGroupBy::default(),
4114 None,
4115 )?;
4116 let stats_zero = agg_zero.partition_statistics(None)?;
4117 assert_eq!(stats_zero.total_byte_size, Precision::Absent);
4118
4119 Ok(())
4120 }
4121
4122 fn build_test_aggregate(
4123 schema: &SchemaRef,
4124 stats: Statistics,
4125 group_by: PhysicalGroupBy,
4126 limit: Option<LimitOptions>,
4127 ) -> Result<AggregateExec> {
4128 let input = Arc::new(StatisticsExec::new(stats, (**schema).clone()))
4129 as Arc<dyn ExecutionPlan>;
4130
4131 let mut agg = AggregateExec::try_new(
4132 AggregateMode::Final,
4133 group_by,
4134 vec![Arc::new(
4135 AggregateExprBuilder::new(count_udaf(), vec![col("a", schema)?])
4136 .schema(Arc::clone(schema))
4137 .alias("COUNT(a)")
4138 .build()?,
4139 )],
4140 vec![None],
4141 input,
4142 Arc::clone(schema),
4143 )?;
4144
4145 if let Some(limit) = limit {
4146 agg = agg.with_limit_options(Some(limit));
4147 }
4148
4149 Ok(agg)
4150 }
4151
4152 fn simple_group_by(schema: &SchemaRef, cols: &[&str]) -> PhysicalGroupBy {
4153 if cols.is_empty() {
4154 PhysicalGroupBy::default()
4155 } else {
4156 PhysicalGroupBy::new_single(
4157 cols.iter()
4158 .map(|name| {
4159 (
4160 col(name, schema).unwrap() as Arc<dyn PhysicalExpr>,
4161 name.to_string(),
4162 )
4163 })
4164 .collect(),
4165 )
4166 }
4167 }
4168
4169 #[test]
4170 fn test_aggregate_cardinality_estimation() -> Result<()> {
4171 use datafusion_common::ColumnStatistics;
4172
4173 let schema = Arc::new(Schema::new(vec![
4174 Field::new("a", DataType::Int32, true),
4175 Field::new("b", DataType::Int32, true),
4176 ]));
4177
4178 struct TestCase {
4179 name: &'static str,
4180 input_rows: Precision<usize>,
4181 col_a_stats: ColumnStatistics,
4182 col_b_stats: ColumnStatistics,
4183 group_by_cols: Vec<&'static str>,
4184 limit_options: Option<LimitOptions>,
4185 expected_num_rows: Precision<usize>,
4186 }
4187
4188 let cases = vec![
4189 TestCase {
4191 name: "single group-by col with NDV tightens estimate",
4192 input_rows: Precision::Exact(1_000_000),
4193 col_a_stats: ColumnStatistics {
4194 distinct_count: Precision::Exact(500),
4195 ..ColumnStatistics::new_unknown()
4196 },
4197 col_b_stats: ColumnStatistics::new_unknown(),
4198 group_by_cols: vec!["a"],
4199 limit_options: None,
4200 expected_num_rows: Precision::Inexact(500),
4201 },
4202 TestCase {
4203 name: "multi-col group-by multiplies NDVs",
4204 input_rows: Precision::Exact(1_000_000),
4205 col_a_stats: ColumnStatistics {
4206 distinct_count: Precision::Exact(100),
4207 ..ColumnStatistics::new_unknown()
4208 },
4209 col_b_stats: ColumnStatistics {
4210 distinct_count: Precision::Exact(50),
4211 ..ColumnStatistics::new_unknown()
4212 },
4213 group_by_cols: vec!["a", "b"],
4214 limit_options: None,
4215 expected_num_rows: Precision::Inexact(5_000),
4216 },
4217 TestCase {
4218 name: "NDV product capped by input rows",
4219 input_rows: Precision::Exact(200),
4220 col_a_stats: ColumnStatistics {
4221 distinct_count: Precision::Exact(100),
4222 ..ColumnStatistics::new_unknown()
4223 },
4224 col_b_stats: ColumnStatistics {
4225 distinct_count: Precision::Exact(50),
4226 ..ColumnStatistics::new_unknown()
4227 },
4228 group_by_cols: vec!["a", "b"],
4229 limit_options: None,
4230 expected_num_rows: Precision::Inexact(200),
4231 },
4232 TestCase {
4233 name: "null adjustment adds +1 per column",
4234 input_rows: Precision::Exact(1_000_000),
4235 col_a_stats: ColumnStatistics {
4236 distinct_count: Precision::Exact(99),
4237 null_count: Precision::Exact(10),
4238 ..ColumnStatistics::new_unknown()
4239 },
4240 col_b_stats: ColumnStatistics::new_unknown(),
4241 group_by_cols: vec!["a"],
4242 limit_options: None,
4243 expected_num_rows: Precision::Inexact(100),
4245 },
4246 TestCase {
4247 name: "null adjustment on multiple columns",
4248 input_rows: Precision::Exact(1_000_000),
4249 col_a_stats: ColumnStatistics {
4250 distinct_count: Precision::Exact(99),
4251 null_count: Precision::Exact(5),
4252 ..ColumnStatistics::new_unknown()
4253 },
4254 col_b_stats: ColumnStatistics {
4255 distinct_count: Precision::Exact(49),
4256 null_count: Precision::Exact(3),
4257 ..ColumnStatistics::new_unknown()
4258 },
4259 group_by_cols: vec!["a", "b"],
4260 limit_options: None,
4261 expected_num_rows: Precision::Inexact(5_000),
4263 },
4264 TestCase {
4265 name: "zero null_count means no adjustment",
4266 input_rows: Precision::Exact(1_000_000),
4267 col_a_stats: ColumnStatistics {
4268 distinct_count: Precision::Exact(100),
4269 null_count: Precision::Exact(0),
4270 ..ColumnStatistics::new_unknown()
4271 },
4272 col_b_stats: ColumnStatistics::new_unknown(),
4273 group_by_cols: vec!["a"],
4274 limit_options: None,
4275 expected_num_rows: Precision::Inexact(100),
4276 },
4277 TestCase {
4279 name: "bail out when one group-by col lacks NDV",
4280 input_rows: Precision::Exact(1_000_000),
4281 col_a_stats: ColumnStatistics {
4282 distinct_count: Precision::Exact(100),
4283 ..ColumnStatistics::new_unknown()
4284 },
4285 col_b_stats: ColumnStatistics::new_unknown(),
4286 group_by_cols: vec!["a", "b"],
4287 limit_options: None,
4288 expected_num_rows: Precision::Inexact(1_000_000),
4289 },
4290 TestCase {
4291 name: "bail out when all group-by cols lack NDV",
4292 input_rows: Precision::Exact(1_000_000),
4293 col_a_stats: ColumnStatistics::new_unknown(),
4294 col_b_stats: ColumnStatistics::new_unknown(),
4295 group_by_cols: vec!["a"],
4296 limit_options: None,
4297 expected_num_rows: Precision::Inexact(1_000_000),
4298 },
4299 TestCase {
4301 name: "TopK limit caps output rows",
4302 input_rows: Precision::Exact(1_000_000),
4303 col_a_stats: ColumnStatistics::new_unknown(),
4304 col_b_stats: ColumnStatistics::new_unknown(),
4305 group_by_cols: vec!["a"],
4306 limit_options: Some(LimitOptions::new(10)),
4307 expected_num_rows: Precision::Inexact(10),
4308 },
4309 TestCase {
4310 name: "NDV + TopK limit: min(NDV, limit) when NDV < limit",
4311 input_rows: Precision::Exact(1_000_000),
4312 col_a_stats: ColumnStatistics {
4313 distinct_count: Precision::Exact(5),
4314 ..ColumnStatistics::new_unknown()
4315 },
4316 col_b_stats: ColumnStatistics::new_unknown(),
4317 group_by_cols: vec!["a"],
4318 limit_options: Some(LimitOptions::new(10)),
4319 expected_num_rows: Precision::Inexact(5),
4320 },
4321 TestCase {
4322 name: "NDV + TopK limit: min(NDV, limit) when limit < NDV",
4323 input_rows: Precision::Exact(1_000_000),
4324 col_a_stats: ColumnStatistics {
4325 distinct_count: Precision::Exact(500),
4326 ..ColumnStatistics::new_unknown()
4327 },
4328 col_b_stats: ColumnStatistics::new_unknown(),
4329 group_by_cols: vec!["a"],
4330 limit_options: Some(LimitOptions::new(10)),
4331 expected_num_rows: Precision::Inexact(10),
4332 },
4333 TestCase {
4335 name: "absent input rows without limit stays absent",
4336 input_rows: Precision::Absent,
4337 col_a_stats: ColumnStatistics::new_unknown(),
4338 col_b_stats: ColumnStatistics::new_unknown(),
4339 group_by_cols: vec!["a"],
4340 limit_options: None,
4341 expected_num_rows: Precision::Absent,
4342 },
4343 TestCase {
4344 name: "absent input rows with TopK limit gives inexact(limit)",
4345 input_rows: Precision::Absent,
4346 col_a_stats: ColumnStatistics::new_unknown(),
4347 col_b_stats: ColumnStatistics::new_unknown(),
4348 group_by_cols: vec!["a"],
4349 limit_options: Some(LimitOptions::new(10)),
4350 expected_num_rows: Precision::Inexact(10),
4351 },
4352 TestCase {
4354 name: "no group-by cols (Final mode) returns Exact(1)",
4355 input_rows: Precision::Exact(1_000_000),
4356 col_a_stats: ColumnStatistics::new_unknown(),
4357 col_b_stats: ColumnStatistics::new_unknown(),
4358 group_by_cols: vec![],
4359 limit_options: None,
4360 expected_num_rows: Precision::Exact(1),
4361 },
4362 TestCase {
4364 name: "one input row returns Exact(1)",
4365 input_rows: Precision::Exact(1),
4366 col_a_stats: ColumnStatistics {
4367 distinct_count: Precision::Exact(1),
4368 ..ColumnStatistics::new_unknown()
4369 },
4370 col_b_stats: ColumnStatistics::new_unknown(),
4371 group_by_cols: vec!["a"],
4372 limit_options: None,
4373 expected_num_rows: Precision::Exact(1),
4374 },
4375 TestCase {
4377 name: "zero input rows returns Exact(0)",
4378 input_rows: Precision::Exact(0),
4379 col_a_stats: ColumnStatistics::new_unknown(),
4380 col_b_stats: ColumnStatistics::new_unknown(),
4381 group_by_cols: vec!["a"],
4382 limit_options: None,
4383 expected_num_rows: Precision::Exact(0),
4384 },
4385 TestCase {
4387 name: "inexact NDV still used for estimation",
4388 input_rows: Precision::Exact(1_000_000),
4389 col_a_stats: ColumnStatistics {
4390 distinct_count: Precision::Inexact(200),
4391 ..ColumnStatistics::new_unknown()
4392 },
4393 col_b_stats: ColumnStatistics::new_unknown(),
4394 group_by_cols: vec!["a"],
4395 limit_options: None,
4396 expected_num_rows: Precision::Inexact(200),
4397 },
4398 TestCase {
4399 name: "inexact NDV combined with limit",
4400 input_rows: Precision::Exact(1_000_000),
4401 col_a_stats: ColumnStatistics {
4402 distinct_count: Precision::Inexact(200),
4403 ..ColumnStatistics::new_unknown()
4404 },
4405 col_b_stats: ColumnStatistics::new_unknown(),
4406 group_by_cols: vec!["a"],
4407 limit_options: Some(LimitOptions::new(10)),
4408 expected_num_rows: Precision::Inexact(10),
4409 },
4410 TestCase {
4412 name: "all-null column contributes 1 to the product, not 0",
4413 input_rows: Precision::Exact(1_000),
4414 col_a_stats: ColumnStatistics {
4415 distinct_count: Precision::Exact(0),
4416 null_count: Precision::Exact(1_000),
4417 ..ColumnStatistics::new_unknown()
4418 },
4419 col_b_stats: ColumnStatistics {
4420 distinct_count: Precision::Exact(50),
4421 ..ColumnStatistics::new_unknown()
4422 },
4423 group_by_cols: vec!["a", "b"],
4424 limit_options: None,
4425 expected_num_rows: Precision::Inexact(50),
4427 },
4428 TestCase {
4430 name: "absent num_rows falls back to NDV estimate",
4431 input_rows: Precision::Absent,
4432 col_a_stats: ColumnStatistics {
4433 distinct_count: Precision::Exact(100),
4434 ..ColumnStatistics::new_unknown()
4435 },
4436 col_b_stats: ColumnStatistics::new_unknown(),
4437 group_by_cols: vec!["a"],
4438 limit_options: None,
4439 expected_num_rows: Precision::Inexact(100),
4440 },
4441 TestCase {
4442 name: "absent num_rows with NDV and limit returns min(ndv, limit)",
4443 input_rows: Precision::Absent,
4444 col_a_stats: ColumnStatistics {
4445 distinct_count: Precision::Exact(100),
4446 ..ColumnStatistics::new_unknown()
4447 },
4448 col_b_stats: ColumnStatistics::new_unknown(),
4449 group_by_cols: vec!["a"],
4450 limit_options: Some(LimitOptions::new(10)),
4451 expected_num_rows: Precision::Inexact(10),
4452 },
4453 ];
4454
4455 for case in cases {
4456 let input_stats = Statistics {
4457 num_rows: case.input_rows,
4458 total_byte_size: Precision::Inexact(1_000_000),
4459 column_statistics: vec![
4460 case.col_a_stats.clone(),
4461 case.col_b_stats.clone(),
4462 ],
4463 };
4464
4465 let group_by = simple_group_by(&schema, &case.group_by_cols);
4466 let agg =
4467 build_test_aggregate(&schema, input_stats, group_by, case.limit_options)?;
4468
4469 let stats = agg.partition_statistics(None)?;
4470 assert_eq!(
4471 stats.num_rows, case.expected_num_rows,
4472 "FAILED: '{}' — expected {:?}, got {:?}",
4473 case.name, case.expected_num_rows, stats.num_rows
4474 );
4475 }
4476
4477 Ok(())
4478 }
4479
4480 #[test]
4481 fn test_aggregate_stats_distinct_count_propagation() -> Result<()> {
4482 use datafusion_common::ColumnStatistics;
4483
4484 let schema = Arc::new(Schema::new(vec![
4485 Field::new("a", DataType::Int32, true),
4486 Field::new("b", DataType::Int32, true),
4487 ]));
4488
4489 let input_stats = Statistics {
4490 num_rows: Precision::Exact(1000),
4491 total_byte_size: Precision::Inexact(10000),
4492 column_statistics: vec![
4493 ColumnStatistics {
4494 distinct_count: Precision::Exact(100),
4495 null_count: Precision::Exact(5),
4496 ..ColumnStatistics::new_unknown()
4497 },
4498 ColumnStatistics::new_unknown(),
4499 ],
4500 };
4501 let agg = build_test_aggregate(
4502 &schema,
4503 input_stats,
4504 simple_group_by(&schema, &["a"]),
4505 None,
4506 )?;
4507
4508 let stats = agg.partition_statistics(None)?;
4509 assert_eq!(
4510 stats.column_statistics[0].distinct_count,
4511 Precision::Exact(100),
4512 "distinct_count should be propagated from child for group-by columns"
4513 );
4514
4515 Ok(())
4516 }
4517
4518 #[test]
4519 fn test_aggregate_stats_grouping_sets() -> Result<()> {
4520 use datafusion_common::ColumnStatistics;
4521
4522 let schema = Arc::new(Schema::new(vec![
4523 Field::new("a", DataType::Int32, true),
4524 Field::new("b", DataType::Int32, true),
4525 ]));
4526
4527 let input_stats = Statistics {
4528 num_rows: Precision::Exact(1_000_000),
4529 total_byte_size: Precision::Inexact(1_000_000),
4530 column_statistics: vec![
4531 ColumnStatistics {
4532 distinct_count: Precision::Exact(100),
4533 ..ColumnStatistics::new_unknown()
4534 },
4535 ColumnStatistics {
4536 distinct_count: Precision::Exact(50),
4537 ..ColumnStatistics::new_unknown()
4538 },
4539 ],
4540 };
4541
4542 let grouping_set = PhysicalGroupBy::new(
4544 vec![
4545 (col("a", &schema)? as Arc<dyn PhysicalExpr>, "a".to_string()),
4546 (col("b", &schema)? as Arc<dyn PhysicalExpr>, "b".to_string()),
4547 ],
4548 vec![
4549 (lit(ScalarValue::Int32(None)), "a".to_string()),
4550 (lit(ScalarValue::Int32(None)), "b".to_string()),
4551 ],
4552 vec![
4553 vec![false, true], vec![true, false], vec![false, false], ],
4557 true,
4558 );
4559
4560 let agg = build_test_aggregate(&schema, input_stats, grouping_set, None)?;
4561
4562 let stats = agg.partition_statistics(None)?;
4563 assert_eq!(
4566 stats.num_rows,
4567 Precision::Inexact(5_150),
4568 "grouping sets should sum per-set NDV products"
4569 );
4570
4571 Ok(())
4572 }
4573
4574 #[test]
4575 fn test_aggregate_stats_non_column_expr_bails_out() -> Result<()> {
4576 use datafusion_common::ColumnStatistics;
4577 use datafusion_expr::Operator;
4578 use datafusion_physical_expr::expressions::BinaryExpr;
4579
4580 let schema = Arc::new(Schema::new(vec![
4581 Field::new("a", DataType::Int32, true),
4582 Field::new("b", DataType::Int32, true),
4583 ]));
4584
4585 let input_stats = Statistics {
4586 num_rows: Precision::Exact(1_000_000),
4587 total_byte_size: Precision::Inexact(1_000_000),
4588 column_statistics: vec![
4589 ColumnStatistics {
4590 distinct_count: Precision::Exact(100),
4591 ..ColumnStatistics::new_unknown()
4592 },
4593 ColumnStatistics {
4594 distinct_count: Precision::Exact(50),
4595 ..ColumnStatistics::new_unknown()
4596 },
4597 ],
4598 };
4599
4600 let expr_a_plus_b: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
4602 col("a", &schema)?,
4603 Operator::Plus,
4604 col("b", &schema)?,
4605 ));
4606
4607 let group_by =
4608 PhysicalGroupBy::new_single(vec![(expr_a_plus_b, "a+b".to_string())]);
4609 let agg = build_test_aggregate(&schema, input_stats, group_by, None)?;
4610
4611 let stats = agg.partition_statistics(None)?;
4612 assert_eq!(
4613 stats.num_rows,
4614 Precision::Inexact(1_000_000),
4615 "non-column group-by expression should bail out to input_rows"
4616 );
4617
4618 Ok(())
4619 }
4620
4621 #[tokio::test]
4622 async fn test_order_is_retained_when_spilling() -> Result<()> {
4623 let schema = Arc::new(Schema::new(vec![
4624 Field::new("a", DataType::Int64, false),
4625 Field::new("b", DataType::Int64, false),
4626 Field::new("c", DataType::Int64, false),
4627 ]));
4628
4629 let batches = vec![vec![
4630 RecordBatch::try_new(
4631 Arc::clone(&schema),
4632 vec![
4633 Arc::new(Int64Array::from(vec![2])),
4634 Arc::new(Int64Array::from(vec![2])),
4635 Arc::new(Int64Array::from(vec![1])),
4636 ],
4637 )?,
4638 RecordBatch::try_new(
4639 Arc::clone(&schema),
4640 vec![
4641 Arc::new(Int64Array::from(vec![1])),
4642 Arc::new(Int64Array::from(vec![1])),
4643 Arc::new(Int64Array::from(vec![1])),
4644 ],
4645 )?,
4646 RecordBatch::try_new(
4647 Arc::clone(&schema),
4648 vec![
4649 Arc::new(Int64Array::from(vec![0])),
4650 Arc::new(Int64Array::from(vec![0])),
4651 Arc::new(Int64Array::from(vec![1])),
4652 ],
4653 )?,
4654 ]];
4655 let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
4656 let scan = scan.try_with_sort_information(vec![
4657 LexOrdering::new([PhysicalSortExpr::new(
4658 col("b", schema.as_ref())?,
4659 SortOptions::default().desc(),
4660 )])
4661 .unwrap(),
4662 ])?;
4663
4664 let aggr = Arc::new(AggregateExec::try_new(
4665 AggregateMode::Single,
4666 PhysicalGroupBy::new(
4667 vec![
4668 (col("b", schema.as_ref())?, "b".to_string()),
4669 (col("c", schema.as_ref())?, "c".to_string()),
4670 ],
4671 vec![],
4672 vec![vec![false, false]],
4673 false,
4674 ),
4675 vec![Arc::new(
4676 AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?])
4677 .schema(Arc::clone(&schema))
4678 .alias("SUM(c)")
4679 .build()?,
4680 )],
4681 vec![None],
4682 Arc::new(scan) as Arc<dyn ExecutionPlan>,
4683 Arc::clone(&schema),
4684 )?);
4685
4686 let task_ctx = new_spill_ctx(1, 600);
4687 let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
4688 assert_spill_count_metric(true, aggr);
4689
4690 allow_duplicates! {
4691 assert_snapshot!(batches_to_string(&result), @r"
4692 +---+---+--------+
4693 | b | c | SUM(c) |
4694 +---+---+--------+
4695 | 2 | 1 | 1 |
4696 | 1 | 1 | 1 |
4697 | 0 | 1 | 1 |
4698 +---+---+--------+
4699 ");
4700 }
4701 Ok(())
4702 }
4703
4704 #[tokio::test]
4708 async fn test_sort_reservation_fails_during_spill() -> Result<()> {
4709 let schema = Arc::new(Schema::new(vec![
4710 Field::new("g", DataType::Int64, false),
4711 Field::new("a", DataType::Float64, false),
4712 Field::new("b", DataType::Float64, false),
4713 Field::new("c", DataType::Float64, false),
4714 Field::new("d", DataType::Float64, false),
4715 Field::new("e", DataType::Float64, false),
4716 ]));
4717
4718 let batches = vec![vec![
4719 RecordBatch::try_new(
4720 Arc::clone(&schema),
4721 vec![
4722 Arc::new(Int64Array::from(vec![1])),
4723 Arc::new(Float64Array::from(vec![10.0])),
4724 Arc::new(Float64Array::from(vec![20.0])),
4725 Arc::new(Float64Array::from(vec![30.0])),
4726 Arc::new(Float64Array::from(vec![40.0])),
4727 Arc::new(Float64Array::from(vec![50.0])),
4728 ],
4729 )?,
4730 RecordBatch::try_new(
4731 Arc::clone(&schema),
4732 vec![
4733 Arc::new(Int64Array::from(vec![2])),
4734 Arc::new(Float64Array::from(vec![11.0])),
4735 Arc::new(Float64Array::from(vec![21.0])),
4736 Arc::new(Float64Array::from(vec![31.0])),
4737 Arc::new(Float64Array::from(vec![41.0])),
4738 Arc::new(Float64Array::from(vec![51.0])),
4739 ],
4740 )?,
4741 RecordBatch::try_new(
4742 Arc::clone(&schema),
4743 vec![
4744 Arc::new(Int64Array::from(vec![3])),
4745 Arc::new(Float64Array::from(vec![12.0])),
4746 Arc::new(Float64Array::from(vec![22.0])),
4747 Arc::new(Float64Array::from(vec![32.0])),
4748 Arc::new(Float64Array::from(vec![42.0])),
4749 Arc::new(Float64Array::from(vec![52.0])),
4750 ],
4751 )?,
4752 ]];
4753
4754 let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;
4755
4756 let aggr = Arc::new(AggregateExec::try_new(
4757 AggregateMode::Single,
4758 PhysicalGroupBy::new(
4759 vec![(col("g", schema.as_ref())?, "g".to_string())],
4760 vec![],
4761 vec![vec![false]],
4762 false,
4763 ),
4764 vec![
4765 Arc::new(
4766 AggregateExprBuilder::new(
4767 avg_udaf(),
4768 vec![col("a", schema.as_ref())?],
4769 )
4770 .schema(Arc::clone(&schema))
4771 .alias("AVG(a)")
4772 .build()?,
4773 ),
4774 Arc::new(
4775 AggregateExprBuilder::new(
4776 avg_udaf(),
4777 vec![col("b", schema.as_ref())?],
4778 )
4779 .schema(Arc::clone(&schema))
4780 .alias("AVG(b)")
4781 .build()?,
4782 ),
4783 Arc::new(
4784 AggregateExprBuilder::new(
4785 avg_udaf(),
4786 vec![col("c", schema.as_ref())?],
4787 )
4788 .schema(Arc::clone(&schema))
4789 .alias("AVG(c)")
4790 .build()?,
4791 ),
4792 Arc::new(
4793 AggregateExprBuilder::new(
4794 avg_udaf(),
4795 vec![col("d", schema.as_ref())?],
4796 )
4797 .schema(Arc::clone(&schema))
4798 .alias("AVG(d)")
4799 .build()?,
4800 ),
4801 Arc::new(
4802 AggregateExprBuilder::new(
4803 avg_udaf(),
4804 vec![col("e", schema.as_ref())?],
4805 )
4806 .schema(Arc::clone(&schema))
4807 .alias("AVG(e)")
4808 .build()?,
4809 ),
4810 ],
4811 vec![None, None, None, None, None],
4812 Arc::new(scan) as Arc<dyn ExecutionPlan>,
4813 Arc::clone(&schema),
4814 )?);
4815
4816 let task_ctx = new_spill_ctx(1, 500);
4819 let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;
4820
4821 match &result {
4822 Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
4823 Err(e) => {
4824 let root = e.find_root();
4825 assert!(
4826 matches!(root, DataFusionError::ResourcesExhausted(_)),
4827 "Expected ResourcesExhausted, got: {root}",
4828 );
4829 let msg = root.to_string();
4830 assert!(
4831 msg.contains("Failed to reserve memory for sort during spill"),
4832 "Expected sort reservation error, got: {msg}",
4833 );
4834 }
4835 }
4836
4837 Ok(())
4838 }
4839
4840 #[tokio::test]
4848 async fn test_partial_reduce_mode() -> Result<()> {
4849 let schema = Arc::new(Schema::new(vec![
4850 Field::new("a", DataType::UInt32, false),
4851 Field::new("b", DataType::Float64, false),
4852 ]));
4853
4854 let batch1 = RecordBatch::try_new(
4856 Arc::clone(&schema),
4857 vec![
4858 Arc::new(UInt32Array::from(vec![1, 2, 3])),
4859 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
4860 ],
4861 )?;
4862 let batch2 = RecordBatch::try_new(
4863 Arc::clone(&schema),
4864 vec![
4865 Arc::new(UInt32Array::from(vec![1, 2, 3])),
4866 Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
4867 ],
4868 )?;
4869
4870 let groups =
4871 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
4872 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
4873 AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
4874 .schema(Arc::clone(&schema))
4875 .alias("SUM(b)")
4876 .build()?,
4877 )];
4878
4879 let input1 =
4881 TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?;
4882 let partial1 = Arc::new(AggregateExec::try_new(
4883 AggregateMode::Partial,
4884 groups.clone(),
4885 aggregates.clone(),
4886 vec![None],
4887 input1,
4888 Arc::clone(&schema),
4889 )?);
4890
4891 let input2 =
4893 TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?;
4894 let partial2 = Arc::new(AggregateExec::try_new(
4895 AggregateMode::Partial,
4896 groups.clone(),
4897 aggregates.clone(),
4898 vec![None],
4899 input2,
4900 Arc::clone(&schema),
4901 )?);
4902
4903 let task_ctx = Arc::new(TaskContext::default());
4905 let partial_result1 =
4906 crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?;
4907 let partial_result2 =
4908 crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?;
4909
4910 let partial_schema = partial1.schema();
4912
4913 let combined_input = TestMemoryExec::try_new_exec(
4915 &[partial_result1, partial_result2],
4916 Arc::clone(&partial_schema),
4917 None,
4918 )?;
4919 let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
4921
4922 let partial_reduce = Arc::new(AggregateExec::try_new(
4923 AggregateMode::PartialReduce,
4924 groups.clone(),
4925 aggregates.clone(),
4926 vec![None],
4927 coalesced,
4928 Arc::clone(&partial_schema),
4929 )?);
4930
4931 assert_eq!(partial_reduce.schema(), partial_schema);
4934
4935 let reduce_result =
4937 crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx))
4938 .await?;
4939
4940 let final_input = TestMemoryExec::try_new_exec(
4942 &[reduce_result],
4943 Arc::clone(&partial_schema),
4944 None,
4945 )?;
4946 let final_agg = Arc::new(AggregateExec::try_new(
4947 AggregateMode::Final,
4948 groups.clone(),
4949 aggregates.clone(),
4950 vec![None],
4951 final_input,
4952 Arc::clone(&partial_schema),
4953 )?);
4954
4955 let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
4956
4957 assert_snapshot!(batches_to_sort_string(&result), @r"
4959 +---+--------+
4960 | a | SUM(b) |
4961 +---+--------+
4962 | 1 | 50.0 |
4963 | 2 | 70.0 |
4964 | 3 | 90.0 |
4965 +---+--------+
4966 ");
4967
4968 Ok(())
4969 }
4970
4971 #[test]
4973 fn test_with_dynamic_filter() -> Result<()> {
4974 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
4975 let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4976
4977 let agg = AggregateExec::try_new(
4979 AggregateMode::Partial,
4980 PhysicalGroupBy::new_single(vec![]),
4981 vec![Arc::new(
4982 AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
4983 .schema(Arc::clone(&schema))
4984 .alias("min_a")
4985 .build()?,
4986 )],
4987 vec![None],
4988 child,
4989 Arc::clone(&schema),
4990 )?;
4991
4992 let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
4995 vec![col("a", &schema)?],
4996 lit(false),
4997 ));
4998 let agg = agg.with_dynamic_filter_expr(Arc::clone(&new_df))?;
4999
5000 let swapped = agg
5002 .dynamic_filter_expr()
5003 .expect("should still have dynamic filter")
5004 .current()?;
5005 assert_eq!(format!("{swapped}"), format!("{}", lit(false)));
5006
5007 let new_df_as_pexpr: Arc<dyn PhysicalExpr> =
5010 Arc::<DynamicFilterPhysicalExpr>::clone(&new_df);
5011 let remapped_pexpr =
5012 new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?;
5013 let Ok(remapped_df) = (remapped_pexpr as Arc<dyn std::any::Any + Send + Sync>)
5014 .downcast::<DynamicFilterPhysicalExpr>()
5015 else {
5016 panic!("should be DynamicFilterPhysicalExpr after with_new_children");
5017 };
5018 let _agg = agg.with_dynamic_filter_expr(remapped_df)?;
5021 Ok(())
5022 }
5023
5024 #[test]
5026 fn test_with_dynamic_filter_error_unsupported() -> Result<()> {
5027 let schema = Arc::new(Schema::new(vec![
5028 Field::new("a", DataType::Int64, false),
5029 Field::new("b", DataType::Int64, false),
5030 ]));
5031 let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5032
5033 let agg = AggregateExec::try_new(
5035 AggregateMode::Final,
5036 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]),
5037 vec![Arc::new(
5038 AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
5039 .schema(Arc::clone(&schema))
5040 .alias("sum_b")
5041 .build()?,
5042 )],
5043 vec![None],
5044 child,
5045 Arc::clone(&schema),
5046 )?;
5047 assert!(agg.dynamic_filter_expr().is_none());
5048
5049 let df = Arc::new(DynamicFilterPhysicalExpr::new(
5050 vec![col("a", &schema)?],
5051 lit(true),
5052 ));
5053 assert!(agg.with_dynamic_filter_expr(df).is_err());
5054 Ok(())
5055 }
5056
5057 #[test]
5059 fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> {
5060 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
5061 let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5062
5063 let agg = AggregateExec::try_new(
5064 AggregateMode::Partial,
5065 PhysicalGroupBy::new_single(vec![]),
5066 vec![Arc::new(
5067 AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
5068 .schema(Arc::clone(&schema))
5069 .alias("min_a")
5070 .build()?,
5071 )],
5072 vec![None],
5073 child,
5074 Arc::clone(&schema),
5075 )?;
5076
5077 let df = Arc::new(DynamicFilterPhysicalExpr::new(
5078 vec![Arc::new(Column::new("bad", 99)) as _],
5079 lit(true),
5080 ));
5081 assert!(agg.with_dynamic_filter_expr(df).is_err());
5082 Ok(())
5083 }
5084}