datafusion_physical_optimizer/
combine_partial_final_agg.rs1use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::ExecutionPlan;
25use datafusion_physical_plan::aggregates::{
26 AggregateExec, AggregateMode, PhysicalGroupBy,
27};
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{PhysicalExpr, physical_exprs_equal};
34
35#[derive(Default, Debug)]
40pub struct CombinePartialFinalAggregate {}
41
42impl CombinePartialFinalAggregate {
43 #[expect(missing_docs)]
44 pub fn new() -> Self {
45 Self {}
46 }
47}
48
49impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
50 fn optimize(
51 &self,
52 plan: Arc<dyn ExecutionPlan>,
53 _config: &ConfigOptions,
54 ) -> Result<Arc<dyn ExecutionPlan>> {
55 plan.transform_down(|plan| {
56 let Some(agg_exec) = plan.downcast_ref::<AggregateExec>() else {
58 return Ok(Transformed::no(plan));
59 };
60
61 if !matches!(
62 agg_exec.mode(),
63 AggregateMode::Final | AggregateMode::FinalPartitioned
64 ) {
65 return Ok(Transformed::no(plan));
66 }
67
68 let Some(input_agg_exec) = agg_exec.input().downcast_ref::<AggregateExec>()
70 else {
71 return Ok(Transformed::no(plan));
72 };
73
74 let transformed = if *input_agg_exec.mode() == AggregateMode::Partial
75 && can_combine(
76 (
77 agg_exec.group_expr(),
78 agg_exec.aggr_expr(),
79 agg_exec.filter_expr(),
80 ),
81 (
82 input_agg_exec.group_expr(),
83 input_agg_exec.aggr_expr(),
84 input_agg_exec.filter_expr(),
85 ),
86 ) {
87 let mode = if agg_exec.mode() == &AggregateMode::Final {
88 AggregateMode::Single
89 } else {
90 AggregateMode::SinglePartitioned
91 };
92 AggregateExec::try_new(
93 mode,
94 input_agg_exec.group_expr().clone(),
95 input_agg_exec.aggr_expr().to_vec(),
96 input_agg_exec.filter_expr().to_vec(),
97 Arc::clone(input_agg_exec.input()),
98 input_agg_exec.input_schema(),
99 )
100 .map(|combined_agg| {
101 combined_agg.with_limit_options(agg_exec.limit_options())
102 })
103 .ok()
104 .map(Arc::new)
105 } else {
106 None
107 };
108 Ok(if let Some(transformed) = transformed {
109 Transformed::yes(transformed)
110 } else {
111 Transformed::no(plan)
112 })
113 })
114 .data()
115 }
116
117 fn name(&self) -> &str {
118 "CombinePartialFinalAggregate"
119 }
120
121 fn schema_check(&self) -> bool {
122 true
123 }
124}
125
126type GroupExprsRef<'a> = (
127 &'a PhysicalGroupBy,
128 &'a [Arc<AggregateFunctionExpr>],
129 &'a [Option<Arc<dyn PhysicalExpr>>],
130);
131
132fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
133 let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
134 let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
135
136 physical_exprs_equal(
138 &input_group_by.output_exprs(),
139 &final_group_by.input_exprs(),
140 ) && input_group_by.groups() == final_group_by.groups()
141 && input_group_by.null_expr().len() == final_group_by.null_expr().len()
142 && input_group_by
143 .null_expr()
144 .iter()
145 .zip(final_group_by.null_expr().iter())
146 .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
147 lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
148 })
149 && final_aggr_expr.len() == input_aggr_expr.len()
150 && final_aggr_expr
151 .iter()
152 .zip(input_aggr_expr.iter())
153 .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
154 && final_filter_expr.len() == input_filter_expr.len()
155 && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
156 |(final_expr, partial_expr)| match (final_expr, partial_expr) {
157 (Some(l), Some(r)) => l.eq(r),
158 (None, None) => true,
159 _ => false,
160 },
161 )
162}
163
164