datafusion_physical_optimizer/
combine_partial_final_agg.rs1use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::aggregates::{
25 AggregateExec, AggregateMode, PhysicalGroupBy,
26};
27use datafusion_physical_plan::ExecutionPlan;
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr};
34
35#[derive(Default, Debug)]
41pub struct CombinePartialFinalAggregate {}
42
43impl CombinePartialFinalAggregate {
44 #[allow(missing_docs)]
45 pub fn new() -> Self {
46 Self {}
47 }
48}
49
50impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
51 fn optimize(
52 &self,
53 plan: Arc<dyn ExecutionPlan>,
54 _config: &ConfigOptions,
55 ) -> Result<Arc<dyn ExecutionPlan>> {
56 plan.transform_down(|plan| {
57 let Some(agg_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
59 return Ok(Transformed::no(plan));
60 };
61
62 if !matches!(
63 agg_exec.mode(),
64 AggregateMode::Final | AggregateMode::FinalPartitioned
65 ) {
66 return Ok(Transformed::no(plan));
67 }
68
69 let Some(input_agg_exec) =
71 agg_exec.input().as_any().downcast_ref::<AggregateExec>()
72 else {
73 return Ok(Transformed::no(plan));
74 };
75
76 let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial)
77 && can_combine(
78 (
79 agg_exec.group_expr(),
80 agg_exec.aggr_expr(),
81 agg_exec.filter_expr(),
82 ),
83 (
84 input_agg_exec.group_expr(),
85 input_agg_exec.aggr_expr(),
86 input_agg_exec.filter_expr(),
87 ),
88 ) {
89 let mode = if agg_exec.mode() == &AggregateMode::Final {
90 AggregateMode::Single
91 } else {
92 AggregateMode::SinglePartitioned
93 };
94 AggregateExec::try_new(
95 mode,
96 input_agg_exec.group_expr().clone(),
97 input_agg_exec.aggr_expr().to_vec(),
98 input_agg_exec.filter_expr().to_vec(),
99 Arc::clone(input_agg_exec.input()),
100 input_agg_exec.input_schema(),
101 )
102 .map(|combined_agg| combined_agg.with_limit(agg_exec.limit()))
103 .ok()
104 .map(Arc::new)
105 } else {
106 None
107 };
108 Ok(if let Some(transformed) = transformed {
109 Transformed::yes(transformed)
110 } else {
111 Transformed::no(plan)
112 })
113 })
114 .data()
115 }
116
117 fn name(&self) -> &str {
118 "CombinePartialFinalAggregate"
119 }
120
121 fn schema_check(&self) -> bool {
122 true
123 }
124}
125
126type GroupExprsRef<'a> = (
127 &'a PhysicalGroupBy,
128 &'a [Arc<AggregateFunctionExpr>],
129 &'a [Option<Arc<dyn PhysicalExpr>>],
130);
131
132fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
133 let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
134 let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
135
136 physical_exprs_equal(
138 &input_group_by.output_exprs(),
139 &final_group_by.input_exprs(),
140 ) && input_group_by.groups() == final_group_by.groups()
141 && input_group_by.null_expr().len() == final_group_by.null_expr().len()
142 && input_group_by
143 .null_expr()
144 .iter()
145 .zip(final_group_by.null_expr().iter())
146 .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
147 lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
148 })
149 && final_aggr_expr.len() == input_aggr_expr.len()
150 && final_aggr_expr
151 .iter()
152 .zip(input_aggr_expr.iter())
153 .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
154 && final_filter_expr.len() == input_filter_expr.len()
155 && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
156 |(final_expr, partial_expr)| match (final_expr, partial_expr) {
157 (Some(l), Some(r)) => l.eq(r),
158 (None, None) => true,
159 _ => false,
160 },
161 )
162}
163
164