datafusion_physical_optimizer/
topk_aggregation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! An optimizer rule that detects aggregate operations that could use a limited bucket count
19
20use std::sync::Arc;
21
22use crate::PhysicalOptimizerRule;
23use arrow::datatypes::DataType;
24use datafusion_common::config::ConfigOptions;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_common::Result;
27use datafusion_physical_expr::expressions::Column;
28use datafusion_physical_expr::LexOrdering;
29use datafusion_physical_plan::aggregates::AggregateExec;
30use datafusion_physical_plan::execution_plan::CardinalityEffect;
31use datafusion_physical_plan::projection::ProjectionExec;
32use datafusion_physical_plan::sorts::sort::SortExec;
33use datafusion_physical_plan::ExecutionPlan;
34use itertools::Itertools;
35
36/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
37#[derive(Debug)]
38pub struct TopKAggregation {}
39
40impl TopKAggregation {
41    /// Create a new `LimitAggregation`
42    pub fn new() -> Self {
43        Self {}
44    }
45
46    fn transform_agg(
47        aggr: &AggregateExec,
48        order_by: &str,
49        order_desc: bool,
50        limit: usize,
51    ) -> Option<Arc<dyn ExecutionPlan>> {
52        // ensure the sort direction matches aggregate function
53        let (field, desc) = aggr.get_minmax_desc()?;
54        if desc != order_desc {
55            return None;
56        }
57        let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
58        let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
59        if !kt.is_primitive()
60            && kt != DataType::Utf8
61            && kt != DataType::Utf8View
62            && kt != DataType::LargeUtf8
63        {
64            return None;
65        }
66        if aggr.filter_expr().iter().any(|e| e.is_some()) {
67            return None;
68        }
69
70        // ensure the sort is on the same field as the aggregate output
71        if order_by != field.name() {
72            return None;
73        }
74
75        // We found what we want: clone, copy the limit down, and return modified node
76        let new_aggr = AggregateExec::try_new(
77            *aggr.mode(),
78            aggr.group_expr().clone(),
79            aggr.aggr_expr().to_vec(),
80            aggr.filter_expr().to_vec(),
81            Arc::clone(aggr.input()),
82            aggr.input_schema(),
83        )
84        .expect("Unable to copy Aggregate!")
85        .with_limit(Some(limit));
86        Some(Arc::new(new_aggr))
87    }
88
89    fn transform_sort(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90        let sort = plan.as_any().downcast_ref::<SortExec>()?;
91
92        let children = sort.children();
93        let child = children.into_iter().exactly_one().ok()?;
94        let order = sort.properties().output_ordering()?;
95        let order = order.iter().exactly_one().ok()?;
96        let order_desc = order.options.descending;
97        let order = order.expr.as_any().downcast_ref::<Column>()?;
98        let mut cur_col_name = order.name().to_string();
99        let limit = sort.fetch()?;
100
101        let mut cardinality_preserved = true;
102        let closure = |plan: Arc<dyn ExecutionPlan>| {
103            if !cardinality_preserved {
104                return Ok(Transformed::no(plan));
105            }
106            if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
107                // either we run into an Aggregate and transform it
108                match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
109                    None => cardinality_preserved = false,
110                    Some(plan) => return Ok(Transformed::yes(plan)),
111                }
112            } else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
113                // track renames due to successive projections
114                for (src_expr, proj_name) in proj.expr() {
115                    let Some(src_col) = src_expr.as_any().downcast_ref::<Column>() else {
116                        continue;
117                    };
118                    if *proj_name == cur_col_name {
119                        cur_col_name = src_col.name().to_string();
120                    }
121                }
122            } else {
123                // or we continue down through types that don't reduce cardinality
124                match plan.cardinality_effect() {
125                    CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
126                    CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
127                        cardinality_preserved = false;
128                    }
129                }
130            }
131            Ok(Transformed::no(plan))
132        };
133        let child = Arc::clone(child).transform_down(closure).data().ok()?;
134        let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child)
135            .with_fetch(sort.fetch())
136            .with_preserve_partitioning(sort.preserve_partitioning());
137        Some(Arc::new(sort))
138    }
139}
140
141impl Default for TopKAggregation {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl PhysicalOptimizerRule for TopKAggregation {
148    fn optimize(
149        &self,
150        plan: Arc<dyn ExecutionPlan>,
151        config: &ConfigOptions,
152    ) -> Result<Arc<dyn ExecutionPlan>> {
153        if config.optimizer.enable_topk_aggregation {
154            plan.transform_down(|plan| {
155                Ok(if let Some(plan) = TopKAggregation::transform_sort(&plan) {
156                    Transformed::yes(plan)
157                } else {
158                    Transformed::no(plan)
159                })
160            })
161            .data()
162        } else {
163            Ok(plan)
164        }
165    }
166
167    fn name(&self) -> &str {
168        "LimitAggregation"
169    }
170
171    fn schema_check(&self) -> bool {
172        true
173    }
174}
175
176// see `aggregate.slt` for tests