datafusion_physical_optimizer/
topk_aggregation.rs1use std::sync::Arc;
21
22use crate::PhysicalOptimizerRule;
23use datafusion_common::Result;
24use datafusion_common::config::ConfigOptions;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_physical_expr::expressions::Column;
27use datafusion_physical_plan::ExecutionPlan;
28use datafusion_physical_plan::aggregates::LimitOptions;
29use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported};
30use datafusion_physical_plan::execution_plan::CardinalityEffect;
31use datafusion_physical_plan::projection::ProjectionExec;
32use datafusion_physical_plan::sorts::sort::SortExec;
33use itertools::Itertools;
34
35#[derive(Debug)]
37pub struct TopKAggregation {}
38
39impl TopKAggregation {
40 pub fn new() -> Self {
42 Self {}
43 }
44
45 fn transform_agg(
46 aggr: &AggregateExec,
47 order_by: &str,
48 order_desc: bool,
49 limit: usize,
50 ) -> Option<Arc<dyn ExecutionPlan>> {
51 let (group_key, group_key_alias) =
53 aggr.group_expr().expr().iter().exactly_one().ok()?;
54 let kt = group_key.data_type(&aggr.input().schema()).ok()?;
55 let vt = if let Some((field, _)) = aggr.get_minmax_desc() {
56 field.data_type().clone()
57 } else {
58 kt.clone()
59 };
60 if !topk_types_supported(&kt, &vt) {
61 return None;
62 }
63 if aggr.filter_expr().iter().any(|e| e.is_some()) {
64 return None;
65 }
66
67 if let Some((field, desc)) = aggr.get_minmax_desc() {
69 if desc != order_desc {
71 return None;
72 }
73 if order_by != field.name() {
75 return None;
76 }
77 } else if aggr.aggr_expr().is_empty() {
78 if order_by != group_key_alias {
80 return None;
81 }
82 } else {
83 return None;
85 }
86
87 let new_aggr = AggregateExec::with_new_limit_options(
89 aggr,
90 Some(LimitOptions::new_with_order(limit, order_desc)),
91 );
92 Some(Arc::new(new_aggr))
93 }
94
95 fn transform_sort(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
96 let sort = plan.as_any().downcast_ref::<SortExec>()?;
97
98 let children = sort.children();
99 let child = children.into_iter().exactly_one().ok()?;
100 let order = sort.properties().output_ordering()?;
101 let order = order.iter().exactly_one().ok()?;
102 let order_desc = order.options.descending;
103 let order = order.expr.as_any().downcast_ref::<Column>()?;
104 let mut cur_col_name = order.name().to_string();
105 let limit = sort.fetch()?;
106
107 let mut cardinality_preserved = true;
108 let closure = |plan: Arc<dyn ExecutionPlan>| {
109 if !cardinality_preserved {
110 return Ok(Transformed::no(plan));
111 }
112 if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
113 match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
115 None => cardinality_preserved = false,
116 Some(plan) => return Ok(Transformed::yes(plan)),
117 }
118 } else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
119 for proj_expr in proj.expr() {
121 let Some(src_col) = proj_expr.expr.as_any().downcast_ref::<Column>()
122 else {
123 continue;
124 };
125 if proj_expr.alias == cur_col_name {
126 cur_col_name = src_col.name().to_string();
127 }
128 }
129 } else {
130 match plan.cardinality_effect() {
132 CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
133 CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
134 cardinality_preserved = false;
135 }
136 }
137 }
138 Ok(Transformed::no(plan))
139 };
140 let child = Arc::clone(child).transform_down(closure).data().ok()?;
141 let sort = SortExec::new(sort.expr().clone(), child)
142 .with_fetch(sort.fetch())
143 .with_preserve_partitioning(sort.preserve_partitioning());
144 Some(Arc::new(sort))
145 }
146}
147
148impl Default for TopKAggregation {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl PhysicalOptimizerRule for TopKAggregation {
155 fn optimize(
156 &self,
157 plan: Arc<dyn ExecutionPlan>,
158 config: &ConfigOptions,
159 ) -> Result<Arc<dyn ExecutionPlan>> {
160 if config.optimizer.enable_topk_aggregation {
161 plan.transform_down(|plan| {
162 Ok(if let Some(plan) = TopKAggregation::transform_sort(&plan) {
163 Transformed::yes(plan)
164 } else {
165 Transformed::no(plan)
166 })
167 })
168 .data()
169 } else {
170 Ok(plan)
171 }
172 }
173
174 fn name(&self) -> &str {
175 "LimitAggregation"
176 }
177
178 fn schema_check(&self) -> bool {
179 true
180 }
181}
182
183