datafusion_physical_optimizer/
combine_partial_final_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs
19//! and try to combine them if necessary
20
21use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::aggregates::{
25    AggregateExec, AggregateMode, PhysicalGroupBy,
26};
27use datafusion_physical_plan::ExecutionPlan;
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr};
34
35/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
36/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
37///
38/// This rule should be applied after the EnforceDistribution and EnforceSorting rules
39///
40#[derive(Default, Debug)]
41pub struct CombinePartialFinalAggregate {}
42
43impl CombinePartialFinalAggregate {
44    #[allow(missing_docs)]
45    pub fn new() -> Self {
46        Self {}
47    }
48}
49
50impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
51    fn optimize(
52        &self,
53        plan: Arc<dyn ExecutionPlan>,
54        _config: &ConfigOptions,
55    ) -> Result<Arc<dyn ExecutionPlan>> {
56        plan.transform_down(|plan| {
57            // Check if the plan is AggregateExec
58            let Some(agg_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
59                return Ok(Transformed::no(plan));
60            };
61
62            if !matches!(
63                agg_exec.mode(),
64                AggregateMode::Final | AggregateMode::FinalPartitioned
65            ) {
66                return Ok(Transformed::no(plan));
67            }
68
69            // Check if the input is AggregateExec
70            let Some(input_agg_exec) =
71                agg_exec.input().as_any().downcast_ref::<AggregateExec>()
72            else {
73                return Ok(Transformed::no(plan));
74            };
75
76            let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial)
77                && can_combine(
78                    (
79                        agg_exec.group_expr(),
80                        agg_exec.aggr_expr(),
81                        agg_exec.filter_expr(),
82                    ),
83                    (
84                        input_agg_exec.group_expr(),
85                        input_agg_exec.aggr_expr(),
86                        input_agg_exec.filter_expr(),
87                    ),
88                ) {
89                let mode = if agg_exec.mode() == &AggregateMode::Final {
90                    AggregateMode::Single
91                } else {
92                    AggregateMode::SinglePartitioned
93                };
94                AggregateExec::try_new(
95                    mode,
96                    input_agg_exec.group_expr().clone(),
97                    input_agg_exec.aggr_expr().to_vec(),
98                    input_agg_exec.filter_expr().to_vec(),
99                    Arc::clone(input_agg_exec.input()),
100                    input_agg_exec.input_schema(),
101                )
102                .map(|combined_agg| combined_agg.with_limit(agg_exec.limit()))
103                .ok()
104                .map(Arc::new)
105            } else {
106                None
107            };
108            Ok(if let Some(transformed) = transformed {
109                Transformed::yes(transformed)
110            } else {
111                Transformed::no(plan)
112            })
113        })
114        .data()
115    }
116
117    fn name(&self) -> &str {
118        "CombinePartialFinalAggregate"
119    }
120
121    fn schema_check(&self) -> bool {
122        true
123    }
124}
125
126type GroupExprsRef<'a> = (
127    &'a PhysicalGroupBy,
128    &'a [Arc<AggregateFunctionExpr>],
129    &'a [Option<Arc<dyn PhysicalExpr>>],
130);
131
132fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
133    let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
134    let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
135
136    // Compare output expressions of the partial, and input expressions of the final operator.
137    physical_exprs_equal(
138        &input_group_by.output_exprs(),
139        &final_group_by.input_exprs(),
140    ) && input_group_by.groups() == final_group_by.groups()
141        && input_group_by.null_expr().len() == final_group_by.null_expr().len()
142        && input_group_by
143            .null_expr()
144            .iter()
145            .zip(final_group_by.null_expr().iter())
146            .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
147                lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
148            })
149        && final_aggr_expr.len() == input_aggr_expr.len()
150        && final_aggr_expr
151            .iter()
152            .zip(input_aggr_expr.iter())
153            .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
154        && final_filter_expr.len() == input_filter_expr.len()
155        && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
156            |(final_expr, partial_expr)| match (final_expr, partial_expr) {
157                (Some(l), Some(r)) => l.eq(r),
158                (None, None) => true,
159                _ => false,
160            },
161        )
162}
163
164// See tests in datafusion/core/tests/physical_optimizer