Skip to main content

datafusion_physical_optimizer/
combine_partial_final_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs
19//! and try to combine them if necessary
20
21use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::ExecutionPlan;
25use datafusion_physical_plan::aggregates::{
26    AggregateExec, AggregateMode, PhysicalGroupBy,
27};
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{PhysicalExpr, physical_exprs_equal};
34
35/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
36/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
37///
38/// This rule should be applied after the EnforceDistribution and EnforceSorting rules
39#[derive(Default, Debug)]
40pub struct CombinePartialFinalAggregate {}
41
42impl CombinePartialFinalAggregate {
43    #[expect(missing_docs)]
44    pub fn new() -> Self {
45        Self {}
46    }
47}
48
49impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
50    fn optimize(
51        &self,
52        plan: Arc<dyn ExecutionPlan>,
53        _config: &ConfigOptions,
54    ) -> Result<Arc<dyn ExecutionPlan>> {
55        plan.transform_down(|plan| {
56            // Check if the plan is AggregateExec
57            let Some(agg_exec) = plan.downcast_ref::<AggregateExec>() else {
58                return Ok(Transformed::no(plan));
59            };
60
61            if !matches!(
62                agg_exec.mode(),
63                AggregateMode::Final | AggregateMode::FinalPartitioned
64            ) {
65                return Ok(Transformed::no(plan));
66            }
67
68            // Check if the input is AggregateExec
69            let Some(input_agg_exec) = agg_exec.input().downcast_ref::<AggregateExec>()
70            else {
71                return Ok(Transformed::no(plan));
72            };
73
74            let transformed = if *input_agg_exec.mode() == AggregateMode::Partial
75                && can_combine(
76                    (
77                        agg_exec.group_expr(),
78                        agg_exec.aggr_expr(),
79                        agg_exec.filter_expr(),
80                    ),
81                    (
82                        input_agg_exec.group_expr(),
83                        input_agg_exec.aggr_expr(),
84                        input_agg_exec.filter_expr(),
85                    ),
86                ) {
87                let mode = if agg_exec.mode() == &AggregateMode::Final {
88                    AggregateMode::Single
89                } else {
90                    AggregateMode::SinglePartitioned
91                };
92                AggregateExec::try_new(
93                    mode,
94                    input_agg_exec.group_expr().clone(),
95                    input_agg_exec.aggr_expr().to_vec(),
96                    input_agg_exec.filter_expr().to_vec(),
97                    Arc::clone(input_agg_exec.input()),
98                    input_agg_exec.input_schema(),
99                )
100                .map(|combined_agg| {
101                    combined_agg.with_limit_options(agg_exec.limit_options())
102                })
103                .ok()
104                .map(Arc::new)
105            } else {
106                None
107            };
108            Ok(if let Some(transformed) = transformed {
109                Transformed::yes(transformed)
110            } else {
111                Transformed::no(plan)
112            })
113        })
114        .data()
115    }
116
117    fn name(&self) -> &str {
118        "CombinePartialFinalAggregate"
119    }
120
121    fn schema_check(&self) -> bool {
122        true
123    }
124}
125
126type GroupExprsRef<'a> = (
127    &'a PhysicalGroupBy,
128    &'a [Arc<AggregateFunctionExpr>],
129    &'a [Option<Arc<dyn PhysicalExpr>>],
130);
131
132fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
133    let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
134    let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
135
136    // Compare output expressions of the partial, and input expressions of the final operator.
137    physical_exprs_equal(
138        &input_group_by.output_exprs(),
139        &final_group_by.input_exprs(),
140    ) && input_group_by.groups() == final_group_by.groups()
141        && input_group_by.null_expr().len() == final_group_by.null_expr().len()
142        && input_group_by
143            .null_expr()
144            .iter()
145            .zip(final_group_by.null_expr().iter())
146            .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
147                lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
148            })
149        && final_aggr_expr.len() == input_aggr_expr.len()
150        && final_aggr_expr
151            .iter()
152            .zip(input_aggr_expr.iter())
153            .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
154        && final_filter_expr.len() == input_filter_expr.len()
155        && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
156            |(final_expr, partial_expr)| match (final_expr, partial_expr) {
157                (Some(l), Some(r)) => l.eq(r),
158                (None, None) => true,
159                _ => false,
160            },
161        )
162}
163
164// See tests in datafusion/core/tests/physical_optimizer