1use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25 no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26 topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::filter_pushdown::{
30 ChildFilterDescription, FilterDescription, FilterPushdownPhase, PushedDownPredicate,
31};
32use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
33use crate::windows::get_ordered_partition_by_indices;
34use crate::{
35 DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
36 SendableRecordBatchStream, Statistics,
37};
38use datafusion_common::config::ConfigOptions;
39use datafusion_physical_expr::utils::collect_columns;
40use std::collections::HashSet;
41
42use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
43use arrow::datatypes::{Field, Schema, SchemaRef};
44use arrow::record_batch::RecordBatch;
45use arrow_schema::FieldRef;
46use datafusion_common::stats::Precision;
47use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result};
48use datafusion_execution::TaskContext;
49use datafusion_expr::{Accumulator, Aggregate};
50use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
51use datafusion_physical_expr::equivalence::ProjectionMapping;
52use datafusion_physical_expr::expressions::Column;
53use datafusion_physical_expr::{
54 physical_exprs_contains, ConstExpr, EquivalenceProperties,
55};
56use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExpr};
57use datafusion_physical_expr_common::sort_expr::{
58 LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
59};
60
61use datafusion_expr::utils::AggregateOrderSensitivity;
62use itertools::Itertools;
63
64pub mod group_values;
65mod no_grouping;
66pub mod order;
67mod row_hash;
68mod topk;
69mod topk_stream;
70
71const AGGREGATION_HASH_SEED: ahash::RandomState =
73 ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
74
75#[derive(Debug, Copy, Clone, PartialEq, Eq)]
80pub enum AggregateMode {
81 Partial,
88 Final,
104 FinalPartitioned,
113 Single,
121 SinglePartitioned,
130}
131
132impl AggregateMode {
133 pub fn is_first_stage(&self) -> bool {
137 match self {
138 AggregateMode::Partial
139 | AggregateMode::Single
140 | AggregateMode::SinglePartitioned => true,
141 AggregateMode::Final | AggregateMode::FinalPartitioned => false,
142 }
143 }
144}
145
146#[derive(Clone, Debug, Default)]
165pub struct PhysicalGroupBy {
166 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
168 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
170 groups: Vec<Vec<bool>>,
175}
176
177impl PhysicalGroupBy {
178 pub fn new(
180 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
181 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
182 groups: Vec<Vec<bool>>,
183 ) -> Self {
184 Self {
185 expr,
186 null_expr,
187 groups,
188 }
189 }
190
191 pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
194 let num_exprs = expr.len();
195 Self {
196 expr,
197 null_expr: vec![],
198 groups: vec![vec![false; num_exprs]],
199 }
200 }
201
202 pub fn exprs_nullable(&self) -> Vec<bool> {
204 let mut exprs_nullable = vec![false; self.expr.len()];
205 for group in self.groups.iter() {
206 group.iter().enumerate().for_each(|(index, is_null)| {
207 if *is_null {
208 exprs_nullable[index] = true;
209 }
210 })
211 }
212 exprs_nullable
213 }
214
215 pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
217 &self.expr
218 }
219
220 pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
222 &self.null_expr
223 }
224
225 pub fn groups(&self) -> &[Vec<bool>] {
227 &self.groups
228 }
229
230 pub fn is_empty(&self) -> bool {
232 self.expr.is_empty()
233 }
234
235 pub fn is_single(&self) -> bool {
237 self.null_expr.is_empty()
238 }
239
240 pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
242 self.expr
243 .iter()
244 .map(|(expr, _alias)| Arc::clone(expr))
245 .collect()
246 }
247
248 fn num_output_exprs(&self) -> usize {
250 let mut num_exprs = self.expr.len();
251 if !self.is_single() {
252 num_exprs += 1
253 }
254 num_exprs
255 }
256
257 pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
259 let num_output_exprs = self.num_output_exprs();
260 let mut output_exprs = Vec::with_capacity(num_output_exprs);
261 output_exprs.extend(
262 self.expr
263 .iter()
264 .enumerate()
265 .take(num_output_exprs)
266 .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
267 );
268 if !self.is_single() {
269 output_exprs.push(Arc::new(Column::new(
270 Aggregate::INTERNAL_GROUPING_ID,
271 self.expr.len(),
272 )) as _);
273 }
274 output_exprs
275 }
276
277 pub fn num_group_exprs(&self) -> usize {
279 if self.is_single() {
280 self.expr.len()
281 } else {
282 self.expr.len() + 1
283 }
284 }
285
286 pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
287 Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
288 }
289
290 fn group_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
292 let mut fields = Vec::with_capacity(self.num_group_exprs());
293 for ((expr, name), group_expr_nullable) in
294 self.expr.iter().zip(self.exprs_nullable().into_iter())
295 {
296 fields.push(
297 Field::new(
298 name,
299 expr.data_type(input_schema)?,
300 group_expr_nullable || expr.nullable(input_schema)?,
301 )
302 .with_metadata(expr.return_field(input_schema)?.metadata().clone())
303 .into(),
304 );
305 }
306 if !self.is_single() {
307 fields.push(
308 Field::new(
309 Aggregate::INTERNAL_GROUPING_ID,
310 Aggregate::grouping_id_type(self.expr.len()),
311 false,
312 )
313 .into(),
314 );
315 }
316 Ok(fields)
317 }
318
319 fn output_fields(&self, input_schema: &Schema) -> Result<Vec<FieldRef>> {
324 let mut fields = self.group_fields(input_schema)?;
325 fields.truncate(self.num_output_exprs());
326 Ok(fields)
327 }
328
329 pub fn as_final(&self) -> PhysicalGroupBy {
332 let expr: Vec<_> =
333 self.output_exprs()
334 .into_iter()
335 .zip(
336 self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
337 Aggregate::INTERNAL_GROUPING_ID.to_owned(),
338 )),
339 )
340 .collect();
341 let num_exprs = expr.len();
342 let groups = if self.expr.is_empty() {
343 vec![]
345 } else {
346 vec![vec![false; num_exprs]]
348 };
349 Self {
350 expr,
351 null_expr: vec![],
352 groups,
353 }
354 }
355}
356
357impl PartialEq for PhysicalGroupBy {
358 fn eq(&self, other: &PhysicalGroupBy) -> bool {
359 self.expr.len() == other.expr.len()
360 && self
361 .expr
362 .iter()
363 .zip(other.expr.iter())
364 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
365 && self.null_expr.len() == other.null_expr.len()
366 && self
367 .null_expr
368 .iter()
369 .zip(other.null_expr.iter())
370 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
371 && self.groups == other.groups
372 }
373}
374
375#[allow(clippy::large_enum_variant)]
376enum StreamType {
377 AggregateStream(AggregateStream),
378 GroupedHash(GroupedHashAggregateStream),
379 GroupedPriorityQueue(GroupedTopKAggregateStream),
380}
381
382impl From<StreamType> for SendableRecordBatchStream {
383 fn from(stream: StreamType) -> Self {
384 match stream {
385 StreamType::AggregateStream(stream) => Box::pin(stream),
386 StreamType::GroupedHash(stream) => Box::pin(stream),
387 StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
388 }
389 }
390}
391
392#[derive(Debug, Clone)]
394pub struct AggregateExec {
395 mode: AggregateMode,
397 group_by: PhysicalGroupBy,
399 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
401 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
403 limit: Option<usize>,
405 pub input: Arc<dyn ExecutionPlan>,
407 schema: SchemaRef,
409 pub input_schema: SchemaRef,
415 metrics: ExecutionPlanMetricsSet,
417 required_input_ordering: Option<OrderingRequirements>,
418 input_order_mode: InputOrderMode,
420 cache: PlanProperties,
421}
422
423impl AggregateExec {
424 pub fn with_new_aggr_exprs(
428 &self,
429 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
430 ) -> Self {
431 Self {
432 aggr_expr,
433 required_input_ordering: self.required_input_ordering.clone(),
435 metrics: ExecutionPlanMetricsSet::new(),
436 input_order_mode: self.input_order_mode.clone(),
437 cache: self.cache.clone(),
438 mode: self.mode,
439 group_by: self.group_by.clone(),
440 filter_expr: self.filter_expr.clone(),
441 limit: self.limit,
442 input: Arc::clone(&self.input),
443 schema: Arc::clone(&self.schema),
444 input_schema: Arc::clone(&self.input_schema),
445 }
446 }
447
448 pub fn cache(&self) -> &PlanProperties {
449 &self.cache
450 }
451
452 pub fn try_new(
454 mode: AggregateMode,
455 group_by: PhysicalGroupBy,
456 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
457 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
458 input: Arc<dyn ExecutionPlan>,
459 input_schema: SchemaRef,
460 ) -> Result<Self> {
461 let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
462
463 let schema = Arc::new(schema);
464 AggregateExec::try_new_with_schema(
465 mode,
466 group_by,
467 aggr_expr,
468 filter_expr,
469 input,
470 input_schema,
471 schema,
472 )
473 }
474
475 #[allow(clippy::too_many_arguments)]
484 fn try_new_with_schema(
485 mode: AggregateMode,
486 group_by: PhysicalGroupBy,
487 mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
488 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
489 input: Arc<dyn ExecutionPlan>,
490 input_schema: SchemaRef,
491 schema: SchemaRef,
492 ) -> Result<Self> {
493 if aggr_expr.len() != filter_expr.len() {
495 return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
496 }
497
498 let input_eq_properties = input.equivalence_properties();
499 let groupby_exprs = group_by.input_exprs();
501 let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?;
505 let mut new_requirements = indices
506 .iter()
507 .map(|&idx| {
508 PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None)
509 })
510 .collect::<Vec<_>>();
511
512 let req = get_finer_aggregate_exprs_requirement(
513 &mut aggr_expr,
514 &group_by,
515 input_eq_properties,
516 &mode,
517 )?;
518 new_requirements.extend(req);
519
520 let required_input_ordering =
521 LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft);
522
523 let indices: Vec<usize> = indices
529 .into_iter()
530 .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
531 .collect();
532
533 let input_order_mode = if indices.len() == groupby_exprs.len()
534 && !indices.is_empty()
535 && group_by.groups.len() == 1
536 {
537 InputOrderMode::Sorted
538 } else if !indices.is_empty() {
539 InputOrderMode::PartiallySorted(indices)
540 } else {
541 InputOrderMode::Linear
542 };
543
544 let group_expr_mapping =
546 ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?;
547
548 let cache = Self::compute_properties(
549 &input,
550 Arc::clone(&schema),
551 &group_expr_mapping,
552 &mode,
553 &input_order_mode,
554 aggr_expr.as_slice(),
555 )?;
556
557 Ok(AggregateExec {
558 mode,
559 group_by,
560 aggr_expr,
561 filter_expr,
562 input,
563 schema,
564 input_schema,
565 metrics: ExecutionPlanMetricsSet::new(),
566 required_input_ordering,
567 limit: None,
568 input_order_mode,
569 cache,
570 })
571 }
572
573 pub fn mode(&self) -> &AggregateMode {
575 &self.mode
576 }
577
578 pub fn with_limit(mut self, limit: Option<usize>) -> Self {
580 self.limit = limit;
581 self
582 }
583 pub fn group_expr(&self) -> &PhysicalGroupBy {
585 &self.group_by
586 }
587
588 pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
590 self.group_by.output_exprs()
591 }
592
593 pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
595 &self.aggr_expr
596 }
597
598 pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
600 &self.filter_expr
601 }
602
603 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
605 &self.input
606 }
607
608 pub fn input_schema(&self) -> SchemaRef {
610 Arc::clone(&self.input_schema)
611 }
612
613 pub fn limit(&self) -> Option<usize> {
615 self.limit
616 }
617
618 fn execute_typed(
619 &self,
620 partition: usize,
621 context: Arc<TaskContext>,
622 ) -> Result<StreamType> {
623 if self.group_by.expr.is_empty() {
625 return Ok(StreamType::AggregateStream(AggregateStream::new(
626 self, context, partition,
627 )?));
628 }
629
630 if let Some(limit) = self.limit {
632 if !self.is_unordered_unfiltered_group_by_distinct() {
633 return Ok(StreamType::GroupedPriorityQueue(
634 GroupedTopKAggregateStream::new(self, context, partition, limit)?,
635 ));
636 }
637 }
638
639 Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
641 self, context, partition,
642 )?))
643 }
644
645 pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
647 let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
648 agg_expr.get_minmax_desc()
649 }
650
651 pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
656 if self.group_expr().is_empty() {
658 return false;
659 }
660 if !self.aggr_expr().is_empty() {
662 return false;
663 }
664 if self.filter_expr().iter().any(|e| e.is_some()) {
667 return false;
668 }
669 if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) {
671 return false;
672 }
673 if self.properties().output_ordering().is_some() {
675 return false;
676 }
677 if let Some(requirement) = self.required_input_ordering().swap_remove(0) {
679 return matches!(requirement, OrderingRequirements::Hard(_));
680 }
681 true
682 }
683
684 pub fn compute_properties(
686 input: &Arc<dyn ExecutionPlan>,
687 schema: SchemaRef,
688 group_expr_mapping: &ProjectionMapping,
689 mode: &AggregateMode,
690 input_order_mode: &InputOrderMode,
691 aggr_exprs: &[Arc<AggregateFunctionExpr>],
692 ) -> Result<PlanProperties> {
693 let mut eq_properties = input
695 .equivalence_properties()
696 .project(group_expr_mapping, schema);
697
698 if group_expr_mapping.is_empty() {
701 let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
702 let column = Arc::new(Column::new(func.name(), idx));
703 ConstExpr::from(column as Arc<dyn PhysicalExpr>)
704 });
705 eq_properties.add_constants(new_constants)?;
706 }
707
708 let mut constraints = eq_properties.constraints().to_vec();
711 let new_constraint = Constraint::Unique(
712 group_expr_mapping
713 .iter()
714 .flat_map(|(_, target_cols)| {
715 target_cols.iter().flat_map(|(expr, _)| {
716 expr.as_any().downcast_ref::<Column>().map(|c| c.index())
717 })
718 })
719 .collect(),
720 );
721 constraints.push(new_constraint);
722 eq_properties =
723 eq_properties.with_constraints(Constraints::new_unverified(constraints));
724
725 let input_partitioning = input.output_partitioning().clone();
727 let output_partitioning = if mode.is_first_stage() {
728 let input_eq_properties = input.equivalence_properties();
732 input_partitioning.project(group_expr_mapping, input_eq_properties)
733 } else {
734 input_partitioning.clone()
735 };
736
737 let emission_type = if *input_order_mode == InputOrderMode::Linear {
739 EmissionType::Final
740 } else {
741 input.pipeline_behavior()
742 };
743
744 Ok(PlanProperties::new(
745 eq_properties,
746 output_partitioning,
747 emission_type,
748 input.boundedness(),
749 ))
750 }
751
752 pub fn input_order_mode(&self) -> &InputOrderMode {
753 &self.input_order_mode
754 }
755
756 fn statistics_inner(&self, child_statistics: Statistics) -> Result<Statistics> {
757 let column_statistics = {
764 let mut column_statistics = Statistics::unknown_column(&self.schema());
766
767 for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() {
768 if let Some(col) = expr.as_any().downcast_ref::<Column>() {
769 column_statistics[idx].max_value = child_statistics.column_statistics
770 [col.index()]
771 .max_value
772 .clone();
773
774 column_statistics[idx].min_value = child_statistics.column_statistics
775 [col.index()]
776 .min_value
777 .clone();
778 }
779 }
780
781 column_statistics
782 };
783 match self.mode {
784 AggregateMode::Final | AggregateMode::FinalPartitioned
785 if self.group_by.expr.is_empty() =>
786 {
787 Ok(Statistics {
788 num_rows: Precision::Exact(1),
789 column_statistics,
790 total_byte_size: Precision::Absent,
791 })
792 }
793 _ => {
794 let num_rows = if let Some(value) = child_statistics.num_rows.get_value()
797 {
798 if *value > 1 {
799 child_statistics.num_rows.to_inexact()
800 } else if *value == 0 {
801 child_statistics.num_rows
802 } else {
803 let grouping_set_num = self.group_by.groups.len();
805 child_statistics.num_rows.map(|x| x * grouping_set_num)
806 }
807 } else {
808 Precision::Absent
809 };
810 Ok(Statistics {
811 num_rows,
812 column_statistics,
813 total_byte_size: Precision::Absent,
814 })
815 }
816 }
817 }
818}
819
820impl DisplayAs for AggregateExec {
821 fn fmt_as(
822 &self,
823 t: DisplayFormatType,
824 f: &mut std::fmt::Formatter,
825 ) -> std::fmt::Result {
826 match t {
827 DisplayFormatType::Default | DisplayFormatType::Verbose => {
828 let format_expr_with_alias =
829 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
830 let e = e.to_string();
831 if &e != alias {
832 format!("{e} as {alias}")
833 } else {
834 e
835 }
836 };
837
838 write!(f, "AggregateExec: mode={:?}", self.mode)?;
839 let g: Vec<String> = if self.group_by.is_single() {
840 self.group_by
841 .expr
842 .iter()
843 .map(format_expr_with_alias)
844 .collect()
845 } else {
846 self.group_by
847 .groups
848 .iter()
849 .map(|group| {
850 let terms = group
851 .iter()
852 .enumerate()
853 .map(|(idx, is_null)| {
854 if *is_null {
855 format_expr_with_alias(
856 &self.group_by.null_expr[idx],
857 )
858 } else {
859 format_expr_with_alias(&self.group_by.expr[idx])
860 }
861 })
862 .collect::<Vec<String>>()
863 .join(", ");
864 format!("({terms})")
865 })
866 .collect()
867 };
868
869 write!(f, ", gby=[{}]", g.join(", "))?;
870
871 let a: Vec<String> = self
872 .aggr_expr
873 .iter()
874 .map(|agg| agg.name().to_string())
875 .collect();
876 write!(f, ", aggr=[{}]", a.join(", "))?;
877 if let Some(limit) = self.limit {
878 write!(f, ", lim=[{limit}]")?;
879 }
880
881 if self.input_order_mode != InputOrderMode::Linear {
882 write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
883 }
884 }
885 DisplayFormatType::TreeRender => {
886 let format_expr_with_alias =
887 |(e, alias): &(Arc<dyn PhysicalExpr>, String)| -> String {
888 let expr_sql = fmt_sql(e.as_ref()).to_string();
889 if &expr_sql != alias {
890 format!("{expr_sql} as {alias}")
891 } else {
892 expr_sql
893 }
894 };
895
896 let g: Vec<String> = if self.group_by.is_single() {
897 self.group_by
898 .expr
899 .iter()
900 .map(format_expr_with_alias)
901 .collect()
902 } else {
903 self.group_by
904 .groups
905 .iter()
906 .map(|group| {
907 let terms = group
908 .iter()
909 .enumerate()
910 .map(|(idx, is_null)| {
911 if *is_null {
912 format_expr_with_alias(
913 &self.group_by.null_expr[idx],
914 )
915 } else {
916 format_expr_with_alias(&self.group_by.expr[idx])
917 }
918 })
919 .collect::<Vec<String>>()
920 .join(", ");
921 format!("({terms})")
922 })
923 .collect()
924 };
925 let a: Vec<String> = self
926 .aggr_expr
927 .iter()
928 .map(|agg| agg.human_display().to_string())
929 .collect();
930 writeln!(f, "mode={:?}", self.mode)?;
931 if !g.is_empty() {
932 writeln!(f, "group_by={}", g.join(", "))?;
933 }
934 if !a.is_empty() {
935 writeln!(f, "aggr={}", a.join(", "))?;
936 }
937 }
938 }
939 Ok(())
940 }
941}
942
943impl ExecutionPlan for AggregateExec {
944 fn name(&self) -> &'static str {
945 "AggregateExec"
946 }
947
948 fn as_any(&self) -> &dyn Any {
950 self
951 }
952
953 fn properties(&self) -> &PlanProperties {
954 &self.cache
955 }
956
957 fn required_input_distribution(&self) -> Vec<Distribution> {
958 match &self.mode {
959 AggregateMode::Partial => {
960 vec![Distribution::UnspecifiedDistribution]
961 }
962 AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
963 vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
964 }
965 AggregateMode::Final | AggregateMode::Single => {
966 vec![Distribution::SinglePartition]
967 }
968 }
969 }
970
971 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
972 vec![self.required_input_ordering.clone()]
973 }
974
975 fn maintains_input_order(&self) -> Vec<bool> {
985 vec![self.input_order_mode != InputOrderMode::Linear]
986 }
987
988 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
989 vec![&self.input]
990 }
991
992 fn with_new_children(
993 self: Arc<Self>,
994 children: Vec<Arc<dyn ExecutionPlan>>,
995 ) -> Result<Arc<dyn ExecutionPlan>> {
996 let mut me = AggregateExec::try_new_with_schema(
997 self.mode,
998 self.group_by.clone(),
999 self.aggr_expr.clone(),
1000 self.filter_expr.clone(),
1001 Arc::clone(&children[0]),
1002 Arc::clone(&self.input_schema),
1003 Arc::clone(&self.schema),
1004 )?;
1005 me.limit = self.limit;
1006
1007 Ok(Arc::new(me))
1008 }
1009
1010 fn execute(
1011 &self,
1012 partition: usize,
1013 context: Arc<TaskContext>,
1014 ) -> Result<SendableRecordBatchStream> {
1015 self.execute_typed(partition, context)
1016 .map(|stream| stream.into())
1017 }
1018
1019 fn metrics(&self) -> Option<MetricsSet> {
1020 Some(self.metrics.clone_inner())
1021 }
1022
1023 fn statistics(&self) -> Result<Statistics> {
1024 self.partition_statistics(None)
1025 }
1026
1027 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1028 self.statistics_inner(self.input().partition_statistics(partition)?)
1029 }
1030
1031 fn cardinality_effect(&self) -> CardinalityEffect {
1032 CardinalityEffect::LowerEqual
1033 }
1034
1035 fn gather_filters_for_pushdown(
1038 &self,
1039 _phase: FilterPushdownPhase,
1040 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1041 _config: &ConfigOptions,
1042 ) -> Result<FilterDescription> {
1043 let grouping_columns: HashSet<_> = self
1052 .group_by
1053 .expr()
1054 .iter()
1055 .flat_map(|(expr, _)| collect_columns(expr))
1056 .collect();
1057
1058 let mut safe_filters = Vec::new();
1060 let mut unsafe_filters = Vec::new();
1061
1062 for filter in parent_filters {
1063 let filter_columns: HashSet<_> =
1064 collect_columns(&filter).into_iter().collect();
1065
1066 let references_non_grouping = !grouping_columns.is_empty()
1068 && !filter_columns.is_subset(&grouping_columns);
1069
1070 if references_non_grouping {
1071 unsafe_filters.push(filter);
1072 continue;
1073 }
1074
1075 if self.group_by.groups().len() > 1 {
1077 let filter_column_indices: Vec<usize> = filter_columns
1078 .iter()
1079 .filter_map(|filter_col| {
1080 self.group_by.expr().iter().position(|(expr, _)| {
1081 collect_columns(expr).contains(filter_col)
1082 })
1083 })
1084 .collect();
1085
1086 let has_missing_column = self.group_by.groups().iter().any(|null_mask| {
1088 filter_column_indices
1089 .iter()
1090 .any(|&idx| null_mask.get(idx) == Some(&true))
1091 });
1092
1093 if has_missing_column {
1094 unsafe_filters.push(filter);
1095 continue;
1096 }
1097 }
1098
1099 safe_filters.push(filter);
1101 }
1102
1103 let child = self.children()[0];
1105 let mut child_desc = ChildFilterDescription::from_child(&safe_filters, child)?;
1106
1107 child_desc.parent_filters.extend(
1109 unsafe_filters
1110 .into_iter()
1111 .map(PushedDownPredicate::unsupported),
1112 );
1113
1114 Ok(FilterDescription::new().with_child(child_desc))
1115 }
1116}
1117
1118fn create_schema(
1119 input_schema: &Schema,
1120 group_by: &PhysicalGroupBy,
1121 aggr_expr: &[Arc<AggregateFunctionExpr>],
1122 mode: AggregateMode,
1123) -> Result<Schema> {
1124 let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
1125 fields.extend(group_by.output_fields(input_schema)?);
1126
1127 match mode {
1128 AggregateMode::Partial => {
1129 for expr in aggr_expr {
1131 fields.extend(expr.state_fields()?.iter().cloned());
1132 }
1133 }
1134 AggregateMode::Final
1135 | AggregateMode::FinalPartitioned
1136 | AggregateMode::Single
1137 | AggregateMode::SinglePartitioned => {
1138 for expr in aggr_expr {
1140 fields.push(expr.field())
1141 }
1142 }
1143 }
1144
1145 Ok(Schema::new_with_metadata(
1146 fields,
1147 input_schema.metadata().clone(),
1148 ))
1149}
1150
1151fn get_aggregate_expr_req(
1172 aggr_expr: &AggregateFunctionExpr,
1173 group_by: &PhysicalGroupBy,
1174 agg_mode: &AggregateMode,
1175 include_soft_requirement: bool,
1176) -> Option<LexOrdering> {
1177 if !agg_mode.is_first_stage() {
1181 return None;
1182 }
1183
1184 match aggr_expr.order_sensitivity() {
1185 AggregateOrderSensitivity::Insensitive => return None,
1186 AggregateOrderSensitivity::HardRequirement => {}
1187 AggregateOrderSensitivity::SoftRequirement => {
1188 if !include_soft_requirement {
1189 return None;
1190 }
1191 }
1192 AggregateOrderSensitivity::Beneficial => return None,
1193 }
1194
1195 let mut sort_exprs = aggr_expr.order_bys().to_vec();
1196 if group_by.is_single() {
1202 let physical_exprs = group_by.input_exprs();
1206 sort_exprs.retain(|sort_expr| {
1207 !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1208 });
1209 }
1210 LexOrdering::new(sort_exprs)
1211}
1212
1213pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1215 [lhs, rhs].concat()
1216}
1217
1218fn determine_finer(
1222 current: &Option<LexOrdering>,
1223 candidate: &LexOrdering,
1224) -> Option<bool> {
1225 if let Some(ordering) = current {
1226 candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt())
1227 } else {
1228 Some(true)
1229 }
1230}
1231
1232pub fn get_finer_aggregate_exprs_requirement(
1253 aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1254 group_by: &PhysicalGroupBy,
1255 eq_properties: &EquivalenceProperties,
1256 agg_mode: &AggregateMode,
1257) -> Result<Vec<PhysicalSortRequirement>> {
1258 let mut requirement = None;
1259
1260 for include_soft_requirement in [false, true] {
1264 for aggr_expr in aggr_exprs.iter_mut() {
1265 let Some(aggr_req) = get_aggregate_expr_req(
1266 aggr_expr,
1267 group_by,
1268 agg_mode,
1269 include_soft_requirement,
1270 )
1271 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1272 continue;
1275 };
1276 let forward_finer = determine_finer(&requirement, &aggr_req);
1281 if let Some(finer) = forward_finer {
1282 if !finer {
1283 continue;
1284 } else if eq_properties.ordering_satisfy(aggr_req.clone())? {
1285 requirement = Some(aggr_req);
1286 continue;
1287 }
1288 }
1289 if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1290 let Some(rev_aggr_req) = get_aggregate_expr_req(
1291 &reverse_aggr_expr,
1292 group_by,
1293 agg_mode,
1294 include_soft_requirement,
1295 )
1296 .and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
1297 *aggr_expr = Arc::new(reverse_aggr_expr);
1300 continue;
1301 };
1302 if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
1308 if !finer {
1309 *aggr_expr = Arc::new(reverse_aggr_expr);
1310 } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
1311 *aggr_expr = Arc::new(reverse_aggr_expr);
1312 requirement = Some(rev_aggr_req);
1313 } else {
1314 requirement = Some(aggr_req);
1315 }
1316 } else if forward_finer.is_some() {
1317 requirement = Some(aggr_req);
1318 } else {
1319 if !include_soft_requirement {
1324 return not_impl_err!(
1325 "Conflicting ordering requirements in aggregate functions is not supported"
1326 );
1327 }
1328 }
1329 }
1330 }
1331 }
1332
1333 Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect()))
1334}
1335
1336pub fn aggregate_expressions(
1342 aggr_expr: &[Arc<AggregateFunctionExpr>],
1343 mode: &AggregateMode,
1344 col_idx_base: usize,
1345) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1346 match mode {
1347 AggregateMode::Partial
1348 | AggregateMode::Single
1349 | AggregateMode::SinglePartitioned => Ok(aggr_expr
1350 .iter()
1351 .map(|agg| {
1352 let mut result = agg.expressions();
1353 result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr)));
1357 result
1358 })
1359 .collect()),
1360 AggregateMode::Final | AggregateMode::FinalPartitioned => {
1362 let mut col_idx_base = col_idx_base;
1363 aggr_expr
1364 .iter()
1365 .map(|agg| {
1366 let exprs = merge_expressions(col_idx_base, agg)?;
1367 col_idx_base += exprs.len();
1368 Ok(exprs)
1369 })
1370 .collect()
1371 }
1372 }
1373}
1374
1375fn merge_expressions(
1380 index_base: usize,
1381 expr: &AggregateFunctionExpr,
1382) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1383 expr.state_fields().map(|fields| {
1384 fields
1385 .iter()
1386 .enumerate()
1387 .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1388 .collect()
1389 })
1390}
1391
1392pub type AccumulatorItem = Box<dyn Accumulator>;
1393
1394pub fn create_accumulators(
1395 aggr_expr: &[Arc<AggregateFunctionExpr>],
1396) -> Result<Vec<AccumulatorItem>> {
1397 aggr_expr
1398 .iter()
1399 .map(|expr| expr.create_accumulator())
1400 .collect()
1401}
1402
1403pub fn finalize_aggregation(
1406 accumulators: &mut [AccumulatorItem],
1407 mode: &AggregateMode,
1408) -> Result<Vec<ArrayRef>> {
1409 match mode {
1410 AggregateMode::Partial => {
1411 accumulators
1413 .iter_mut()
1414 .map(|accumulator| {
1415 accumulator.state().and_then(|e| {
1416 e.iter()
1417 .map(|v| v.to_array())
1418 .collect::<Result<Vec<ArrayRef>>>()
1419 })
1420 })
1421 .flatten_ok()
1422 .collect()
1423 }
1424 AggregateMode::Final
1425 | AggregateMode::FinalPartitioned
1426 | AggregateMode::Single
1427 | AggregateMode::SinglePartitioned => {
1428 accumulators
1430 .iter_mut()
1431 .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1432 .collect()
1433 }
1434 }
1435}
1436
1437fn evaluate(
1439 expr: &[Arc<dyn PhysicalExpr>],
1440 batch: &RecordBatch,
1441) -> Result<Vec<ArrayRef>> {
1442 expr.iter()
1443 .map(|expr| {
1444 expr.evaluate(batch)
1445 .and_then(|v| v.into_array(batch.num_rows()))
1446 })
1447 .collect()
1448}
1449
1450pub fn evaluate_many(
1452 expr: &[Vec<Arc<dyn PhysicalExpr>>],
1453 batch: &RecordBatch,
1454) -> Result<Vec<Vec<ArrayRef>>> {
1455 expr.iter().map(|expr| evaluate(expr, batch)).collect()
1456}
1457
1458fn evaluate_optional(
1459 expr: &[Option<Arc<dyn PhysicalExpr>>],
1460 batch: &RecordBatch,
1461) -> Result<Vec<Option<ArrayRef>>> {
1462 expr.iter()
1463 .map(|expr| {
1464 expr.as_ref()
1465 .map(|expr| {
1466 expr.evaluate(batch)
1467 .and_then(|v| v.into_array(batch.num_rows()))
1468 })
1469 .transpose()
1470 })
1471 .collect()
1472}
1473
1474fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1475 if group.len() > 64 {
1476 return not_impl_err!(
1477 "Grouping sets with more than 64 columns are not supported"
1478 );
1479 }
1480 let group_id = group.iter().fold(0u64, |acc, &is_null| {
1481 (acc << 1) | if is_null { 1 } else { 0 }
1482 });
1483 let num_rows = batch.num_rows();
1484 if group.len() <= 8 {
1485 Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1486 } else if group.len() <= 16 {
1487 Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1488 } else if group.len() <= 32 {
1489 Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1490 } else {
1491 Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1492 }
1493}
1494
1495pub fn evaluate_group_by(
1506 group_by: &PhysicalGroupBy,
1507 batch: &RecordBatch,
1508) -> Result<Vec<Vec<ArrayRef>>> {
1509 let exprs: Vec<ArrayRef> = group_by
1510 .expr
1511 .iter()
1512 .map(|(expr, _)| {
1513 let value = expr.evaluate(batch)?;
1514 value.into_array(batch.num_rows())
1515 })
1516 .collect::<Result<Vec<_>>>()?;
1517
1518 let null_exprs: Vec<ArrayRef> = group_by
1519 .null_expr
1520 .iter()
1521 .map(|(expr, _)| {
1522 let value = expr.evaluate(batch)?;
1523 value.into_array(batch.num_rows())
1524 })
1525 .collect::<Result<Vec<_>>>()?;
1526
1527 group_by
1528 .groups
1529 .iter()
1530 .map(|group| {
1531 let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1532 group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1533 if *is_null {
1534 Arc::clone(&null_exprs[idx])
1535 } else {
1536 Arc::clone(&exprs[idx])
1537 }
1538 }));
1539 if !group_by.is_single() {
1540 group_values.push(group_id_array(group, batch)?);
1541 }
1542 Ok(group_values)
1543 })
1544 .collect()
1545}
1546
1547#[cfg(test)]
1548mod tests {
1549 use std::task::{Context, Poll};
1550
1551 use super::*;
1552 use crate::coalesce_batches::CoalesceBatchesExec;
1553 use crate::coalesce_partitions::CoalescePartitionsExec;
1554 use crate::common;
1555 use crate::common::collect;
1556 use crate::execution_plan::Boundedness;
1557 use crate::expressions::col;
1558 use crate::metrics::MetricValue;
1559 use crate::test::assert_is_pending;
1560 use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
1561 use crate::test::TestMemoryExec;
1562 use crate::RecordBatchStream;
1563
1564 use arrow::array::{
1565 DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
1566 UInt32Array, UInt64Array,
1567 };
1568 use arrow::compute::{concat_batches, SortOptions};
1569 use arrow::datatypes::{DataType, Int32Type};
1570 use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
1571 use datafusion_common::{internal_err, DataFusionError, ScalarValue};
1572 use datafusion_execution::config::SessionConfig;
1573 use datafusion_execution::memory_pool::FairSpillPool;
1574 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1575 use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1576 use datafusion_functions_aggregate::average::avg_udaf;
1577 use datafusion_functions_aggregate::count::count_udaf;
1578 use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1579 use datafusion_functions_aggregate::median::median_udaf;
1580 use datafusion_functions_aggregate::sum::sum_udaf;
1581 use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1582 use datafusion_physical_expr::expressions::lit;
1583 use datafusion_physical_expr::expressions::Literal;
1584 use datafusion_physical_expr::Partitioning;
1585 use datafusion_physical_expr::PhysicalSortExpr;
1586
1587 use futures::{FutureExt, Stream};
1588 use insta::{allow_duplicates, assert_snapshot};
1589
1590 fn create_test_schema() -> Result<SchemaRef> {
1592 let a = Field::new("a", DataType::Int32, true);
1593 let b = Field::new("b", DataType::Int32, true);
1594 let c = Field::new("c", DataType::Int32, true);
1595 let d = Field::new("d", DataType::Int32, true);
1596 let e = Field::new("e", DataType::Int32, true);
1597 let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1598
1599 Ok(schema)
1600 }
1601
1602 fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1604 let schema = Arc::new(Schema::new(vec![
1606 Field::new("a", DataType::UInt32, false),
1607 Field::new("b", DataType::Float64, false),
1608 ]));
1609
1610 (
1612 Arc::clone(&schema),
1613 vec![
1614 RecordBatch::try_new(
1615 Arc::clone(&schema),
1616 vec![
1617 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1618 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1619 ],
1620 )
1621 .unwrap(),
1622 RecordBatch::try_new(
1623 schema,
1624 vec![
1625 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1626 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1627 ],
1628 )
1629 .unwrap(),
1630 ],
1631 )
1632 }
1633
1634 fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1636 let schema = Arc::new(Schema::new(vec![
1638 Field::new("a", DataType::UInt32, false),
1639 Field::new("b", DataType::Float64, false),
1640 ]));
1641
1642 (
1647 Arc::clone(&schema),
1648 vec![
1649 RecordBatch::try_new(
1650 Arc::clone(&schema),
1651 vec![
1652 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1653 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1654 ],
1655 )
1656 .unwrap(),
1657 RecordBatch::try_new(
1658 Arc::clone(&schema),
1659 vec![
1660 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1661 Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1662 ],
1663 )
1664 .unwrap(),
1665 RecordBatch::try_new(
1666 Arc::clone(&schema),
1667 vec![
1668 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1669 Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1670 ],
1671 )
1672 .unwrap(),
1673 RecordBatch::try_new(
1674 schema,
1675 vec![
1676 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1677 Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1678 ],
1679 )
1680 .unwrap(),
1681 ],
1682 )
1683 }
1684
1685 fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1686 let session_config = SessionConfig::new().with_batch_size(batch_size);
1687 let runtime = RuntimeEnvBuilder::new()
1688 .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1689 .build_arc()
1690 .unwrap();
1691 let task_ctx = TaskContext::default()
1692 .with_session_config(session_config)
1693 .with_runtime(runtime);
1694 Arc::new(task_ctx)
1695 }
1696
1697 async fn check_grouping_sets(
1698 input: Arc<dyn ExecutionPlan>,
1699 spill: bool,
1700 ) -> Result<()> {
1701 let input_schema = input.schema();
1702
1703 let grouping_set = PhysicalGroupBy::new(
1704 vec![
1705 (col("a", &input_schema)?, "a".to_string()),
1706 (col("b", &input_schema)?, "b".to_string()),
1707 ],
1708 vec![
1709 (lit(ScalarValue::UInt32(None)), "a".to_string()),
1710 (lit(ScalarValue::Float64(None)), "b".to_string()),
1711 ],
1712 vec![
1713 vec![false, true], vec![true, false], vec![false, false], ],
1717 );
1718
1719 let aggregates = vec![Arc::new(
1720 AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1721 .schema(Arc::clone(&input_schema))
1722 .alias("COUNT(1)")
1723 .build()?,
1724 )];
1725
1726 let task_ctx = if spill {
1727 new_spill_ctx(4, 500)
1729 } else {
1730 Arc::new(TaskContext::default())
1731 };
1732
1733 let partial_aggregate = Arc::new(AggregateExec::try_new(
1734 AggregateMode::Partial,
1735 grouping_set.clone(),
1736 aggregates.clone(),
1737 vec![None],
1738 input,
1739 Arc::clone(&input_schema),
1740 )?);
1741
1742 let result =
1743 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1744
1745 if spill {
1746 allow_duplicates! {
1749 assert_snapshot!(batches_to_sort_string(&result),
1750 @r"
1751+---+-----+---------------+-----------------+
1752| a | b | __grouping_id | COUNT(1)[count] |
1753+---+-----+---------------+-----------------+
1754| | 1.0 | 2 | 1 |
1755| | 1.0 | 2 | 1 |
1756| | 2.0 | 2 | 1 |
1757| | 2.0 | 2 | 1 |
1758| | 3.0 | 2 | 1 |
1759| | 3.0 | 2 | 1 |
1760| | 4.0 | 2 | 1 |
1761| | 4.0 | 2 | 1 |
1762| 2 | | 1 | 1 |
1763| 2 | | 1 | 1 |
1764| 2 | 1.0 | 0 | 1 |
1765| 2 | 1.0 | 0 | 1 |
1766| 3 | | 1 | 1 |
1767| 3 | | 1 | 2 |
1768| 3 | 2.0 | 0 | 2 |
1769| 3 | 3.0 | 0 | 1 |
1770| 4 | | 1 | 1 |
1771| 4 | | 1 | 2 |
1772| 4 | 3.0 | 0 | 1 |
1773| 4 | 4.0 | 0 | 2 |
1774+---+-----+---------------+-----------------+
1775 "
1776 );
1777 }
1778 } else {
1779 allow_duplicates! {
1780 assert_snapshot!(batches_to_sort_string(&result),
1781 @r"
1782+---+-----+---------------+-----------------+
1783| a | b | __grouping_id | COUNT(1)[count] |
1784+---+-----+---------------+-----------------+
1785| | 1.0 | 2 | 2 |
1786| | 2.0 | 2 | 2 |
1787| | 3.0 | 2 | 2 |
1788| | 4.0 | 2 | 2 |
1789| 2 | | 1 | 2 |
1790| 2 | 1.0 | 0 | 2 |
1791| 3 | | 1 | 3 |
1792| 3 | 2.0 | 0 | 2 |
1793| 3 | 3.0 | 0 | 1 |
1794| 4 | | 1 | 3 |
1795| 4 | 3.0 | 0 | 1 |
1796| 4 | 4.0 | 0 | 2 |
1797+---+-----+---------------+-----------------+
1798 "
1799 );
1800 }
1801 };
1802
1803 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1804
1805 let final_grouping_set = grouping_set.as_final();
1806
1807 let task_ctx = if spill {
1808 new_spill_ctx(4, 3160)
1809 } else {
1810 task_ctx
1811 };
1812
1813 let merged_aggregate = Arc::new(AggregateExec::try_new(
1814 AggregateMode::Final,
1815 final_grouping_set,
1816 aggregates,
1817 vec![None],
1818 merge,
1819 input_schema,
1820 )?);
1821
1822 let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1823 let batch = concat_batches(&result[0].schema(), &result)?;
1824 assert_eq!(batch.num_columns(), 4);
1825 assert_eq!(batch.num_rows(), 12);
1826
1827 allow_duplicates! {
1828 assert_snapshot!(
1829 batches_to_sort_string(&result),
1830 @r"
1831 +---+-----+---------------+----------+
1832 | a | b | __grouping_id | COUNT(1) |
1833 +---+-----+---------------+----------+
1834 | | 1.0 | 2 | 2 |
1835 | | 2.0 | 2 | 2 |
1836 | | 3.0 | 2 | 2 |
1837 | | 4.0 | 2 | 2 |
1838 | 2 | | 1 | 2 |
1839 | 2 | 1.0 | 0 | 2 |
1840 | 3 | | 1 | 3 |
1841 | 3 | 2.0 | 0 | 2 |
1842 | 3 | 3.0 | 0 | 1 |
1843 | 4 | | 1 | 3 |
1844 | 4 | 3.0 | 0 | 1 |
1845 | 4 | 4.0 | 0 | 2 |
1846 +---+-----+---------------+----------+
1847 "
1848 );
1849 }
1850
1851 let metrics = merged_aggregate.metrics().unwrap();
1852 let output_rows = metrics.output_rows().unwrap();
1853 assert_eq!(12, output_rows);
1854
1855 Ok(())
1856 }
1857
1858 async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
1860 let input_schema = input.schema();
1861
1862 let grouping_set = PhysicalGroupBy::new(
1863 vec![(col("a", &input_schema)?, "a".to_string())],
1864 vec![],
1865 vec![vec![false]],
1866 );
1867
1868 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
1869 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1870 .schema(Arc::clone(&input_schema))
1871 .alias("AVG(b)")
1872 .build()?,
1873 )];
1874
1875 let task_ctx = if spill {
1876 new_spill_ctx(2, 1600)
1878 } else {
1879 Arc::new(TaskContext::default())
1880 };
1881
1882 let partial_aggregate = Arc::new(AggregateExec::try_new(
1883 AggregateMode::Partial,
1884 grouping_set.clone(),
1885 aggregates.clone(),
1886 vec![None],
1887 input,
1888 Arc::clone(&input_schema),
1889 )?);
1890
1891 let result =
1892 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1893
1894 if spill {
1895 allow_duplicates! {
1896 assert_snapshot!(batches_to_sort_string(&result), @r"
1897 +---+---------------+-------------+
1898 | a | AVG(b)[count] | AVG(b)[sum] |
1899 +---+---------------+-------------+
1900 | 2 | 1 | 1.0 |
1901 | 2 | 1 | 1.0 |
1902 | 3 | 1 | 2.0 |
1903 | 3 | 2 | 5.0 |
1904 | 4 | 3 | 11.0 |
1905 +---+---------------+-------------+
1906 ");
1907 }
1908 } else {
1909 allow_duplicates! {
1910 assert_snapshot!(batches_to_sort_string(&result), @r"
1911 +---+---------------+-------------+
1912 | a | AVG(b)[count] | AVG(b)[sum] |
1913 +---+---------------+-------------+
1914 | 2 | 2 | 2.0 |
1915 | 3 | 3 | 7.0 |
1916 | 4 | 3 | 11.0 |
1917 +---+---------------+-------------+
1918 ");
1919 }
1920 };
1921
1922 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1923
1924 let final_grouping_set = grouping_set.as_final();
1925
1926 let merged_aggregate = Arc::new(AggregateExec::try_new(
1927 AggregateMode::Final,
1928 final_grouping_set,
1929 aggregates,
1930 vec![None],
1931 merge,
1932 input_schema,
1933 )?);
1934
1935 let task_ctx = if spill {
1936 new_spill_ctx(2, 2600)
1938 } else {
1939 Arc::clone(&task_ctx)
1940 };
1941 let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
1942 let batch = concat_batches(&result[0].schema(), &result)?;
1943 assert_eq!(batch.num_columns(), 2);
1944 assert_eq!(batch.num_rows(), 3);
1945
1946 allow_duplicates! {
1947 assert_snapshot!(batches_to_sort_string(&result), @r"
1948 +---+--------------------+
1949 | a | AVG(b) |
1950 +---+--------------------+
1951 | 2 | 1.0 |
1952 | 3 | 2.3333333333333335 |
1953 | 4 | 3.6666666666666665 |
1954 +---+--------------------+
1955 ");
1956 }
1959
1960 let metrics = merged_aggregate.metrics().unwrap();
1961 let output_rows = metrics.output_rows().unwrap();
1962 let spill_count = metrics.spill_count().unwrap();
1963 let spilled_bytes = metrics.spilled_bytes().unwrap();
1964 let spilled_rows = metrics.spilled_rows().unwrap();
1965
1966 if spill {
1967 assert_eq!(8, output_rows);
1970
1971 assert!(spill_count > 0);
1972 assert!(spilled_bytes > 0);
1973 assert!(spilled_rows > 0);
1974 } else {
1975 assert_eq!(3, output_rows);
1976
1977 assert_eq!(0, spill_count);
1978 assert_eq!(0, spilled_bytes);
1979 assert_eq!(0, spilled_rows);
1980 }
1981
1982 Ok(())
1983 }
1984
1985 #[derive(Debug)]
1988 struct TestYieldingExec {
1989 pub yield_first: bool,
1991 cache: PlanProperties,
1992 }
1993
1994 impl TestYieldingExec {
1995 fn new(yield_first: bool) -> Self {
1996 let schema = some_data().0;
1997 let cache = Self::compute_properties(schema);
1998 Self { yield_first, cache }
1999 }
2000
2001 fn compute_properties(schema: SchemaRef) -> PlanProperties {
2003 PlanProperties::new(
2004 EquivalenceProperties::new(schema),
2005 Partitioning::UnknownPartitioning(1),
2006 EmissionType::Incremental,
2007 Boundedness::Bounded,
2008 )
2009 }
2010 }
2011
2012 impl DisplayAs for TestYieldingExec {
2013 fn fmt_as(
2014 &self,
2015 t: DisplayFormatType,
2016 f: &mut std::fmt::Formatter,
2017 ) -> std::fmt::Result {
2018 match t {
2019 DisplayFormatType::Default | DisplayFormatType::Verbose => {
2020 write!(f, "TestYieldingExec")
2021 }
2022 DisplayFormatType::TreeRender => {
2023 write!(f, "")
2025 }
2026 }
2027 }
2028 }
2029
2030 impl ExecutionPlan for TestYieldingExec {
2031 fn name(&self) -> &'static str {
2032 "TestYieldingExec"
2033 }
2034
2035 fn as_any(&self) -> &dyn Any {
2036 self
2037 }
2038
2039 fn properties(&self) -> &PlanProperties {
2040 &self.cache
2041 }
2042
2043 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2044 vec![]
2045 }
2046
2047 fn with_new_children(
2048 self: Arc<Self>,
2049 _: Vec<Arc<dyn ExecutionPlan>>,
2050 ) -> Result<Arc<dyn ExecutionPlan>> {
2051 internal_err!("Children cannot be replaced in {self:?}")
2052 }
2053
2054 fn execute(
2055 &self,
2056 _partition: usize,
2057 _context: Arc<TaskContext>,
2058 ) -> Result<SendableRecordBatchStream> {
2059 let stream = if self.yield_first {
2060 TestYieldingStream::New
2061 } else {
2062 TestYieldingStream::Yielded
2063 };
2064
2065 Ok(Box::pin(stream))
2066 }
2067
2068 fn statistics(&self) -> Result<Statistics> {
2069 self.partition_statistics(None)
2070 }
2071
2072 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
2073 if partition.is_some() {
2074 return Ok(Statistics::new_unknown(self.schema().as_ref()));
2075 }
2076 let (_, batches) = some_data();
2077 Ok(common::compute_record_batch_statistics(
2078 &[batches],
2079 &self.schema(),
2080 None,
2081 ))
2082 }
2083 }
2084
2085 enum TestYieldingStream {
2087 New,
2088 Yielded,
2089 ReturnedBatch1,
2090 ReturnedBatch2,
2091 }
2092
2093 impl Stream for TestYieldingStream {
2094 type Item = Result<RecordBatch>;
2095
2096 fn poll_next(
2097 mut self: std::pin::Pin<&mut Self>,
2098 cx: &mut Context<'_>,
2099 ) -> Poll<Option<Self::Item>> {
2100 match &*self {
2101 TestYieldingStream::New => {
2102 *(self.as_mut()) = TestYieldingStream::Yielded;
2103 cx.waker().wake_by_ref();
2104 Poll::Pending
2105 }
2106 TestYieldingStream::Yielded => {
2107 *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
2108 Poll::Ready(Some(Ok(some_data().1[0].clone())))
2109 }
2110 TestYieldingStream::ReturnedBatch1 => {
2111 *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
2112 Poll::Ready(Some(Ok(some_data().1[1].clone())))
2113 }
2114 TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
2115 }
2116 }
2117 }
2118
2119 impl RecordBatchStream for TestYieldingStream {
2120 fn schema(&self) -> SchemaRef {
2121 some_data().0
2122 }
2123 }
2124
2125 #[tokio::test]
2128 async fn aggregate_source_not_yielding() -> Result<()> {
2129 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2130
2131 check_aggregates(input, false).await
2132 }
2133
2134 #[tokio::test]
2135 async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
2136 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2137
2138 check_grouping_sets(input, false).await
2139 }
2140
2141 #[tokio::test]
2142 async fn aggregate_source_with_yielding() -> Result<()> {
2143 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2144
2145 check_aggregates(input, false).await
2146 }
2147
2148 #[tokio::test]
2149 async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
2150 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2151
2152 check_grouping_sets(input, false).await
2153 }
2154
2155 #[tokio::test]
2156 async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
2157 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2158
2159 check_aggregates(input, true).await
2160 }
2161
2162 #[tokio::test]
2163 async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
2164 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
2165
2166 check_grouping_sets(input, true).await
2167 }
2168
2169 #[tokio::test]
2170 async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
2171 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2172
2173 check_aggregates(input, true).await
2174 }
2175
2176 #[tokio::test]
2177 async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
2178 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2179
2180 check_grouping_sets(input, true).await
2181 }
2182
2183 fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
2185 AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
2186 .schema(schema)
2187 .alias("MEDIAN(a)")
2188 .build()
2189 }
2190
2191 #[tokio::test]
2192 async fn test_oom() -> Result<()> {
2193 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
2194 let input_schema = input.schema();
2195
2196 let runtime = RuntimeEnvBuilder::new()
2197 .with_memory_limit(1, 1.0)
2198 .build_arc()?;
2199 let task_ctx = TaskContext::default().with_runtime(runtime);
2200 let task_ctx = Arc::new(task_ctx);
2201
2202 let groups_none = PhysicalGroupBy::default();
2203 let groups_some = PhysicalGroupBy::new(
2204 vec![(col("a", &input_schema)?, "a".to_string())],
2205 vec![],
2206 vec![vec![false]],
2207 );
2208
2209 let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
2211 vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
2212
2213 let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2215 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
2216 .schema(Arc::clone(&input_schema))
2217 .alias("AVG(b)")
2218 .build()?,
2219 )];
2220
2221 for (version, groups, aggregates) in [
2222 (0, groups_none, aggregates_v0),
2223 (2, groups_some, aggregates_v2),
2224 ] {
2225 let n_aggr = aggregates.len();
2226 let partial_aggregate = Arc::new(AggregateExec::try_new(
2227 AggregateMode::Partial,
2228 groups,
2229 aggregates,
2230 vec![None; n_aggr],
2231 Arc::clone(&input),
2232 Arc::clone(&input_schema),
2233 )?);
2234
2235 let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?;
2236
2237 match version {
2239 0 => {
2240 assert!(matches!(stream, StreamType::AggregateStream(_)));
2241 }
2242 1 => {
2243 assert!(matches!(stream, StreamType::GroupedHash(_)));
2244 }
2245 2 => {
2246 assert!(matches!(stream, StreamType::GroupedHash(_)));
2247 }
2248 _ => panic!("Unknown version: {version}"),
2249 }
2250
2251 let stream: SendableRecordBatchStream = stream.into();
2252 let err = collect(stream).await.unwrap_err();
2253
2254 let err = err.find_root();
2256 assert!(
2257 matches!(err, DataFusionError::ResourcesExhausted(_)),
2258 "Wrong error type: {err}",
2259 );
2260 }
2261
2262 Ok(())
2263 }
2264
2265 #[tokio::test]
2266 async fn test_drop_cancel_without_groups() -> Result<()> {
2267 let task_ctx = Arc::new(TaskContext::default());
2268 let schema =
2269 Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2270
2271 let groups = PhysicalGroupBy::default();
2272
2273 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2274 AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2275 .schema(Arc::clone(&schema))
2276 .alias("AVG(a)")
2277 .build()?,
2278 )];
2279
2280 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2281 let refs = blocking_exec.refs();
2282 let aggregate_exec = Arc::new(AggregateExec::try_new(
2283 AggregateMode::Partial,
2284 groups.clone(),
2285 aggregates.clone(),
2286 vec![None],
2287 blocking_exec,
2288 schema,
2289 )?);
2290
2291 let fut = crate::collect(aggregate_exec, task_ctx);
2292 let mut fut = fut.boxed();
2293
2294 assert_is_pending(&mut fut);
2295 drop(fut);
2296 assert_strong_count_converges_to_zero(refs).await;
2297
2298 Ok(())
2299 }
2300
2301 #[tokio::test]
2302 async fn test_drop_cancel_with_groups() -> Result<()> {
2303 let task_ctx = Arc::new(TaskContext::default());
2304 let schema = Arc::new(Schema::new(vec![
2305 Field::new("a", DataType::Float64, true),
2306 Field::new("b", DataType::Float64, true),
2307 ]));
2308
2309 let groups =
2310 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2311
2312 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2313 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2314 .schema(Arc::clone(&schema))
2315 .alias("AVG(b)")
2316 .build()?,
2317 )];
2318
2319 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2320 let refs = blocking_exec.refs();
2321 let aggregate_exec = Arc::new(AggregateExec::try_new(
2322 AggregateMode::Partial,
2323 groups,
2324 aggregates.clone(),
2325 vec![None],
2326 blocking_exec,
2327 schema,
2328 )?);
2329
2330 let fut = crate::collect(aggregate_exec, task_ctx);
2331 let mut fut = fut.boxed();
2332
2333 assert_is_pending(&mut fut);
2334 drop(fut);
2335 assert_strong_count_converges_to_zero(refs).await;
2336
2337 Ok(())
2338 }
2339
2340 #[tokio::test]
2341 async fn run_first_last_multi_partitions() -> Result<()> {
2342 for use_coalesce_batches in [false, true] {
2343 for is_first_acc in [false, true] {
2344 for spill in [false, true] {
2345 first_last_multi_partitions(
2346 use_coalesce_batches,
2347 is_first_acc,
2348 spill,
2349 4200,
2350 )
2351 .await?
2352 }
2353 }
2354 }
2355 Ok(())
2356 }
2357
2358 fn test_first_value_agg_expr(
2360 schema: &Schema,
2361 sort_options: SortOptions,
2362 ) -> Result<Arc<AggregateFunctionExpr>> {
2363 let order_bys = vec![PhysicalSortExpr {
2364 expr: col("b", schema)?,
2365 options: sort_options,
2366 }];
2367 let args = [col("b", schema)?];
2368
2369 AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2370 .order_by(order_bys)
2371 .schema(Arc::new(schema.clone()))
2372 .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2373 .build()
2374 .map(Arc::new)
2375 }
2376
2377 fn test_last_value_agg_expr(
2379 schema: &Schema,
2380 sort_options: SortOptions,
2381 ) -> Result<Arc<AggregateFunctionExpr>> {
2382 let order_bys = vec![PhysicalSortExpr {
2383 expr: col("b", schema)?,
2384 options: sort_options,
2385 }];
2386 let args = [col("b", schema)?];
2387 AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2388 .order_by(order_bys)
2389 .schema(Arc::new(schema.clone()))
2390 .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2391 .build()
2392 .map(Arc::new)
2393 }
2394
2395 async fn first_last_multi_partitions(
2413 use_coalesce_batches: bool,
2414 is_first_acc: bool,
2415 spill: bool,
2416 max_memory: usize,
2417 ) -> Result<()> {
2418 let task_ctx = if spill {
2419 new_spill_ctx(2, max_memory)
2420 } else {
2421 Arc::new(TaskContext::default())
2422 };
2423
2424 let (schema, data) = some_data_v2();
2425 let partition1 = data[0].clone();
2426 let partition2 = data[1].clone();
2427 let partition3 = data[2].clone();
2428 let partition4 = data[3].clone();
2429
2430 let groups =
2431 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2432
2433 let sort_options = SortOptions {
2434 descending: false,
2435 nulls_first: false,
2436 };
2437 let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2438 vec![test_first_value_agg_expr(&schema, sort_options)?]
2439 } else {
2440 vec![test_last_value_agg_expr(&schema, sort_options)?]
2441 };
2442
2443 let memory_exec = TestMemoryExec::try_new_exec(
2444 &[
2445 vec![partition1],
2446 vec![partition2],
2447 vec![partition3],
2448 vec![partition4],
2449 ],
2450 Arc::clone(&schema),
2451 None,
2452 )?;
2453 let aggregate_exec = Arc::new(AggregateExec::try_new(
2454 AggregateMode::Partial,
2455 groups.clone(),
2456 aggregates.clone(),
2457 vec![None],
2458 memory_exec,
2459 Arc::clone(&schema),
2460 )?);
2461 let coalesce = if use_coalesce_batches {
2462 let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2463 Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2464 } else {
2465 Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2466 as Arc<dyn ExecutionPlan>
2467 };
2468 let aggregate_final = Arc::new(AggregateExec::try_new(
2469 AggregateMode::Final,
2470 groups,
2471 aggregates.clone(),
2472 vec![None],
2473 coalesce,
2474 schema,
2475 )?) as Arc<dyn ExecutionPlan>;
2476
2477 let result = crate::collect(aggregate_final, task_ctx).await?;
2478 if is_first_acc {
2479 allow_duplicates! {
2480 assert_snapshot!(batches_to_string(&result), @r"
2481 +---+--------------------------------------------+
2482 | a | first_value(b) ORDER BY [b ASC NULLS LAST] |
2483 +---+--------------------------------------------+
2484 | 2 | 0.0 |
2485 | 3 | 1.0 |
2486 | 4 | 3.0 |
2487 +---+--------------------------------------------+
2488 ");
2489 }
2490 } else {
2491 allow_duplicates! {
2492 assert_snapshot!(batches_to_string(&result), @r"
2493 +---+-------------------------------------------+
2494 | a | last_value(b) ORDER BY [b ASC NULLS LAST] |
2495 +---+-------------------------------------------+
2496 | 2 | 3.0 |
2497 | 3 | 5.0 |
2498 | 4 | 6.0 |
2499 +---+-------------------------------------------+
2500 ");
2501 }
2502 };
2503 Ok(())
2504 }
2505
2506 #[tokio::test]
2507 async fn test_get_finest_requirements() -> Result<()> {
2508 let test_schema = create_test_schema()?;
2509
2510 let options = SortOptions {
2511 descending: false,
2512 nulls_first: false,
2513 };
2514 let col_a = &col("a", &test_schema)?;
2515 let col_b = &col("b", &test_schema)?;
2516 let col_c = &col("c", &test_schema)?;
2517 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2518 eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?;
2520 let order_by_exprs = vec![
2523 vec![],
2524 vec![PhysicalSortExpr {
2525 expr: Arc::clone(col_a),
2526 options,
2527 }],
2528 vec![
2529 PhysicalSortExpr {
2530 expr: Arc::clone(col_a),
2531 options,
2532 },
2533 PhysicalSortExpr {
2534 expr: Arc::clone(col_b),
2535 options,
2536 },
2537 PhysicalSortExpr {
2538 expr: Arc::clone(col_c),
2539 options,
2540 },
2541 ],
2542 vec![
2543 PhysicalSortExpr {
2544 expr: Arc::clone(col_a),
2545 options,
2546 },
2547 PhysicalSortExpr {
2548 expr: Arc::clone(col_b),
2549 options,
2550 },
2551 ],
2552 ];
2553
2554 let common_requirement = vec![
2555 PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)),
2556 PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)),
2557 ];
2558 let mut aggr_exprs = order_by_exprs
2559 .into_iter()
2560 .map(|order_by_expr| {
2561 AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2562 .alias("a")
2563 .order_by(order_by_expr)
2564 .schema(Arc::clone(&test_schema))
2565 .build()
2566 .map(Arc::new)
2567 .unwrap()
2568 })
2569 .collect::<Vec<_>>();
2570 let group_by = PhysicalGroupBy::new_single(vec![]);
2571 let result = get_finer_aggregate_exprs_requirement(
2572 &mut aggr_exprs,
2573 &group_by,
2574 &eq_properties,
2575 &AggregateMode::Partial,
2576 )?;
2577 assert_eq!(result, common_requirement);
2578 Ok(())
2579 }
2580
2581 #[test]
2582 fn test_agg_exec_same_schema() -> Result<()> {
2583 let schema = Arc::new(Schema::new(vec![
2584 Field::new("a", DataType::Float32, true),
2585 Field::new("b", DataType::Float32, true),
2586 ]));
2587
2588 let col_a = col("a", &schema)?;
2589 let option_desc = SortOptions {
2590 descending: true,
2591 nulls_first: true,
2592 };
2593 let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2594
2595 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2596 test_first_value_agg_expr(&schema, option_desc)?,
2597 test_last_value_agg_expr(&schema, option_desc)?,
2598 ];
2599 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2600 let aggregate_exec = Arc::new(AggregateExec::try_new(
2601 AggregateMode::Partial,
2602 groups,
2603 aggregates,
2604 vec![None, None],
2605 Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2606 schema,
2607 )?);
2608 let new_agg =
2609 Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
2610 assert_eq!(new_agg.schema(), aggregate_exec.schema());
2611 Ok(())
2612 }
2613
2614 #[tokio::test]
2615 async fn test_agg_exec_group_by_const() -> Result<()> {
2616 let schema = Arc::new(Schema::new(vec![
2617 Field::new("a", DataType::Float32, true),
2618 Field::new("b", DataType::Float32, true),
2619 Field::new("const", DataType::Int32, false),
2620 ]));
2621
2622 let col_a = col("a", &schema)?;
2623 let col_b = col("b", &schema)?;
2624 let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2625
2626 let groups = PhysicalGroupBy::new(
2627 vec![
2628 (col_a, "a".to_string()),
2629 (col_b, "b".to_string()),
2630 (const_expr, "const".to_string()),
2631 ],
2632 vec![
2633 (
2634 Arc::new(Literal::new(ScalarValue::Float32(None))),
2635 "a".to_string(),
2636 ),
2637 (
2638 Arc::new(Literal::new(ScalarValue::Float32(None))),
2639 "b".to_string(),
2640 ),
2641 (
2642 Arc::new(Literal::new(ScalarValue::Int32(None))),
2643 "const".to_string(),
2644 ),
2645 ],
2646 vec![
2647 vec![false, true, true],
2648 vec![true, false, true],
2649 vec![true, true, false],
2650 ],
2651 );
2652
2653 let aggregates: Vec<Arc<AggregateFunctionExpr>> =
2654 vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2655 .schema(Arc::clone(&schema))
2656 .alias("1")
2657 .build()
2658 .map(Arc::new)?];
2659
2660 let input_batches = (0..4)
2661 .map(|_| {
2662 let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2663 let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2664 let c = Arc::new(Int32Array::from(vec![1; 8192]));
2665
2666 RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2667 })
2668 .collect();
2669
2670 let input =
2671 TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
2672
2673 let aggregate_exec = Arc::new(AggregateExec::try_new(
2674 AggregateMode::Single,
2675 groups,
2676 aggregates.clone(),
2677 vec![None],
2678 input,
2679 schema,
2680 )?);
2681
2682 let output =
2683 collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2684
2685 allow_duplicates! {
2686 assert_snapshot!(batches_to_sort_string(&output), @r"
2687 +-----+-----+-------+---------------+-------+
2688 | a | b | const | __grouping_id | 1 |
2689 +-----+-----+-------+---------------+-------+
2690 | | | 1 | 6 | 32768 |
2691 | | 0.0 | | 5 | 32768 |
2692 | 0.0 | | | 3 | 32768 |
2693 +-----+-----+-------+---------------+-------+
2694 ");
2695 }
2696
2697 Ok(())
2698 }
2699
2700 #[tokio::test]
2701 async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2702 let batch = RecordBatch::try_new(
2703 Arc::new(Schema::new(vec![
2704 Field::new(
2705 "labels".to_string(),
2706 DataType::Struct(
2707 vec![
2708 Field::new(
2709 "a".to_string(),
2710 DataType::Dictionary(
2711 Box::new(DataType::Int32),
2712 Box::new(DataType::Utf8),
2713 ),
2714 true,
2715 ),
2716 Field::new(
2717 "b".to_string(),
2718 DataType::Dictionary(
2719 Box::new(DataType::Int32),
2720 Box::new(DataType::Utf8),
2721 ),
2722 true,
2723 ),
2724 ]
2725 .into(),
2726 ),
2727 false,
2728 ),
2729 Field::new("value", DataType::UInt64, false),
2730 ])),
2731 vec![
2732 Arc::new(StructArray::from(vec![
2733 (
2734 Arc::new(Field::new(
2735 "a".to_string(),
2736 DataType::Dictionary(
2737 Box::new(DataType::Int32),
2738 Box::new(DataType::Utf8),
2739 ),
2740 true,
2741 )),
2742 Arc::new(
2743 vec![Some("a"), None, Some("a")]
2744 .into_iter()
2745 .collect::<DictionaryArray<Int32Type>>(),
2746 ) as ArrayRef,
2747 ),
2748 (
2749 Arc::new(Field::new(
2750 "b".to_string(),
2751 DataType::Dictionary(
2752 Box::new(DataType::Int32),
2753 Box::new(DataType::Utf8),
2754 ),
2755 true,
2756 )),
2757 Arc::new(
2758 vec![Some("b"), Some("c"), Some("b")]
2759 .into_iter()
2760 .collect::<DictionaryArray<Int32Type>>(),
2761 ) as ArrayRef,
2762 ),
2763 ])),
2764 Arc::new(UInt64Array::from(vec![1, 1, 1])),
2765 ],
2766 )
2767 .expect("Failed to create RecordBatch");
2768
2769 let group_by = PhysicalGroupBy::new_single(vec![(
2770 col("labels", &batch.schema())?,
2771 "labels".to_string(),
2772 )]);
2773
2774 let aggr_expr = vec![AggregateExprBuilder::new(
2775 sum_udaf(),
2776 vec![col("value", &batch.schema())?],
2777 )
2778 .schema(Arc::clone(&batch.schema()))
2779 .alias(String::from("SUM(value)"))
2780 .build()
2781 .map(Arc::new)?];
2782
2783 let input = TestMemoryExec::try_new_exec(
2784 &[vec![batch.clone()]],
2785 Arc::<Schema>::clone(&batch.schema()),
2786 None,
2787 )?;
2788 let aggregate_exec = Arc::new(AggregateExec::try_new(
2789 AggregateMode::FinalPartitioned,
2790 group_by,
2791 aggr_expr,
2792 vec![None],
2793 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2794 batch.schema(),
2795 )?);
2796
2797 let session_config = SessionConfig::default();
2798 let ctx = TaskContext::default().with_session_config(session_config);
2799 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2800
2801 allow_duplicates! {
2802 assert_snapshot!(batches_to_string(&output), @r"
2803 +--------------+------------+
2804 | labels | SUM(value) |
2805 +--------------+------------+
2806 | {a: a, b: b} | 2 |
2807 | {a: , b: c} | 1 |
2808 +--------------+------------+
2809 ");
2810 }
2811
2812 Ok(())
2813 }
2814
2815 #[tokio::test]
2816 async fn test_skip_aggregation_after_first_batch() -> Result<()> {
2817 let schema = Arc::new(Schema::new(vec![
2818 Field::new("key", DataType::Int32, true),
2819 Field::new("val", DataType::Int32, true),
2820 ]));
2821
2822 let group_by =
2823 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2824
2825 let aggr_expr =
2826 vec![
2827 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2828 .schema(Arc::clone(&schema))
2829 .alias(String::from("COUNT(val)"))
2830 .build()
2831 .map(Arc::new)?,
2832 ];
2833
2834 let input_data = vec![
2835 RecordBatch::try_new(
2836 Arc::clone(&schema),
2837 vec![
2838 Arc::new(Int32Array::from(vec![1, 2, 3])),
2839 Arc::new(Int32Array::from(vec![0, 0, 0])),
2840 ],
2841 )
2842 .unwrap(),
2843 RecordBatch::try_new(
2844 Arc::clone(&schema),
2845 vec![
2846 Arc::new(Int32Array::from(vec![2, 3, 4])),
2847 Arc::new(Int32Array::from(vec![0, 0, 0])),
2848 ],
2849 )
2850 .unwrap(),
2851 ];
2852
2853 let input =
2854 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2855 let aggregate_exec = Arc::new(AggregateExec::try_new(
2856 AggregateMode::Partial,
2857 group_by,
2858 aggr_expr,
2859 vec![None],
2860 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2861 schema,
2862 )?);
2863
2864 let mut session_config = SessionConfig::default();
2865 session_config = session_config.set(
2866 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2867 &ScalarValue::Int64(Some(2)),
2868 );
2869 session_config = session_config.set(
2870 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2871 &ScalarValue::Float64(Some(0.1)),
2872 );
2873
2874 let ctx = TaskContext::default().with_session_config(session_config);
2875 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2876
2877 allow_duplicates! {
2878 assert_snapshot!(batches_to_string(&output), @r"
2879 +-----+-------------------+
2880 | key | COUNT(val)[count] |
2881 +-----+-------------------+
2882 | 1 | 1 |
2883 | 2 | 1 |
2884 | 3 | 1 |
2885 | 2 | 1 |
2886 | 3 | 1 |
2887 | 4 | 1 |
2888 +-----+-------------------+
2889 ");
2890 }
2891
2892 Ok(())
2893 }
2894
2895 #[tokio::test]
2896 async fn test_skip_aggregation_after_threshold() -> Result<()> {
2897 let schema = Arc::new(Schema::new(vec![
2898 Field::new("key", DataType::Int32, true),
2899 Field::new("val", DataType::Int32, true),
2900 ]));
2901
2902 let group_by =
2903 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2904
2905 let aggr_expr =
2906 vec![
2907 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2908 .schema(Arc::clone(&schema))
2909 .alias(String::from("COUNT(val)"))
2910 .build()
2911 .map(Arc::new)?,
2912 ];
2913
2914 let input_data = vec![
2915 RecordBatch::try_new(
2916 Arc::clone(&schema),
2917 vec![
2918 Arc::new(Int32Array::from(vec![1, 2, 3])),
2919 Arc::new(Int32Array::from(vec![0, 0, 0])),
2920 ],
2921 )
2922 .unwrap(),
2923 RecordBatch::try_new(
2924 Arc::clone(&schema),
2925 vec![
2926 Arc::new(Int32Array::from(vec![2, 3, 4])),
2927 Arc::new(Int32Array::from(vec![0, 0, 0])),
2928 ],
2929 )
2930 .unwrap(),
2931 RecordBatch::try_new(
2932 Arc::clone(&schema),
2933 vec![
2934 Arc::new(Int32Array::from(vec![2, 3, 4])),
2935 Arc::new(Int32Array::from(vec![0, 0, 0])),
2936 ],
2937 )
2938 .unwrap(),
2939 ];
2940
2941 let input =
2942 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2943 let aggregate_exec = Arc::new(AggregateExec::try_new(
2944 AggregateMode::Partial,
2945 group_by,
2946 aggr_expr,
2947 vec![None],
2948 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2949 schema,
2950 )?);
2951
2952 let mut session_config = SessionConfig::default();
2953 session_config = session_config.set(
2954 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2955 &ScalarValue::Int64(Some(5)),
2956 );
2957 session_config = session_config.set(
2958 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2959 &ScalarValue::Float64(Some(0.1)),
2960 );
2961
2962 let ctx = TaskContext::default().with_session_config(session_config);
2963 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2964
2965 allow_duplicates! {
2966 assert_snapshot!(batches_to_string(&output), @r"
2967 +-----+-------------------+
2968 | key | COUNT(val)[count] |
2969 +-----+-------------------+
2970 | 1 | 1 |
2971 | 2 | 2 |
2972 | 3 | 2 |
2973 | 4 | 1 |
2974 | 2 | 1 |
2975 | 3 | 1 |
2976 | 4 | 1 |
2977 +-----+-------------------+
2978 ");
2979 }
2980
2981 Ok(())
2982 }
2983
2984 #[test]
2985 fn group_exprs_nullable() -> Result<()> {
2986 let input_schema = Arc::new(Schema::new(vec![
2987 Field::new("a", DataType::Float32, false),
2988 Field::new("b", DataType::Float32, false),
2989 ]));
2990
2991 let aggr_expr =
2992 vec![
2993 AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
2994 .schema(Arc::clone(&input_schema))
2995 .alias("COUNT(a)")
2996 .build()
2997 .map(Arc::new)?,
2998 ];
2999
3000 let grouping_set = PhysicalGroupBy::new(
3001 vec![
3002 (col("a", &input_schema)?, "a".to_string()),
3003 (col("b", &input_schema)?, "b".to_string()),
3004 ],
3005 vec![
3006 (lit(ScalarValue::Float32(None)), "a".to_string()),
3007 (lit(ScalarValue::Float32(None)), "b".to_string()),
3008 ],
3009 vec![
3010 vec![false, true], vec![false, false], ],
3013 );
3014 let aggr_schema = create_schema(
3015 &input_schema,
3016 &grouping_set,
3017 &aggr_expr,
3018 AggregateMode::Final,
3019 )?;
3020 let expected_schema = Schema::new(vec![
3021 Field::new("a", DataType::Float32, false),
3022 Field::new("b", DataType::Float32, true),
3023 Field::new("__grouping_id", DataType::UInt8, false),
3024 Field::new("COUNT(a)", DataType::Int64, false),
3025 ]);
3026 assert_eq!(aggr_schema, expected_schema);
3027 Ok(())
3028 }
3029
3030 async fn run_test_with_spill_pool_if_necessary(
3032 pool_size: usize,
3033 expect_spill: bool,
3034 ) -> Result<()> {
3035 fn create_record_batch(
3036 schema: &Arc<Schema>,
3037 data: (Vec<u32>, Vec<f64>),
3038 ) -> Result<RecordBatch> {
3039 Ok(RecordBatch::try_new(
3040 Arc::clone(schema),
3041 vec![
3042 Arc::new(UInt32Array::from(data.0)),
3043 Arc::new(Float64Array::from(data.1)),
3044 ],
3045 )?)
3046 }
3047
3048 let schema = Arc::new(Schema::new(vec![
3049 Field::new("a", DataType::UInt32, false),
3050 Field::new("b", DataType::Float64, false),
3051 ]));
3052
3053 let batches = vec![
3054 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3055 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
3056 ];
3057 let plan: Arc<dyn ExecutionPlan> =
3058 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
3059
3060 let grouping_set = PhysicalGroupBy::new(
3061 vec![(col("a", &schema)?, "a".to_string())],
3062 vec![],
3063 vec![vec![false]],
3064 );
3065
3066 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
3068 Arc::new(
3069 AggregateExprBuilder::new(
3070 datafusion_functions_aggregate::min_max::min_udaf(),
3071 vec![col("b", &schema)?],
3072 )
3073 .schema(Arc::clone(&schema))
3074 .alias("MIN(b)")
3075 .build()?,
3076 ),
3077 Arc::new(
3078 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
3079 .schema(Arc::clone(&schema))
3080 .alias("AVG(b)")
3081 .build()?,
3082 ),
3083 ];
3084
3085 let single_aggregate = Arc::new(AggregateExec::try_new(
3086 AggregateMode::Single,
3087 grouping_set,
3088 aggregates,
3089 vec![None, None],
3090 plan,
3091 Arc::clone(&schema),
3092 )?);
3093
3094 let batch_size = 2;
3095 let memory_pool = Arc::new(FairSpillPool::new(pool_size));
3096 let task_ctx = Arc::new(
3097 TaskContext::default()
3098 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
3099 .with_runtime(Arc::new(
3100 RuntimeEnvBuilder::new()
3101 .with_memory_pool(memory_pool)
3102 .build()?,
3103 )),
3104 );
3105
3106 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
3107
3108 assert_spill_count_metric(expect_spill, single_aggregate);
3109
3110 allow_duplicates! {
3111 assert_snapshot!(batches_to_string(&result), @r"
3112 +---+--------+--------+
3113 | a | MIN(b) | AVG(b) |
3114 +---+--------+--------+
3115 | 2 | 1.0 | 1.0 |
3116 | 3 | 2.0 | 2.0 |
3117 | 4 | 3.0 | 3.5 |
3118 +---+--------+--------+
3119 ");
3120 }
3121
3122 Ok(())
3123 }
3124
3125 fn assert_spill_count_metric(
3126 expect_spill: bool,
3127 single_aggregate: Arc<AggregateExec>,
3128 ) {
3129 if let Some(metrics_set) = single_aggregate.metrics() {
3130 let mut spill_count = 0;
3131
3132 for metric in metrics_set.iter() {
3134 if let MetricValue::SpillCount(count) = metric.value() {
3135 spill_count = count.value();
3136 break;
3137 }
3138 }
3139
3140 if expect_spill && spill_count == 0 {
3141 panic!(
3142 "Expected spill but SpillCount metric not found or SpillCount was 0."
3143 );
3144 } else if !expect_spill && spill_count > 0 {
3145 panic!("Expected no spill but found SpillCount metric with value greater than 0.");
3146 }
3147 } else {
3148 panic!("No metrics returned from the operator; cannot verify spilling.");
3149 }
3150 }
3151
3152 #[tokio::test]
3153 async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
3154 run_test_with_spill_pool_if_necessary(2_000, true).await?;
3156 run_test_with_spill_pool_if_necessary(20_000, false).await?;
3158 Ok(())
3159 }
3160}