Skip to main content

datafusion_physical_plan/sorts/
partitioned_topk.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PartitionedTopKExec`]: Top-K per partition operator
19//!
20//! For queries like:
21//! ```sql
22//! SELECT *, ROW_NUMBER() OVER (PARTITION BY pk ORDER BY val) as rn
23//! FROM t WHERE rn <= N
24//! ```
25//!
26//! Instead of sorting the entire dataset, this operator maintains a
27//! [`TopK`] heap per partition (reusing the existing TopK implementation)
28//! and emits only the top-K rows per partition in sorted order
29//! `(partition_keys, order_keys)`.
30
31use std::fmt::{self, Formatter};
32use std::sync::Arc;
33
34use arrow::array::{RecordBatch, UInt32Array};
35use arrow::compute::{BatchCoalescer, take_record_batch};
36use arrow::datatypes::SchemaRef;
37use arrow::row::{OwnedRow, RowConverter};
38use datafusion_common::{HashMap, Result};
39use datafusion_execution::TaskContext;
40use datafusion_physical_expr::PhysicalExpr;
41use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
42use datafusion_physical_expr_common::sort_expr::LexOrdering;
43use futures::StreamExt;
44use futures::TryStreamExt;
45use parking_lot::RwLock;
46
47use crate::execution_plan::{Boundedness, EmissionType};
48use crate::metrics::ExecutionPlanMetricsSet;
49use crate::topk::{TopK, TopKDynamicFilters, build_sort_fields};
50use crate::{
51    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
52    PlanProperties, SendableRecordBatchStream, stream::RecordBatchStreamAdapter,
53};
54
55/// Per-partition Top-K operator for window function queries.
56///
57/// # Background
58///
59/// "Top K per partition" is a common analytics pattern used for queries such as
60/// "find the top 3 products by revenue for each store". The (simplified) SQL
61/// for such a query might be:
62///
63/// ```sql
64/// SELECT * FROM (
65///     SELECT *, ROW_NUMBER() OVER (PARTITION BY store ORDER BY revenue DESC) as rn
66///     FROM sales
67/// ) WHERE rn <= 3;
68/// ```
69///
70/// The unoptimized physical plan would be:
71///
72/// ```text
73/// FilterExec: rn <= 3
74///   BoundedWindowAggExec: ROW_NUMBER() PARTITION BY [store] ORDER BY [revenue DESC]
75///     SortExec: expr=[store ASC, revenue DESC]
76///       DataSourceExec
77/// ```
78///
79/// This plan sorts the **entire** dataset (O(N log N)), computes `ROW_NUMBER`
80/// for **all** rows, and then filters to keep only the top K per partition.
81/// With 10M rows, 1K partitions, and K=3, it sorts all 10M rows but only
82/// keeps 3K.
83///
84/// # Optimization
85///
86/// `PartitionedTopKExec` replaces the `SortExec` and the `FilterExec` is
87/// removed. The optimized plan becomes:
88///
89/// ```text
90/// BoundedWindowAggExec: ROW_NUMBER() PARTITION BY [store] ORDER BY [revenue DESC]
91///   PartitionedTopKExec: fetch=3, partition=[store], order=[revenue DESC]
92///     DataSourceExec
93/// ```
94///
95/// Instead of sorting the entire dataset, this operator reads unsorted input,
96/// maintains a [`TopK`] heap per distinct partition key, and emits only the
97/// top-K rows per partition in sorted order `(partition_keys, order_keys)`.
98///
99/// Cost: O(N log K) time instead of O(N log N), and O(K × P × row_size)
100/// memory where K = fetch, P = number of distinct partitions.
101/// ## Why maintaining partition key order in output
102/// Window functions do not require partition keys to be globally sorted, and
103/// enforcing such ordering in the output can introduce unnecessary overhead.
104/// However, the physical optimizer framework currently cannot express an
105/// ordering that is only grouped by some keys while ordered by others. For
106/// example:
107///
108///
109/// # Example
110///
111/// For the query above with `fetch=3` and input:
112///
113/// ```text
114/// store | revenue
115/// ------|--------
116///   A   |  100
117///   B   |   50
118///   A   |  200
119///   B   |  150
120///   A   |  300
121///   A   |  400
122/// ```
123///
124/// The operator maintains two heaps:
125/// - **store=A**: keeps top-3 by revenue DESC → {400, 300, 200}, evicts 100
126/// - **store=B**: keeps top-3 by revenue DESC → {150, 50} (only 2 rows)
127///
128/// Output (sorted by store ASC, revenue DESC):
129///
130/// ```text
131/// store | revenue
132/// ------|--------
133///   A   |  400
134///   A   |  300
135///   A   |  200
136///   B   |  150
137///   B   |   50
138/// ```
139///
140/// This is then passed to `BoundedWindowAggExec` which assigns
141/// `ROW_NUMBER` 1, 2, 3 to each partition — all of which satisfy `rn <= 3`.
142///
143/// # Limitations
144///
145/// - Only activated when the window function is `ROW_NUMBER` with a
146///   `PARTITION BY` clause. Global top-K (no `PARTITION BY`) is already
147///   handled efficiently by `SortExec` with `fetch`.
148/// - For very high cardinality partition keys (millions of distinct values),
149///   both memory usage and runtime overhead can become significant. In such
150///   cases, the sort-based plan is more robust. Therefore, this optimization
151///   is currently controlled by a configuration flag.
152#[derive(Debug, Clone)]
153pub struct PartitionedTopKExec {
154    /// Input execution plan (reads unsorted data)
155    input: Arc<dyn ExecutionPlan>,
156    /// Full sort expressions: `[partition_keys..., order_keys...]`.
157    ///
158    /// For `PARTITION BY store ORDER BY revenue DESC` with sort
159    /// `[store ASC, revenue DESC]`, the first `partition_prefix_len`
160    /// expressions are the partition keys (`[store ASC]`) and the
161    /// remaining are the order-by keys (`[revenue DESC]`).
162    expr: LexOrdering,
163    /// Number of leading expressions in `expr` that define the partition
164    /// key. For example, `PARTITION BY a, b` → `partition_prefix_len = 2`.
165    partition_prefix_len: usize,
166    /// Maximum number of rows to keep per partition (the K in "top-K").
167    /// Derived from the filter predicate: `rn <= 3` → `fetch = 3`,
168    /// `rn < 3` → `fetch = 2`.
169    fetch: usize,
170    /// Execution metrics
171    metrics_set: ExecutionPlanMetricsSet,
172    /// Cached plan properties (output ordering, partitioning, etc.)
173    cache: Arc<PlanProperties>,
174}
175
176impl PartitionedTopKExec {
177    /// Create a new `PartitionedTopKExec`.
178    ///
179    /// # Arguments
180    ///
181    /// * `input` - The child execution plan providing unsorted input rows.
182    /// * `expr` - Full sort ordering `[partition_keys..., order_keys...]`.
183    ///   For `PARTITION BY pk ORDER BY val ASC`, this would be `[pk ASC, val ASC]`.
184    /// * `partition_prefix_len` - Number of leading expressions in `expr`
185    ///   that form the partition key. Must be >= 1.
186    /// * `fetch` - Maximum rows to retain per partition (the K in "top-K").
187    ///
188    /// # Example
189    ///
190    /// ```text
191    /// // For: ROW_NUMBER() OVER (PARTITION BY store ORDER BY revenue DESC) ... WHERE rn <= 5
192    /// PartitionedTopKExec::try_new(
193    ///     data_source,
194    ///     LexOrdering([store ASC, revenue DESC]),
195    ///     1,    // partition_prefix_len: 1 partition column (store)
196    ///     5,    // fetch: keep top 5 per partition
197    /// )
198    /// ```
199    pub fn try_new(
200        input: Arc<dyn ExecutionPlan>,
201        expr: LexOrdering,
202        partition_prefix_len: usize,
203        fetch: usize,
204    ) -> Result<Self> {
205        let cache = Self::compute_properties(&input, expr.clone())?;
206        Ok(Self {
207            input,
208            expr,
209            partition_prefix_len,
210            fetch,
211            metrics_set: ExecutionPlanMetricsSet::new(),
212            cache: Arc::new(cache),
213        })
214    }
215
216    /// Returns the child execution plan.
217    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
218        &self.input
219    }
220
221    /// Returns the full sort ordering `[partition_keys..., order_keys...]`.
222    pub fn expr(&self) -> &LexOrdering {
223        &self.expr
224    }
225
226    /// Returns the number of leading expressions in [`Self::expr`] that
227    /// define the partition key.
228    pub fn partition_prefix_len(&self) -> usize {
229        self.partition_prefix_len
230    }
231
232    /// Returns the maximum number of rows retained per partition.
233    pub fn fetch(&self) -> usize {
234        self.fetch
235    }
236
237    /// Compute [`PlanProperties`] for this operator.
238    ///
239    /// The output is sorted by `sort_exprs` (partition keys then order keys),
240    /// uses the same partitioning as the input, emits all output at once
241    /// (`EmissionType::Final`), and is bounded.
242    fn compute_properties(
243        input: &Arc<dyn ExecutionPlan>,
244        sort_exprs: LexOrdering,
245    ) -> Result<PlanProperties> {
246        let mut eq_properties = input.equivalence_properties().clone();
247        eq_properties.reorder(sort_exprs)?;
248
249        Ok(PlanProperties::new(
250            eq_properties,
251            input.output_partitioning().clone(),
252            EmissionType::Final,
253            Boundedness::Bounded,
254        ))
255    }
256}
257
258impl DisplayAs for PartitionedTopKExec {
259    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result {
260        match t {
261            DisplayFormatType::Default | DisplayFormatType::Verbose => {
262                let partition_exprs: Vec<String> = self.expr[..self.partition_prefix_len]
263                    .iter()
264                    .map(|e| format!("{}", e.expr))
265                    .collect();
266                let order_exprs: Vec<String> = self.expr[self.partition_prefix_len..]
267                    .iter()
268                    .map(|e| format!("{e}"))
269                    .collect();
270                write!(
271                    f,
272                    "PartitionedTopKExec: fetch={}, partition=[{}], order=[{}]",
273                    self.fetch,
274                    partition_exprs.join(", "),
275                    order_exprs.join(", "),
276                )
277            }
278            DisplayFormatType::TreeRender => {
279                let partition_exprs: Vec<String> = self.expr[..self.partition_prefix_len]
280                    .iter()
281                    .map(|e| format!("{}", e.expr))
282                    .collect();
283                let order_exprs: Vec<String> = self.expr[self.partition_prefix_len..]
284                    .iter()
285                    .map(|e| format!("{e}"))
286                    .collect();
287                writeln!(f, "fetch={}", self.fetch)?;
288                writeln!(f, "partition=[{}]", partition_exprs.join(", "))?;
289                writeln!(f, "order=[{}]", order_exprs.join(", "))
290            }
291        }
292    }
293}
294
295impl ExecutionPlan for PartitionedTopKExec {
296    fn name(&self) -> &'static str {
297        "PartitionedTopKExec"
298    }
299
300    fn properties(&self) -> &Arc<PlanProperties> {
301        &self.cache
302    }
303
304    fn required_input_distribution(&self) -> Vec<Distribution> {
305        let partition_exprs: Vec<Arc<dyn PhysicalExpr>> = self.expr
306            [..self.partition_prefix_len]
307            .iter()
308            .map(|e| Arc::clone(&e.expr))
309            .collect();
310        vec![Distribution::HashPartitioned(partition_exprs)]
311    }
312
313    fn maintains_input_order(&self) -> Vec<bool> {
314        vec![false]
315    }
316
317    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
318        vec![&self.input]
319    }
320
321    fn with_new_children(
322        self: Arc<Self>,
323        children: Vec<Arc<dyn ExecutionPlan>>,
324    ) -> Result<Arc<dyn ExecutionPlan>> {
325        assert_eq!(children.len(), 1);
326        Ok(Arc::new(PartitionedTopKExec::try_new(
327            Arc::clone(&children[0]),
328            self.expr.clone(),
329            self.partition_prefix_len,
330            self.fetch,
331        )?))
332    }
333
334    fn execute(
335        &self,
336        partition: usize,
337        context: Arc<TaskContext>,
338    ) -> Result<SendableRecordBatchStream> {
339        let input = self.input.execute(partition, Arc::clone(&context))?;
340        let schema = input.schema();
341
342        let partition_sort_fields =
343            build_sort_fields(&self.expr[..self.partition_prefix_len], &schema)?;
344
345        let partition_converter = RowConverter::new(partition_sort_fields)?;
346
347        let partition_exprs: Vec<Arc<dyn PhysicalExpr>> = self.expr
348            [..self.partition_prefix_len]
349            .iter()
350            .map(|e| Arc::clone(&e.expr))
351            .collect();
352        let order_expr: LexOrdering =
353            LexOrdering::new(self.expr[self.partition_prefix_len..].iter().cloned())
354                .expect("PartitionedTopKExec requires at least one order-by expression");
355        let fetch = self.fetch;
356        let batch_size = context.session_config().batch_size();
357        let runtime = Arc::clone(&context.runtime_env());
358        let metrics_set = self.metrics_set.clone();
359
360        let stream = futures::stream::once(async move {
361            do_partitioned_topk(
362                input,
363                schema,
364                partition_converter,
365                partition_exprs,
366                order_expr,
367                fetch,
368                batch_size,
369                runtime,
370                metrics_set,
371            )
372            .await
373        })
374        .try_flatten();
375
376        Ok(Box::pin(RecordBatchStreamAdapter::new(
377            self.input.schema(),
378            stream,
379        )))
380    }
381}
382
383/// Create a no-op [`TopKDynamicFilters`] for a per-partition [`TopK`].
384///
385/// In normal `SortExec` top-K mode, dynamic filters push predicates down to
386/// the data source (e.g., telling Parquet to skip rows worse than the current
387/// K-th best). For per-partition heaps the data is already in memory and split
388/// by partition key, so there is no data source to push filters to. We pass
389/// `lit(true)` (accept everything) so the filter never rejects any row.
390fn create_noop_dynamic_filter() -> Arc<RwLock<TopKDynamicFilters>> {
391    Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
392        DynamicFilterPhysicalExpr::new(vec![], lit(true)),
393    ))))
394}
395
396/// Read all input, split batches by partition key, feed each sub-batch
397/// to a per-partition [`TopK`], then emit results in partition-key order.
398///
399/// # Phases
400///
401/// 1. **Accumulation** — For each input batch:
402///    - Evaluate partition key expressions to get partition column arrays
403///    - Convert partition columns to binary [`arrow::row::Row`] format
404///    - Group row indices by partition key
405///    - Extract sub-batches via [`take_record_batch`] and insert into
406///      the partition's [`TopK`] heap
407///
408/// 2. **Emission** — After all input is consumed:
409///    - Sort partition keys so output is ordered by partition key
410///    - For each partition in sorted order, call [`TopK::emit`] to get
411///      rows sorted by order-by key
412///    - Return all batches as a single stream
413///
414/// # Cost
415///
416/// - Time: O(N log K) where N = total rows, K = fetch
417/// - Memory: O(K × P × row_size) where P = number of distinct partitions
418#[expect(clippy::too_many_arguments)]
419async fn do_partitioned_topk(
420    mut input: SendableRecordBatchStream,
421    schema: SchemaRef,
422    partition_converter: RowConverter,
423    partition_exprs: Vec<Arc<dyn PhysicalExpr>>,
424    order_expr: LexOrdering,
425    fetch: usize,
426    batch_size: usize,
427    runtime: Arc<datafusion_execution::runtime_env::RuntimeEnv>,
428    metrics_set: ExecutionPlanMetricsSet,
429) -> Result<SendableRecordBatchStream> {
430    let mut partitions: HashMap<OwnedRow, TopK> = HashMap::new();
431    let mut partition_counter: usize = 0;
432
433    // Macro-like helper: create a new TopK for a partition
434    macro_rules! new_topk {
435        () => {{
436            let id = partition_counter;
437            partition_counter += 1;
438            TopK::try_new(
439                id,
440                Arc::clone(&schema),
441                vec![],
442                order_expr.clone(),
443                fetch,
444                batch_size,
445                Arc::clone(&runtime),
446                &metrics_set,
447                create_noop_dynamic_filter(),
448            )
449        }};
450    }
451
452    // ---------- Accumulation phase ----------
453    while let Some(batch) = input.next().await {
454        let batch = batch?;
455        let num_rows = batch.num_rows();
456        if num_rows == 0 {
457            continue;
458        }
459
460        // Evaluate partition key columns
461        let pk_arrays: Vec<_> = partition_exprs
462            .iter()
463            .map(|e| e.evaluate(&batch).and_then(|v| v.into_array(num_rows)))
464            .collect::<Result<Vec<_>>>()?;
465
466        let pk_rows = partition_converter.convert_columns(&pk_arrays)?;
467
468        // Group row indices by partition key
469        let mut groups: HashMap<OwnedRow, Vec<u32>> = HashMap::new();
470        for row_idx in 0..num_rows {
471            let pk = pk_rows.row(row_idx).owned();
472            groups.entry(pk).or_default().push(row_idx as u32);
473        }
474
475        // For each partition group, create a sub-batch and feed to TopK
476        for (pk, indices) in groups {
477            if !partitions.contains_key(&pk) {
478                partitions.insert(pk.clone(), new_topk!()?);
479            }
480            let topk = partitions.get_mut(&pk).unwrap();
481            let indices_array = UInt32Array::from(indices);
482            let sub_batch = take_record_batch(&batch, &indices_array)?;
483            topk.insert_batch(sub_batch)?;
484        }
485    }
486    // Release the input pipeline now that accumulation is complete.
487    drop(input);
488
489    // ---------- Emit phase ----------
490    // Sort partition keys so output is ordered by (partition_keys, order_keys).
491    let mut sorted_pks: Vec<OwnedRow> = partitions.keys().cloned().collect();
492    sorted_pks.sort();
493
494    let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), batch_size);
495
496    for pk in sorted_pks {
497        if let Some(topk) = partitions.remove(&pk) {
498            // TopK::emit() returns a stream of sorted batches
499            let mut stream = topk.emit()?;
500            while let Some(batch) = stream.next().await {
501                coalescer.push_batch(batch?)?;
502            }
503        }
504    }
505    coalescer.finish_buffered_batch()?;
506    let mut output_batches: Vec<RecordBatch> = Vec::new();
507    while let Some(batch) = coalescer.next_completed_batch() {
508        output_batches.push(batch);
509    }
510
511    Ok(Box::pin(RecordBatchStreamAdapter::new(
512        schema,
513        futures::stream::iter(output_batches.into_iter().map(Ok)),
514    )))
515}