datafusion_physical_optimizer/
combine_partial_final_agg.rs1use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::ExecutionPlan;
25use datafusion_physical_plan::aggregates::{
26 AggregateExec, AggregateMode, PhysicalGroupBy,
27};
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{PhysicalExpr, physical_exprs_equal};
34
35#[derive(Default, Debug)]
40pub struct CombinePartialFinalAggregate {}
41
42impl CombinePartialFinalAggregate {
43 #[expect(missing_docs)]
44 pub fn new() -> Self {
45 Self {}
46 }
47}
48
49impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
50 fn optimize(
51 &self,
52 plan: Arc<dyn ExecutionPlan>,
53 _config: &ConfigOptions,
54 ) -> Result<Arc<dyn ExecutionPlan>> {
55 plan.transform_down(|plan| {
56 let Some(agg_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
58 return Ok(Transformed::no(plan));
59 };
60
61 if !matches!(
62 agg_exec.mode(),
63 AggregateMode::Final | AggregateMode::FinalPartitioned
64 ) {
65 return Ok(Transformed::no(plan));
66 }
67
68 let Some(input_agg_exec) =
70 agg_exec.input().as_any().downcast_ref::<AggregateExec>()
71 else {
72 return Ok(Transformed::no(plan));
73 };
74
75 let transformed = if *input_agg_exec.mode() == AggregateMode::Partial
76 && can_combine(
77 (
78 agg_exec.group_expr(),
79 agg_exec.aggr_expr(),
80 agg_exec.filter_expr(),
81 ),
82 (
83 input_agg_exec.group_expr(),
84 input_agg_exec.aggr_expr(),
85 input_agg_exec.filter_expr(),
86 ),
87 ) {
88 let mode = if agg_exec.mode() == &AggregateMode::Final {
89 AggregateMode::Single
90 } else {
91 AggregateMode::SinglePartitioned
92 };
93 AggregateExec::try_new(
94 mode,
95 input_agg_exec.group_expr().clone(),
96 input_agg_exec.aggr_expr().to_vec(),
97 input_agg_exec.filter_expr().to_vec(),
98 Arc::clone(input_agg_exec.input()),
99 input_agg_exec.input_schema(),
100 )
101 .map(|combined_agg| {
102 combined_agg.with_limit_options(agg_exec.limit_options())
103 })
104 .ok()
105 .map(Arc::new)
106 } else {
107 None
108 };
109 Ok(if let Some(transformed) = transformed {
110 Transformed::yes(transformed)
111 } else {
112 Transformed::no(plan)
113 })
114 })
115 .data()
116 }
117
118 fn name(&self) -> &str {
119 "CombinePartialFinalAggregate"
120 }
121
122 fn schema_check(&self) -> bool {
123 true
124 }
125}
126
127type GroupExprsRef<'a> = (
128 &'a PhysicalGroupBy,
129 &'a [Arc<AggregateFunctionExpr>],
130 &'a [Option<Arc<dyn PhysicalExpr>>],
131);
132
133fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
134 let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
135 let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
136
137 physical_exprs_equal(
139 &input_group_by.output_exprs(),
140 &final_group_by.input_exprs(),
141 ) && input_group_by.groups() == final_group_by.groups()
142 && input_group_by.null_expr().len() == final_group_by.null_expr().len()
143 && input_group_by
144 .null_expr()
145 .iter()
146 .zip(final_group_by.null_expr().iter())
147 .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
148 lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
149 })
150 && final_aggr_expr.len() == input_aggr_expr.len()
151 && final_aggr_expr
152 .iter()
153 .zip(input_aggr_expr.iter())
154 .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
155 && final_filter_expr.len() == input_filter_expr.len()
156 && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
157 |(final_expr, partial_expr)| match (final_expr, partial_expr) {
158 (Some(l), Some(r)) => l.eq(r),
159 (None, None) => true,
160 _ => false,
161 },
162 )
163}
164
165