Skip to main content

datafusion_physical_optimizer/
topk_aggregation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! An optimizer rule that detects aggregate operations that could use a limited bucket count
19
20use std::sync::Arc;
21
22use crate::PhysicalOptimizerRule;
23use datafusion_common::Result;
24use datafusion_common::config::ConfigOptions;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_physical_expr::expressions::Column;
27use datafusion_physical_plan::ExecutionPlan;
28use datafusion_physical_plan::aggregates::LimitOptions;
29use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported};
30use datafusion_physical_plan::execution_plan::CardinalityEffect;
31use datafusion_physical_plan::projection::ProjectionExec;
32use datafusion_physical_plan::sorts::sort::SortExec;
33use itertools::Itertools;
34
35/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
36#[derive(Debug)]
37pub struct TopKAggregation {}
38
39impl TopKAggregation {
40    /// Create a new `LimitAggregation`
41    pub fn new() -> Self {
42        Self {}
43    }
44
45    fn transform_agg(
46        aggr: &AggregateExec,
47        order_by: &str,
48        order_desc: bool,
49        limit: usize,
50    ) -> Option<Arc<dyn ExecutionPlan>> {
51        // Current only support single group key
52        let (group_key, group_key_alias) =
53            aggr.group_expr().expr().iter().exactly_one().ok()?;
54        let kt = group_key.data_type(&aggr.input().schema()).ok()?;
55        let vt = if let Some((field, _)) = aggr.get_minmax_desc() {
56            field.data_type().clone()
57        } else {
58            kt.clone()
59        };
60        if !topk_types_supported(&kt, &vt) {
61            return None;
62        }
63        if aggr.filter_expr().iter().any(|e| e.is_some()) {
64            return None;
65        }
66
67        // Check if this is ordering by an aggregate function (MIN/MAX)
68        if let Some((field, desc)) = aggr.get_minmax_desc() {
69            // ensure the sort direction matches aggregate function
70            if desc != order_desc {
71                return None;
72            }
73            // ensure the sort is on the same field as the aggregate output
74            if order_by != field.name() {
75                return None;
76            }
77        } else if aggr.aggr_expr().is_empty() {
78            // This is a GROUP BY without aggregates, check if ordering is on the group key itself
79            if order_by != group_key_alias {
80                return None;
81            }
82        } else {
83            // Has aggregates but not MIN/MAX, or doesn't DISTINCT
84            return None;
85        }
86
87        // We found what we want: clone, copy the limit down, and return modified node
88        let new_aggr = AggregateExec::with_new_limit_options(
89            aggr,
90            Some(LimitOptions::new_with_order(limit, order_desc)),
91        );
92        Some(Arc::new(new_aggr))
93    }
94
95    fn transform_sort(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
96        let sort = plan.as_any().downcast_ref::<SortExec>()?;
97
98        let children = sort.children();
99        let child = children.into_iter().exactly_one().ok()?;
100        let order = sort.properties().output_ordering()?;
101        let order = order.iter().exactly_one().ok()?;
102        let order_desc = order.options.descending;
103        let order = order.expr.as_any().downcast_ref::<Column>()?;
104        let mut cur_col_name = order.name().to_string();
105        let limit = sort.fetch()?;
106
107        let mut cardinality_preserved = true;
108        let closure = |plan: Arc<dyn ExecutionPlan>| {
109            if !cardinality_preserved {
110                return Ok(Transformed::no(plan));
111            }
112            if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
113                // either we run into an Aggregate and transform it
114                match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
115                    None => cardinality_preserved = false,
116                    Some(plan) => return Ok(Transformed::yes(plan)),
117                }
118            } else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
119                // track renames due to successive projections
120                for proj_expr in proj.expr() {
121                    let Some(src_col) = proj_expr.expr.as_any().downcast_ref::<Column>()
122                    else {
123                        continue;
124                    };
125                    if proj_expr.alias == cur_col_name {
126                        cur_col_name = src_col.name().to_string();
127                    }
128                }
129            } else {
130                // or we continue down through types that don't reduce cardinality
131                match plan.cardinality_effect() {
132                    CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
133                    CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
134                        cardinality_preserved = false;
135                    }
136                }
137            }
138            Ok(Transformed::no(plan))
139        };
140        let child = Arc::clone(child).transform_down(closure).data().ok()?;
141        let sort = SortExec::new(sort.expr().clone(), child)
142            .with_fetch(sort.fetch())
143            .with_preserve_partitioning(sort.preserve_partitioning());
144        Some(Arc::new(sort))
145    }
146}
147
148impl Default for TopKAggregation {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl PhysicalOptimizerRule for TopKAggregation {
155    fn optimize(
156        &self,
157        plan: Arc<dyn ExecutionPlan>,
158        config: &ConfigOptions,
159    ) -> Result<Arc<dyn ExecutionPlan>> {
160        if config.optimizer.enable_topk_aggregation {
161            plan.transform_down(|plan| {
162                Ok(if let Some(plan) = TopKAggregation::transform_sort(&plan) {
163                    Transformed::yes(plan)
164                } else {
165                    Transformed::no(plan)
166                })
167            })
168            .data()
169        } else {
170            Ok(plan)
171        }
172    }
173
174    fn name(&self) -> &str {
175        "LimitAggregation"
176    }
177
178    fn schema_check(&self) -> bool {
179        true
180    }
181}
182
183// see `aggregate.slt` for tests