datafusion_physical_optimizer/
topk_repartition.rs1use crate::PhysicalOptimizerRule;
48use datafusion_common::Result;
49use datafusion_common::config::ConfigOptions;
50use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
51use std::sync::Arc;
52#[expect(deprecated)]
55use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
56use datafusion_physical_plan::repartition::RepartitionExec;
57use datafusion_physical_plan::sorts::sort::SortExec;
58use datafusion_physical_plan::{ExecutionPlan, Partitioning};
59
60#[derive(Debug, Clone, Default)]
65pub struct TopKRepartition;
66
67impl TopKRepartition {
68 pub fn new() -> Self {
69 Self {}
70 }
71}
72
73impl PhysicalOptimizerRule for TopKRepartition {
74 #[expect(deprecated)] fn optimize(
76 &self,
77 plan: Arc<dyn ExecutionPlan>,
78 config: &ConfigOptions,
79 ) -> Result<Arc<dyn ExecutionPlan>> {
80 if !config.optimizer.enable_topk_repartition {
81 return Ok(plan);
82 }
83 plan.transform_down(|node| {
84 let Some(sort_exec) = node.downcast_ref::<SortExec>() else {
86 return Ok(Transformed::no(node));
87 };
88 let Some(fetch) = sort_exec.fetch() else {
89 return Ok(Transformed::no(node));
90 };
91
92 let sort_input = sort_exec.input();
94 let (repart_parent, repart_exec) = if let Some(rp) =
95 sort_input.downcast_ref::<RepartitionExec>()
96 {
97 (None, rp)
99 } else if let Some(cb_exec) = sort_input.downcast_ref::<CoalesceBatchesExec>()
100 {
101 let cb_input = cb_exec.input();
104 let Some(rp) = cb_input.downcast_ref::<RepartitionExec>() else {
105 return Ok(Transformed::no(node));
106 };
107 (Some(Arc::clone(sort_input)), rp)
108 } else {
109 return Ok(Transformed::no(node));
110 };
111
112 let Partitioning::Hash(hash_exprs, num_partitions) =
114 repart_exec.partitioning()
115 else {
116 return Ok(Transformed::no(node));
117 };
118
119 let sort_exprs = sort_exec.expr();
120
121 if hash_exprs.len() > sort_exprs.len() {
125 return Ok(Transformed::no(node));
126 }
127 for (hash_expr, sort_expr) in hash_exprs.iter().zip(sort_exprs.iter()) {
128 if !hash_expr.eq(&sort_expr.expr) {
129 return Ok(Transformed::no(node));
130 }
131 }
132
133 let repart_input = repart_exec.input();
136 if repart_input.is::<SortExec>() {
137 return Ok(Transformed::no(node));
138 }
139
140 let new_sort: Arc<dyn ExecutionPlan> = Arc::new(
142 SortExec::new(sort_exprs.clone(), Arc::clone(repart_input))
143 .with_fetch(Some(fetch))
144 .with_preserve_partitioning(sort_exec.preserve_partitioning()),
145 );
146
147 let new_partitioning =
148 Partitioning::Hash(hash_exprs.clone(), *num_partitions);
149 let new_repartition: Arc<dyn ExecutionPlan> =
150 Arc::new(RepartitionExec::try_new(new_sort, new_partitioning)?);
151
152 let new_sort_input = if let Some(parent) = repart_parent {
154 parent.with_new_children(vec![new_repartition])?
155 } else {
156 new_repartition
157 };
158
159 let new_top_sort: Arc<dyn ExecutionPlan> = Arc::new(
160 SortExec::new(sort_exprs.clone(), new_sort_input)
161 .with_fetch(Some(fetch))
162 .with_preserve_partitioning(sort_exec.preserve_partitioning()),
163 );
164
165 Ok(Transformed::yes(new_top_sort))
166 })
167 .data()
168 }
169
170 fn name(&self) -> &str {
171 "TopKRepartition"
172 }
173
174 fn schema_check(&self) -> bool {
175 true
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use arrow::datatypes::{DataType, Field, Schema};
183 use datafusion_physical_expr::expressions::col;
184 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
185 use datafusion_physical_plan::displayable;
186 use datafusion_physical_plan::test::scan_partitioned;
187 use insta::assert_snapshot;
188
189 fn schema() -> Arc<Schema> {
190 Arc::new(Schema::new(vec![
191 Field::new("a", DataType::Utf8, false),
192 Field::new("b", DataType::Int64, false),
193 ]))
194 }
195
196 fn sort_exprs(schema: &Schema) -> LexOrdering {
197 LexOrdering::new(vec![
198 PhysicalSortExpr::new_default(col("a", schema).unwrap()).asc(),
199 PhysicalSortExpr::new_default(col("b", schema).unwrap()).asc(),
200 ])
201 .unwrap()
202 }
203
204 #[test]
207 fn topk_pushed_below_hash_repartition() {
208 let s = schema();
209 let input = scan_partitioned(1);
210 let ordering = sort_exprs(&s);
211
212 let repartition = Arc::new(
213 RepartitionExec::try_new(
214 input,
215 Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
216 )
217 .unwrap(),
218 );
219
220 let sort = Arc::new(
221 SortExec::new(ordering, repartition)
222 .with_fetch(Some(3))
223 .with_preserve_partitioning(true),
224 );
225
226 let config = ConfigOptions::new();
227 let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
228
229 let display = displayable(optimized.as_ref()).indent(true).to_string();
230 assert_snapshot!(display, @r"
231 SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC]
232 RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true
233 SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
234 DataSourceExec: partitions=1, partition_sizes=[1]
235 ");
236 }
237
238 #[test]
240 fn unbounded_sort_not_pushed() {
241 let s = schema();
242 let input = scan_partitioned(1);
243 let ordering = sort_exprs(&s);
244
245 let repartition = Arc::new(
246 RepartitionExec::try_new(
247 input,
248 Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
249 )
250 .unwrap(),
251 );
252
253 let sort: Arc<dyn ExecutionPlan> = Arc::new(
254 SortExec::new(ordering, repartition).with_preserve_partitioning(true),
255 );
256
257 let config = ConfigOptions::new();
258 let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
259
260 let display = displayable(optimized.as_ref()).indent(true).to_string();
261 assert_snapshot!(display, @r"
262 SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
263 RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
264 DataSourceExec: partitions=1, partition_sizes=[1]
265 ");
266 }
267
268 #[test]
270 fn non_prefix_hash_key_not_pushed() {
271 let s = schema();
272 let input = scan_partitioned(1);
273 let ordering = sort_exprs(&s);
274
275 let repartition = Arc::new(
277 RepartitionExec::try_new(
278 input,
279 Partitioning::Hash(vec![col("b", &s).unwrap()], 4),
280 )
281 .unwrap(),
282 );
283
284 let sort: Arc<dyn ExecutionPlan> = Arc::new(
285 SortExec::new(ordering, repartition)
286 .with_fetch(Some(3))
287 .with_preserve_partitioning(true),
288 );
289
290 let config = ConfigOptions::new();
291 let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
292
293 let display = displayable(optimized.as_ref()).indent(true).to_string();
294 assert_snapshot!(display, @r"
295 SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
296 RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=1
297 DataSourceExec: partitions=1, partition_sizes=[1]
298 ");
299 }
300
301 #[expect(deprecated)]
304 #[test]
305 fn topk_pushed_through_coalesce_batches() {
306 let s = schema();
307 let input = scan_partitioned(1);
308 let ordering = sort_exprs(&s);
309
310 let repartition = Arc::new(
311 RepartitionExec::try_new(
312 input,
313 Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
314 )
315 .unwrap(),
316 );
317
318 let coalesce: Arc<dyn ExecutionPlan> =
319 Arc::new(CoalesceBatchesExec::new(repartition, 8192));
320
321 let sort = Arc::new(
322 SortExec::new(ordering, coalesce)
323 .with_fetch(Some(3))
324 .with_preserve_partitioning(true),
325 );
326
327 let config = ConfigOptions::new();
328 let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
329
330 let display = displayable(optimized.as_ref()).indent(true).to_string();
331 assert_snapshot!(display, @r"
332 SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC]
333 CoalesceBatchesExec: target_batch_size=8192
334 RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true
335 SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
336 DataSourceExec: partitions=1, partition_sizes=[1]
337 ");
338 }
339
340 #[test]
342 fn round_robin_not_pushed() {
343 let s = schema();
344 let input = scan_partitioned(1);
345 let ordering = sort_exprs(&s);
346
347 let repartition = Arc::new(
348 RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(4)).unwrap(),
349 );
350
351 let sort: Arc<dyn ExecutionPlan> = Arc::new(
352 SortExec::new(ordering, repartition)
353 .with_fetch(Some(3))
354 .with_preserve_partitioning(true),
355 );
356
357 let config = ConfigOptions::new();
358 let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
359
360 let display = displayable(optimized.as_ref()).indent(true).to_string();
361 assert_snapshot!(display, @r"
362 SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
363 RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
364 DataSourceExec: partitions=1, partition_sizes=[1]
365 ");
366 }
367}