Skip to main content

datafusion_physical_optimizer/
topk_repartition.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Push TopK (Sort with fetch) past Hash Repartition
19//!
20//! When a `SortExec` with a fetch limit (TopK) sits above a
21//! `RepartitionExec(Hash)`, and the hash partition expressions are a prefix
22//! of the sort expressions, this rule inserts a copy of the TopK below
23//! the repartition to reduce the volume of data flowing through the shuffle.
24//!
25//! This is correct because the hash partition key being a prefix of the sort
26//! key guarantees that all rows with the same partition key end up in the same
27//! output partition. Therefore, rows that survive the final TopK after
28//! repartitioning will always survive the pre-repartition TopK as well.
29//!
30//! ## Example
31//!
32//! Before:
33//! ```text
34//! SortExec: TopK(fetch=3), expr=[a ASC, b ASC]
35//!   RepartitionExec: Hash([a], 4)
36//!     DataSourceExec
37//! ```
38//!
39//! After:
40//! ```text
41//! SortExec: TopK(fetch=3), expr=[a ASC, b ASC]
42//!   RepartitionExec: Hash([a], 4)
43//!     SortExec: TopK(fetch=3), expr=[a ASC, b ASC]
44//!       DataSourceExec
45//! ```
46
47use crate::PhysicalOptimizerRule;
48use datafusion_common::Result;
49use datafusion_common::config::ConfigOptions;
50use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
51use std::sync::Arc;
52// CoalesceBatchesExec is deprecated on main (replaced by arrow-rs BatchCoalescer),
53// but older DataFusion versions may still insert it between SortExec and RepartitionExec.
54#[expect(deprecated)]
55use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
56use datafusion_physical_plan::repartition::RepartitionExec;
57use datafusion_physical_plan::sorts::sort::SortExec;
58use datafusion_physical_plan::{ExecutionPlan, Partitioning};
59
60/// A physical optimizer rule that pushes TopK (Sort with fetch) past
61/// hash repartition when the partition key is a prefix of the sort key.
62///
63/// See module-level documentation for details.
64#[derive(Debug, Clone, Default)]
65pub struct TopKRepartition;
66
67impl TopKRepartition {
68    pub fn new() -> Self {
69        Self {}
70    }
71}
72
73impl PhysicalOptimizerRule for TopKRepartition {
74    #[expect(deprecated)] // CoalesceBatchesExec: kept for older DataFusion versions
75    fn optimize(
76        &self,
77        plan: Arc<dyn ExecutionPlan>,
78        config: &ConfigOptions,
79    ) -> Result<Arc<dyn ExecutionPlan>> {
80        if !config.optimizer.enable_topk_repartition {
81            return Ok(plan);
82        }
83        plan.transform_down(|node| {
84            // Match SortExec with fetch (TopK)
85            let Some(sort_exec) = node.downcast_ref::<SortExec>() else {
86                return Ok(Transformed::no(node));
87            };
88            let Some(fetch) = sort_exec.fetch() else {
89                return Ok(Transformed::no(node));
90            };
91
92            // The child might be a CoalesceBatchesExec; look through it
93            let sort_input = sort_exec.input();
94            let (repart_parent, repart_exec) = if let Some(rp) =
95                sort_input.downcast_ref::<RepartitionExec>()
96            {
97                // found a RepartitionExec, use it
98                (None, rp)
99            } else if let Some(cb_exec) = sort_input.downcast_ref::<CoalesceBatchesExec>()
100            {
101                // There's a CoalesceBatchesExec between TopK & RepartitionExec
102                // in this case we will need to reconstruct both nodes
103                let cb_input = cb_exec.input();
104                let Some(rp) = cb_input.downcast_ref::<RepartitionExec>() else {
105                    return Ok(Transformed::no(node));
106                };
107                (Some(Arc::clone(sort_input)), rp)
108            } else {
109                return Ok(Transformed::no(node));
110            };
111
112            // Only handle Hash partitioning
113            let Partitioning::Hash(hash_exprs, num_partitions) =
114                repart_exec.partitioning()
115            else {
116                return Ok(Transformed::no(node));
117            };
118
119            let sort_exprs = sort_exec.expr();
120
121            // Check that hash expressions are a prefix of the sort expressions.
122            // Each hash expression must match the corresponding sort expression
123            // (ignoring sort options like ASC/DESC since hash doesn't care about order).
124            if hash_exprs.len() > sort_exprs.len() {
125                return Ok(Transformed::no(node));
126            }
127            for (hash_expr, sort_expr) in hash_exprs.iter().zip(sort_exprs.iter()) {
128                if !hash_expr.eq(&sort_expr.expr) {
129                    return Ok(Transformed::no(node));
130                }
131            }
132
133            // Don't push if the input to the repartition is already bounded
134            // (e.g., another TopK), as it would be redundant.
135            let repart_input = repart_exec.input();
136            if repart_input.is::<SortExec>() {
137                return Ok(Transformed::no(node));
138            }
139
140            // Insert a copy of the TopK below the repartition
141            let new_sort: Arc<dyn ExecutionPlan> = Arc::new(
142                SortExec::new(sort_exprs.clone(), Arc::clone(repart_input))
143                    .with_fetch(Some(fetch))
144                    .with_preserve_partitioning(sort_exec.preserve_partitioning()),
145            );
146
147            let new_partitioning =
148                Partitioning::Hash(hash_exprs.clone(), *num_partitions);
149            let new_repartition: Arc<dyn ExecutionPlan> =
150                Arc::new(RepartitionExec::try_new(new_sort, new_partitioning)?);
151
152            // Rebuild the tree above the repartition
153            let new_sort_input = if let Some(parent) = repart_parent {
154                parent.with_new_children(vec![new_repartition])?
155            } else {
156                new_repartition
157            };
158
159            let new_top_sort: Arc<dyn ExecutionPlan> = Arc::new(
160                SortExec::new(sort_exprs.clone(), new_sort_input)
161                    .with_fetch(Some(fetch))
162                    .with_preserve_partitioning(sort_exec.preserve_partitioning()),
163            );
164
165            Ok(Transformed::yes(new_top_sort))
166        })
167        .data()
168    }
169
170    fn name(&self) -> &str {
171        "TopKRepartition"
172    }
173
174    fn schema_check(&self) -> bool {
175        true
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use arrow::datatypes::{DataType, Field, Schema};
183    use datafusion_physical_expr::expressions::col;
184    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
185    use datafusion_physical_plan::displayable;
186    use datafusion_physical_plan::test::scan_partitioned;
187    use insta::assert_snapshot;
188
189    fn schema() -> Arc<Schema> {
190        Arc::new(Schema::new(vec![
191            Field::new("a", DataType::Utf8, false),
192            Field::new("b", DataType::Int64, false),
193        ]))
194    }
195
196    fn sort_exprs(schema: &Schema) -> LexOrdering {
197        LexOrdering::new(vec![
198            PhysicalSortExpr::new_default(col("a", schema).unwrap()).asc(),
199            PhysicalSortExpr::new_default(col("b", schema).unwrap()).asc(),
200        ])
201        .unwrap()
202    }
203
204    /// TopK above Hash(a) repartition should get pushed below it,
205    /// because `a` is a prefix of the sort key `(a, b)`.
206    #[test]
207    fn topk_pushed_below_hash_repartition() {
208        let s = schema();
209        let input = scan_partitioned(1);
210        let ordering = sort_exprs(&s);
211
212        let repartition = Arc::new(
213            RepartitionExec::try_new(
214                input,
215                Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
216            )
217            .unwrap(),
218        );
219
220        let sort = Arc::new(
221            SortExec::new(ordering, repartition)
222                .with_fetch(Some(3))
223                .with_preserve_partitioning(true),
224        );
225
226        let config = ConfigOptions::new();
227        let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
228
229        let display = displayable(optimized.as_ref()).indent(true).to_string();
230        assert_snapshot!(display, @r"
231        SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC]
232          RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true
233            SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
234              DataSourceExec: partitions=1, partition_sizes=[1]
235        ");
236    }
237
238    /// TopK with no fetch (unbounded sort) should NOT be pushed.
239    #[test]
240    fn unbounded_sort_not_pushed() {
241        let s = schema();
242        let input = scan_partitioned(1);
243        let ordering = sort_exprs(&s);
244
245        let repartition = Arc::new(
246            RepartitionExec::try_new(
247                input,
248                Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
249            )
250            .unwrap(),
251        );
252
253        let sort: Arc<dyn ExecutionPlan> = Arc::new(
254            SortExec::new(ordering, repartition).with_preserve_partitioning(true),
255        );
256
257        let config = ConfigOptions::new();
258        let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
259
260        let display = displayable(optimized.as_ref()).indent(true).to_string();
261        assert_snapshot!(display, @r"
262        SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
263          RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
264            DataSourceExec: partitions=1, partition_sizes=[1]
265        ");
266    }
267
268    /// Hash key NOT a prefix of sort key should NOT be pushed.
269    #[test]
270    fn non_prefix_hash_key_not_pushed() {
271        let s = schema();
272        let input = scan_partitioned(1);
273        let ordering = sort_exprs(&s);
274
275        // Hash by `b`, but sort by `(a, b)` - b is not a prefix
276        let repartition = Arc::new(
277            RepartitionExec::try_new(
278                input,
279                Partitioning::Hash(vec![col("b", &s).unwrap()], 4),
280            )
281            .unwrap(),
282        );
283
284        let sort: Arc<dyn ExecutionPlan> = Arc::new(
285            SortExec::new(ordering, repartition)
286                .with_fetch(Some(3))
287                .with_preserve_partitioning(true),
288        );
289
290        let config = ConfigOptions::new();
291        let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
292
293        let display = displayable(optimized.as_ref()).indent(true).to_string();
294        assert_snapshot!(display, @r"
295        SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
296          RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=1
297            DataSourceExec: partitions=1, partition_sizes=[1]
298        ");
299    }
300
301    /// TopK above CoalesceBatchesExec above Hash(a) repartition should
302    /// push through both, inserting a new TopK below the repartition.
303    #[expect(deprecated)]
304    #[test]
305    fn topk_pushed_through_coalesce_batches() {
306        let s = schema();
307        let input = scan_partitioned(1);
308        let ordering = sort_exprs(&s);
309
310        let repartition = Arc::new(
311            RepartitionExec::try_new(
312                input,
313                Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
314            )
315            .unwrap(),
316        );
317
318        let coalesce: Arc<dyn ExecutionPlan> =
319            Arc::new(CoalesceBatchesExec::new(repartition, 8192));
320
321        let sort = Arc::new(
322            SortExec::new(ordering, coalesce)
323                .with_fetch(Some(3))
324                .with_preserve_partitioning(true),
325        );
326
327        let config = ConfigOptions::new();
328        let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
329
330        let display = displayable(optimized.as_ref()).indent(true).to_string();
331        assert_snapshot!(display, @r"
332        SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC]
333          CoalesceBatchesExec: target_batch_size=8192
334            RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true
335              SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
336                DataSourceExec: partitions=1, partition_sizes=[1]
337        ");
338    }
339
340    /// RoundRobin repartition should NOT be pushed.
341    #[test]
342    fn round_robin_not_pushed() {
343        let s = schema();
344        let input = scan_partitioned(1);
345        let ordering = sort_exprs(&s);
346
347        let repartition = Arc::new(
348            RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(4)).unwrap(),
349        );
350
351        let sort: Arc<dyn ExecutionPlan> = Arc::new(
352            SortExec::new(ordering, repartition)
353                .with_fetch(Some(3))
354                .with_preserve_partitioning(true),
355        );
356
357        let config = ConfigOptions::new();
358        let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
359
360        let display = displayable(optimized.as_ref()).indent(true).to_string();
361        assert_snapshot!(display, @r"
362        SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
363          RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
364            DataSourceExec: partitions=1, partition_sizes=[1]
365        ");
366    }
367}