datafusion_physical_plan/sorts/partitioned_topk.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PartitionedTopKExec`]: Top-K per partition operator
19//!
20//! For queries like:
21//! ```sql
22//! SELECT *, ROW_NUMBER() OVER (PARTITION BY pk ORDER BY val) as rn
23//! FROM t WHERE rn <= N
24//! ```
25//!
26//! Instead of sorting the entire dataset, this operator maintains a
27//! [`TopK`] heap per partition (reusing the existing TopK implementation)
28//! and emits only the top-K rows per partition in sorted order
29//! `(partition_keys, order_keys)`.
30
31use std::fmt::{self, Formatter};
32use std::sync::Arc;
33
34use arrow::array::{RecordBatch, UInt32Array};
35use arrow::compute::{BatchCoalescer, take_record_batch};
36use arrow::datatypes::SchemaRef;
37use arrow::row::{OwnedRow, RowConverter};
38use datafusion_common::{HashMap, Result};
39use datafusion_execution::TaskContext;
40use datafusion_physical_expr::PhysicalExpr;
41use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
42use datafusion_physical_expr_common::sort_expr::LexOrdering;
43use futures::StreamExt;
44use futures::TryStreamExt;
45use parking_lot::RwLock;
46
47use crate::execution_plan::{Boundedness, EmissionType};
48use crate::metrics::ExecutionPlanMetricsSet;
49use crate::topk::{TopK, TopKDynamicFilters, build_sort_fields};
50use crate::{
51 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
52 PlanProperties, SendableRecordBatchStream, stream::RecordBatchStreamAdapter,
53};
54
55/// Per-partition Top-K operator for window function queries.
56///
57/// # Background
58///
59/// "Top K per partition" is a common analytics pattern used for queries such as
60/// "find the top 3 products by revenue for each store". The (simplified) SQL
61/// for such a query might be:
62///
63/// ```sql
64/// SELECT * FROM (
65/// SELECT *, ROW_NUMBER() OVER (PARTITION BY store ORDER BY revenue DESC) as rn
66/// FROM sales
67/// ) WHERE rn <= 3;
68/// ```
69///
70/// The unoptimized physical plan would be:
71///
72/// ```text
73/// FilterExec: rn <= 3
74/// BoundedWindowAggExec: ROW_NUMBER() PARTITION BY [store] ORDER BY [revenue DESC]
75/// SortExec: expr=[store ASC, revenue DESC]
76/// DataSourceExec
77/// ```
78///
79/// This plan sorts the **entire** dataset (O(N log N)), computes `ROW_NUMBER`
80/// for **all** rows, and then filters to keep only the top K per partition.
81/// With 10M rows, 1K partitions, and K=3, it sorts all 10M rows but only
82/// keeps 3K.
83///
84/// # Optimization
85///
86/// `PartitionedTopKExec` replaces the `SortExec` and the `FilterExec` is
87/// removed. The optimized plan becomes:
88///
89/// ```text
90/// BoundedWindowAggExec: ROW_NUMBER() PARTITION BY [store] ORDER BY [revenue DESC]
91/// PartitionedTopKExec: fetch=3, partition=[store], order=[revenue DESC]
92/// DataSourceExec
93/// ```
94///
95/// Instead of sorting the entire dataset, this operator reads unsorted input,
96/// maintains a [`TopK`] heap per distinct partition key, and emits only the
97/// top-K rows per partition in sorted order `(partition_keys, order_keys)`.
98///
99/// Cost: O(N log K) time instead of O(N log N), and O(K × P × row_size)
100/// memory where K = fetch, P = number of distinct partitions.
101/// ## Why maintaining partition key order in output
102/// Window functions do not require partition keys to be globally sorted, and
103/// enforcing such ordering in the output can introduce unnecessary overhead.
104/// However, the physical optimizer framework currently cannot express an
105/// ordering that is only grouped by some keys while ordered by others. For
106/// example:
107///
108///
109/// # Example
110///
111/// For the query above with `fetch=3` and input:
112///
113/// ```text
114/// store | revenue
115/// ------|--------
116/// A | 100
117/// B | 50
118/// A | 200
119/// B | 150
120/// A | 300
121/// A | 400
122/// ```
123///
124/// The operator maintains two heaps:
125/// - **store=A**: keeps top-3 by revenue DESC → {400, 300, 200}, evicts 100
126/// - **store=B**: keeps top-3 by revenue DESC → {150, 50} (only 2 rows)
127///
128/// Output (sorted by store ASC, revenue DESC):
129///
130/// ```text
131/// store | revenue
132/// ------|--------
133/// A | 400
134/// A | 300
135/// A | 200
136/// B | 150
137/// B | 50
138/// ```
139///
140/// This is then passed to `BoundedWindowAggExec` which assigns
141/// `ROW_NUMBER` 1, 2, 3 to each partition — all of which satisfy `rn <= 3`.
142///
143/// # Limitations
144///
145/// - Only activated when the window function is `ROW_NUMBER` with a
146/// `PARTITION BY` clause. Global top-K (no `PARTITION BY`) is already
147/// handled efficiently by `SortExec` with `fetch`.
148/// - For very high cardinality partition keys (millions of distinct values),
149/// both memory usage and runtime overhead can become significant. In such
150/// cases, the sort-based plan is more robust. Therefore, this optimization
151/// is currently controlled by a configuration flag.
152#[derive(Debug, Clone)]
153pub struct PartitionedTopKExec {
154 /// Input execution plan (reads unsorted data)
155 input: Arc<dyn ExecutionPlan>,
156 /// Full sort expressions: `[partition_keys..., order_keys...]`.
157 ///
158 /// For `PARTITION BY store ORDER BY revenue DESC` with sort
159 /// `[store ASC, revenue DESC]`, the first `partition_prefix_len`
160 /// expressions are the partition keys (`[store ASC]`) and the
161 /// remaining are the order-by keys (`[revenue DESC]`).
162 expr: LexOrdering,
163 /// Number of leading expressions in `expr` that define the partition
164 /// key. For example, `PARTITION BY a, b` → `partition_prefix_len = 2`.
165 partition_prefix_len: usize,
166 /// Maximum number of rows to keep per partition (the K in "top-K").
167 /// Derived from the filter predicate: `rn <= 3` → `fetch = 3`,
168 /// `rn < 3` → `fetch = 2`.
169 fetch: usize,
170 /// Execution metrics
171 metrics_set: ExecutionPlanMetricsSet,
172 /// Cached plan properties (output ordering, partitioning, etc.)
173 cache: Arc<PlanProperties>,
174}
175
176impl PartitionedTopKExec {
177 /// Create a new `PartitionedTopKExec`.
178 ///
179 /// # Arguments
180 ///
181 /// * `input` - The child execution plan providing unsorted input rows.
182 /// * `expr` - Full sort ordering `[partition_keys..., order_keys...]`.
183 /// For `PARTITION BY pk ORDER BY val ASC`, this would be `[pk ASC, val ASC]`.
184 /// * `partition_prefix_len` - Number of leading expressions in `expr`
185 /// that form the partition key. Must be >= 1.
186 /// * `fetch` - Maximum rows to retain per partition (the K in "top-K").
187 ///
188 /// # Example
189 ///
190 /// ```text
191 /// // For: ROW_NUMBER() OVER (PARTITION BY store ORDER BY revenue DESC) ... WHERE rn <= 5
192 /// PartitionedTopKExec::try_new(
193 /// data_source,
194 /// LexOrdering([store ASC, revenue DESC]),
195 /// 1, // partition_prefix_len: 1 partition column (store)
196 /// 5, // fetch: keep top 5 per partition
197 /// )
198 /// ```
199 pub fn try_new(
200 input: Arc<dyn ExecutionPlan>,
201 expr: LexOrdering,
202 partition_prefix_len: usize,
203 fetch: usize,
204 ) -> Result<Self> {
205 let cache = Self::compute_properties(&input, expr.clone())?;
206 Ok(Self {
207 input,
208 expr,
209 partition_prefix_len,
210 fetch,
211 metrics_set: ExecutionPlanMetricsSet::new(),
212 cache: Arc::new(cache),
213 })
214 }
215
216 /// Returns the child execution plan.
217 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
218 &self.input
219 }
220
221 /// Returns the full sort ordering `[partition_keys..., order_keys...]`.
222 pub fn expr(&self) -> &LexOrdering {
223 &self.expr
224 }
225
226 /// Returns the number of leading expressions in [`Self::expr`] that
227 /// define the partition key.
228 pub fn partition_prefix_len(&self) -> usize {
229 self.partition_prefix_len
230 }
231
232 /// Returns the maximum number of rows retained per partition.
233 pub fn fetch(&self) -> usize {
234 self.fetch
235 }
236
237 /// Compute [`PlanProperties`] for this operator.
238 ///
239 /// The output is sorted by `sort_exprs` (partition keys then order keys),
240 /// uses the same partitioning as the input, emits all output at once
241 /// (`EmissionType::Final`), and is bounded.
242 fn compute_properties(
243 input: &Arc<dyn ExecutionPlan>,
244 sort_exprs: LexOrdering,
245 ) -> Result<PlanProperties> {
246 let mut eq_properties = input.equivalence_properties().clone();
247 eq_properties.reorder(sort_exprs)?;
248
249 Ok(PlanProperties::new(
250 eq_properties,
251 input.output_partitioning().clone(),
252 EmissionType::Final,
253 Boundedness::Bounded,
254 ))
255 }
256}
257
258impl DisplayAs for PartitionedTopKExec {
259 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result {
260 match t {
261 DisplayFormatType::Default | DisplayFormatType::Verbose => {
262 let partition_exprs: Vec<String> = self.expr[..self.partition_prefix_len]
263 .iter()
264 .map(|e| format!("{}", e.expr))
265 .collect();
266 let order_exprs: Vec<String> = self.expr[self.partition_prefix_len..]
267 .iter()
268 .map(|e| format!("{e}"))
269 .collect();
270 write!(
271 f,
272 "PartitionedTopKExec: fetch={}, partition=[{}], order=[{}]",
273 self.fetch,
274 partition_exprs.join(", "),
275 order_exprs.join(", "),
276 )
277 }
278 DisplayFormatType::TreeRender => {
279 let partition_exprs: Vec<String> = self.expr[..self.partition_prefix_len]
280 .iter()
281 .map(|e| format!("{}", e.expr))
282 .collect();
283 let order_exprs: Vec<String> = self.expr[self.partition_prefix_len..]
284 .iter()
285 .map(|e| format!("{e}"))
286 .collect();
287 writeln!(f, "fetch={}", self.fetch)?;
288 writeln!(f, "partition=[{}]", partition_exprs.join(", "))?;
289 writeln!(f, "order=[{}]", order_exprs.join(", "))
290 }
291 }
292 }
293}
294
295impl ExecutionPlan for PartitionedTopKExec {
296 fn name(&self) -> &'static str {
297 "PartitionedTopKExec"
298 }
299
300 fn properties(&self) -> &Arc<PlanProperties> {
301 &self.cache
302 }
303
304 fn required_input_distribution(&self) -> Vec<Distribution> {
305 let partition_exprs: Vec<Arc<dyn PhysicalExpr>> = self.expr
306 [..self.partition_prefix_len]
307 .iter()
308 .map(|e| Arc::clone(&e.expr))
309 .collect();
310 vec![Distribution::HashPartitioned(partition_exprs)]
311 }
312
313 fn maintains_input_order(&self) -> Vec<bool> {
314 vec![false]
315 }
316
317 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
318 vec![&self.input]
319 }
320
321 fn with_new_children(
322 self: Arc<Self>,
323 children: Vec<Arc<dyn ExecutionPlan>>,
324 ) -> Result<Arc<dyn ExecutionPlan>> {
325 assert_eq!(children.len(), 1);
326 Ok(Arc::new(PartitionedTopKExec::try_new(
327 Arc::clone(&children[0]),
328 self.expr.clone(),
329 self.partition_prefix_len,
330 self.fetch,
331 )?))
332 }
333
334 fn execute(
335 &self,
336 partition: usize,
337 context: Arc<TaskContext>,
338 ) -> Result<SendableRecordBatchStream> {
339 let input = self.input.execute(partition, Arc::clone(&context))?;
340 let schema = input.schema();
341
342 let partition_sort_fields =
343 build_sort_fields(&self.expr[..self.partition_prefix_len], &schema)?;
344
345 let partition_converter = RowConverter::new(partition_sort_fields)?;
346
347 let partition_exprs: Vec<Arc<dyn PhysicalExpr>> = self.expr
348 [..self.partition_prefix_len]
349 .iter()
350 .map(|e| Arc::clone(&e.expr))
351 .collect();
352 let order_expr: LexOrdering =
353 LexOrdering::new(self.expr[self.partition_prefix_len..].iter().cloned())
354 .expect("PartitionedTopKExec requires at least one order-by expression");
355 let fetch = self.fetch;
356 let batch_size = context.session_config().batch_size();
357 let runtime = Arc::clone(&context.runtime_env());
358 let metrics_set = self.metrics_set.clone();
359
360 let stream = futures::stream::once(async move {
361 do_partitioned_topk(
362 input,
363 schema,
364 partition_converter,
365 partition_exprs,
366 order_expr,
367 fetch,
368 batch_size,
369 runtime,
370 metrics_set,
371 )
372 .await
373 })
374 .try_flatten();
375
376 Ok(Box::pin(RecordBatchStreamAdapter::new(
377 self.input.schema(),
378 stream,
379 )))
380 }
381}
382
383/// Create a no-op [`TopKDynamicFilters`] for a per-partition [`TopK`].
384///
385/// In normal `SortExec` top-K mode, dynamic filters push predicates down to
386/// the data source (e.g., telling Parquet to skip rows worse than the current
387/// K-th best). For per-partition heaps the data is already in memory and split
388/// by partition key, so there is no data source to push filters to. We pass
389/// `lit(true)` (accept everything) so the filter never rejects any row.
390fn create_noop_dynamic_filter() -> Arc<RwLock<TopKDynamicFilters>> {
391 Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
392 DynamicFilterPhysicalExpr::new(vec![], lit(true)),
393 ))))
394}
395
396/// Read all input, split batches by partition key, feed each sub-batch
397/// to a per-partition [`TopK`], then emit results in partition-key order.
398///
399/// # Phases
400///
401/// 1. **Accumulation** — For each input batch:
402/// - Evaluate partition key expressions to get partition column arrays
403/// - Convert partition columns to binary [`arrow::row::Row`] format
404/// - Group row indices by partition key
405/// - Extract sub-batches via [`take_record_batch`] and insert into
406/// the partition's [`TopK`] heap
407///
408/// 2. **Emission** — After all input is consumed:
409/// - Sort partition keys so output is ordered by partition key
410/// - For each partition in sorted order, call [`TopK::emit`] to get
411/// rows sorted by order-by key
412/// - Return all batches as a single stream
413///
414/// # Cost
415///
416/// - Time: O(N log K) where N = total rows, K = fetch
417/// - Memory: O(K × P × row_size) where P = number of distinct partitions
418#[expect(clippy::too_many_arguments)]
419async fn do_partitioned_topk(
420 mut input: SendableRecordBatchStream,
421 schema: SchemaRef,
422 partition_converter: RowConverter,
423 partition_exprs: Vec<Arc<dyn PhysicalExpr>>,
424 order_expr: LexOrdering,
425 fetch: usize,
426 batch_size: usize,
427 runtime: Arc<datafusion_execution::runtime_env::RuntimeEnv>,
428 metrics_set: ExecutionPlanMetricsSet,
429) -> Result<SendableRecordBatchStream> {
430 let mut partitions: HashMap<OwnedRow, TopK> = HashMap::new();
431 let mut partition_counter: usize = 0;
432
433 // Macro-like helper: create a new TopK for a partition
434 macro_rules! new_topk {
435 () => {{
436 let id = partition_counter;
437 partition_counter += 1;
438 TopK::try_new(
439 id,
440 Arc::clone(&schema),
441 vec![],
442 order_expr.clone(),
443 fetch,
444 batch_size,
445 Arc::clone(&runtime),
446 &metrics_set,
447 create_noop_dynamic_filter(),
448 )
449 }};
450 }
451
452 // ---------- Accumulation phase ----------
453 while let Some(batch) = input.next().await {
454 let batch = batch?;
455 let num_rows = batch.num_rows();
456 if num_rows == 0 {
457 continue;
458 }
459
460 // Evaluate partition key columns
461 let pk_arrays: Vec<_> = partition_exprs
462 .iter()
463 .map(|e| e.evaluate(&batch).and_then(|v| v.into_array(num_rows)))
464 .collect::<Result<Vec<_>>>()?;
465
466 let pk_rows = partition_converter.convert_columns(&pk_arrays)?;
467
468 // Group row indices by partition key
469 let mut groups: HashMap<OwnedRow, Vec<u32>> = HashMap::new();
470 for row_idx in 0..num_rows {
471 let pk = pk_rows.row(row_idx).owned();
472 groups.entry(pk).or_default().push(row_idx as u32);
473 }
474
475 // For each partition group, create a sub-batch and feed to TopK
476 for (pk, indices) in groups {
477 if !partitions.contains_key(&pk) {
478 partitions.insert(pk.clone(), new_topk!()?);
479 }
480 let topk = partitions.get_mut(&pk).unwrap();
481 let indices_array = UInt32Array::from(indices);
482 let sub_batch = take_record_batch(&batch, &indices_array)?;
483 topk.insert_batch(sub_batch)?;
484 }
485 }
486 // Release the input pipeline now that accumulation is complete.
487 drop(input);
488
489 // ---------- Emit phase ----------
490 // Sort partition keys so output is ordered by (partition_keys, order_keys).
491 let mut sorted_pks: Vec<OwnedRow> = partitions.keys().cloned().collect();
492 sorted_pks.sort();
493
494 let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), batch_size);
495
496 for pk in sorted_pks {
497 if let Some(topk) = partitions.remove(&pk) {
498 // TopK::emit() returns a stream of sorted batches
499 let mut stream = topk.emit()?;
500 while let Some(batch) = stream.next().await {
501 coalescer.push_batch(batch?)?;
502 }
503 }
504 }
505 coalescer.finish_buffered_batch()?;
506 let mut output_batches: Vec<RecordBatch> = Vec::new();
507 while let Some(batch) = coalescer.next_completed_batch() {
508 output_batches.push(batch);
509 }
510
511 Ok(Box::pin(RecordBatchStreamAdapter::new(
512 schema,
513 futures::stream::iter(output_batches.into_iter().map(Ok)),
514 )))
515}