datafusion_physical_optimizer/
topk_aggregation.rs1use std::sync::Arc;
21
22use crate::PhysicalOptimizerRule;
23use arrow::datatypes::DataType;
24use datafusion_common::config::ConfigOptions;
25use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
26use datafusion_common::Result;
27use datafusion_physical_expr::expressions::Column;
28use datafusion_physical_expr::LexOrdering;
29use datafusion_physical_plan::aggregates::AggregateExec;
30use datafusion_physical_plan::execution_plan::CardinalityEffect;
31use datafusion_physical_plan::projection::ProjectionExec;
32use datafusion_physical_plan::sorts::sort::SortExec;
33use datafusion_physical_plan::ExecutionPlan;
34use itertools::Itertools;
35
36#[derive(Debug)]
38pub struct TopKAggregation {}
39
40impl TopKAggregation {
41 pub fn new() -> Self {
43 Self {}
44 }
45
46 fn transform_agg(
47 aggr: &AggregateExec,
48 order_by: &str,
49 order_desc: bool,
50 limit: usize,
51 ) -> Option<Arc<dyn ExecutionPlan>> {
52 let (field, desc) = aggr.get_minmax_desc()?;
54 if desc != order_desc {
55 return None;
56 }
57 let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
58 let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
59 if !kt.is_primitive()
60 && kt != DataType::Utf8
61 && kt != DataType::Utf8View
62 && kt != DataType::LargeUtf8
63 {
64 return None;
65 }
66 if aggr.filter_expr().iter().any(|e| e.is_some()) {
67 return None;
68 }
69
70 if order_by != field.name() {
72 return None;
73 }
74
75 let new_aggr = AggregateExec::try_new(
77 *aggr.mode(),
78 aggr.group_expr().clone(),
79 aggr.aggr_expr().to_vec(),
80 aggr.filter_expr().to_vec(),
81 Arc::clone(aggr.input()),
82 aggr.input_schema(),
83 )
84 .expect("Unable to copy Aggregate!")
85 .with_limit(Some(limit));
86 Some(Arc::new(new_aggr))
87 }
88
89 fn transform_sort(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90 let sort = plan.as_any().downcast_ref::<SortExec>()?;
91
92 let children = sort.children();
93 let child = children.into_iter().exactly_one().ok()?;
94 let order = sort.properties().output_ordering()?;
95 let order = order.iter().exactly_one().ok()?;
96 let order_desc = order.options.descending;
97 let order = order.expr.as_any().downcast_ref::<Column>()?;
98 let mut cur_col_name = order.name().to_string();
99 let limit = sort.fetch()?;
100
101 let mut cardinality_preserved = true;
102 let closure = |plan: Arc<dyn ExecutionPlan>| {
103 if !cardinality_preserved {
104 return Ok(Transformed::no(plan));
105 }
106 if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
107 match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
109 None => cardinality_preserved = false,
110 Some(plan) => return Ok(Transformed::yes(plan)),
111 }
112 } else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
113 for (src_expr, proj_name) in proj.expr() {
115 let Some(src_col) = src_expr.as_any().downcast_ref::<Column>() else {
116 continue;
117 };
118 if *proj_name == cur_col_name {
119 cur_col_name = src_col.name().to_string();
120 }
121 }
122 } else {
123 match plan.cardinality_effect() {
125 CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
126 CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
127 cardinality_preserved = false;
128 }
129 }
130 }
131 Ok(Transformed::no(plan))
132 };
133 let child = Arc::clone(child).transform_down(closure).data().ok()?;
134 let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child)
135 .with_fetch(sort.fetch())
136 .with_preserve_partitioning(sort.preserve_partitioning());
137 Some(Arc::new(sort))
138 }
139}
140
141impl Default for TopKAggregation {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl PhysicalOptimizerRule for TopKAggregation {
148 fn optimize(
149 &self,
150 plan: Arc<dyn ExecutionPlan>,
151 config: &ConfigOptions,
152 ) -> Result<Arc<dyn ExecutionPlan>> {
153 if config.optimizer.enable_topk_aggregation {
154 plan.transform_down(|plan| {
155 Ok(if let Some(plan) = TopKAggregation::transform_sort(&plan) {
156 Transformed::yes(plan)
157 } else {
158 Transformed::no(plan)
159 })
160 })
161 .data()
162 } else {
163 Ok(plan)
164 }
165 }
166
167 fn name(&self) -> &str {
168 "LimitAggregation"
169 }
170
171 fn schema_check(&self) -> bool {
172 true
173 }
174}
175
176