datafusion_physical_plan/topk/mod.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! TopK: Combination of Sort / LIMIT
19
20use arrow::{
21 array::{Array, AsArray},
22 compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter},
23 row::{RowConverter, Rows, SortField},
24};
25use datafusion_expr::{ColumnarValue, Operator};
26use std::mem::size_of;
27use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
28
29use super::metrics::{
30 BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, RecordOutput,
31};
32use crate::spill::get_record_batch_memory_size;
33use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter};
34
35use arrow::array::{ArrayRef, RecordBatch};
36use arrow::datatypes::SchemaRef;
37use datafusion_common::{
38 HashMap, Result, ScalarValue, internal_datafusion_err, internal_err,
39};
40use datafusion_execution::{
41 memory_pool::{MemoryConsumer, MemoryReservation},
42 runtime_env::RuntimeEnv,
43};
44use datafusion_physical_expr::{
45 PhysicalExpr,
46 expressions::{BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit},
47};
48use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
49use parking_lot::RwLock;
50
51/// Global TopK
52///
53/// # Background
54///
55/// "Top K" is a common query optimization used for queries such as
56/// "find the top 3 customers by revenue". The (simplified) SQL for
57/// such a query might be:
58///
59/// ```sql
60/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
61/// ```
62///
63/// The simple plan would be:
64///
65/// ```sql
66/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
67/// +--------------+----------------------------------------+
68/// | plan_type | plan |
69/// +--------------+----------------------------------------+
70/// | logical_plan | Limit: 3 |
71/// | | Sort: revenue DESC NULLS FIRST |
72/// | | Projection: customer_id, revenue |
73/// | | TableScan: sales |
74/// +--------------+----------------------------------------+
75/// ```
76///
77/// While this plan produces the correct answer, it will fully sorts the
78/// input before discarding everything other than the top 3 elements.
79///
80/// The same answer can be produced by simply keeping track of the top
81/// K=3 elements, reducing the total amount of required buffer memory.
82///
83/// # Partial Sort Optimization
84///
85/// This implementation additionally optimizes queries where the input is already
86/// partially sorted by a common prefix of the requested ordering. Once the top K
87/// heap is full, if subsequent rows are guaranteed to be strictly greater (in sort
88/// order) on this prefix than the largest row currently stored, the operator
89/// safely terminates early.
90///
91/// ## Example
92///
93/// For input sorted by `(day DESC)`, but not by `timestamp`, a query such as:
94///
95/// ```sql
96/// SELECT day, timestamp FROM sensor ORDER BY day DESC, timestamp DESC LIMIT 10;
97/// ```
98///
99/// can terminate scanning early once sufficient rows from the latest days have been
100/// collected, skipping older data.
101///
102/// # Structure
103///
104/// This operator tracks the top K items using a `TopKHeap`.
105pub struct TopK {
106 /// schema of the output (and the input)
107 schema: SchemaRef,
108 /// Runtime metrics
109 metrics: TopKMetrics,
110 /// Reservation
111 reservation: MemoryReservation,
112 /// The target number of rows for output batches
113 batch_size: usize,
114 /// sort expressions
115 expr: LexOrdering,
116 /// row converter, for sort keys
117 row_converter: RowConverter,
118 /// scratch space for converting rows
119 scratch_rows: Rows,
120 /// stores the top k values and their sort key values, in order
121 heap: TopKHeap,
122 /// row converter, for common keys between the sort keys and the input ordering
123 common_sort_prefix_converter: Option<RowConverter>,
124 /// Common sort prefix between the input and the sort expressions to allow early exit optimization
125 common_sort_prefix: Arc<[PhysicalSortExpr]>,
126 /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
127 filter: Arc<RwLock<TopKDynamicFilters>>,
128 /// If true, indicates that all rows of subsequent batches are guaranteed
129 /// to be greater (by byte order, after row conversion) than the top K,
130 /// which means the top K won't change and the computation can be finished early.
131 pub(crate) finished: bool,
132}
133
134/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]
135///
136/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters
137#[derive(Debug, Clone)]
138pub struct TopKDynamicFilters {
139 /// The current *global* threshold for the dynamic filter.
140 /// This is shared across all partitions and is updated by any of them.
141 /// Stored as row bytes for efficient comparison.
142 threshold_row: Option<Vec<u8>>,
143 /// The expression used to evaluate the dynamic filter
144 /// Only updated when lock held for the duration of the update
145 expr: Arc<DynamicFilterPhysicalExpr>,
146}
147
148impl TopKDynamicFilters {
149 /// Create a new `TopKDynamicFilters` with the given expression
150 pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
151 Self {
152 threshold_row: None,
153 expr,
154 }
155 }
156
157 pub fn expr(&self) -> Arc<DynamicFilterPhysicalExpr> {
158 Arc::clone(&self.expr)
159 }
160}
161
162// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
163const ESTIMATED_BYTES_PER_ROW: usize = 20;
164
165fn build_sort_fields(
166 ordering: &[PhysicalSortExpr],
167 schema: &SchemaRef,
168) -> Result<Vec<SortField>> {
169 ordering
170 .iter()
171 .map(|e| {
172 Ok(SortField::new_with_options(
173 e.expr.data_type(schema)?,
174 e.options,
175 ))
176 })
177 .collect::<Result<_>>()
178}
179
180impl TopK {
181 /// Create a new [`TopK`] that stores the top `k` values, as
182 /// defined by the sort expressions in `expr`.
183 // TODO: make a builder or some other nicer API
184 #[expect(clippy::too_many_arguments)]
185 #[expect(clippy::needless_pass_by_value)]
186 pub fn try_new(
187 partition_id: usize,
188 schema: SchemaRef,
189 common_sort_prefix: Vec<PhysicalSortExpr>,
190 expr: LexOrdering,
191 k: usize,
192 batch_size: usize,
193 runtime: Arc<RuntimeEnv>,
194 metrics: &ExecutionPlanMetricsSet,
195 filter: Arc<RwLock<TopKDynamicFilters>>,
196 ) -> Result<Self> {
197 let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
198 .register(&runtime.memory_pool);
199
200 let sort_fields = build_sort_fields(&expr, &schema)?;
201
202 // TODO there is potential to add special cases for single column sort fields
203 // to improve performance
204 let row_converter = RowConverter::new(sort_fields)?;
205 let scratch_rows =
206 row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size);
207
208 let prefix_row_converter = if common_sort_prefix.is_empty() {
209 None
210 } else {
211 let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?;
212 Some(RowConverter::new(input_sort_fields)?)
213 };
214
215 Ok(Self {
216 schema: Arc::clone(&schema),
217 metrics: TopKMetrics::new(metrics, partition_id),
218 reservation,
219 batch_size,
220 expr,
221 row_converter,
222 scratch_rows,
223 heap: TopKHeap::new(k, batch_size),
224 common_sort_prefix_converter: prefix_row_converter,
225 common_sort_prefix: Arc::from(common_sort_prefix),
226 finished: false,
227 filter,
228 })
229 }
230
231 /// Insert `batch`, remembering if any of its values are among
232 /// the top k seen so far.
233 #[expect(clippy::needless_pass_by_value)]
234 pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
235 // Updates on drop
236 let baseline = self.metrics.baseline.clone();
237 let _timer = baseline.elapsed_compute().timer();
238
239 let mut sort_keys: Vec<ArrayRef> = self
240 .expr
241 .iter()
242 .map(|expr| {
243 let value = expr.expr.evaluate(&batch)?;
244 value.into_array(batch.num_rows())
245 })
246 .collect::<Result<Vec<_>>>()?;
247
248 let mut selected_rows = None;
249
250 // If a filter is provided, update it with the new rows
251 let filter = self.filter.read().expr.current()?;
252 let filtered = filter.evaluate(&batch)?;
253 let num_rows = batch.num_rows();
254 let array = filtered.into_array(num_rows)?;
255 let mut filter = array.as_boolean().clone();
256 let true_count = filter.true_count();
257 if true_count == 0 {
258 // nothing to filter, so no need to update
259 return Ok(());
260 }
261 // only update the keys / rows if the filter does not match all rows
262 if true_count < num_rows {
263 // Indices in `set_indices` should be correct if filter contains nulls
264 // So we prepare the filter here. Note this is also done in the `FilterBuilder`
265 // so there is no overhead to do this here.
266 if filter.nulls().is_some() {
267 filter = prep_null_mask_filter(&filter);
268 }
269
270 let filter_predicate = FilterBuilder::new(&filter);
271 let filter_predicate = if sort_keys.len() > 1 {
272 // Optimize filter when it has multiple sort keys
273 filter_predicate.optimize().build()
274 } else {
275 filter_predicate.build()
276 };
277 selected_rows = Some(filter);
278 sort_keys = sort_keys
279 .iter()
280 .map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
281 .collect::<Result<Vec<_>>>()?;
282 }
283 // reuse existing `Rows` to avoid reallocations
284 let rows = &mut self.scratch_rows;
285 rows.clear();
286 self.row_converter.append(rows, &sort_keys)?;
287
288 let mut batch_entry = self.heap.register_batch(batch.clone());
289
290 let replacements = match selected_rows {
291 Some(filter) => {
292 self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry)
293 }
294 None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry),
295 };
296
297 if replacements > 0 {
298 self.metrics.row_replacements.add(replacements);
299
300 self.heap.insert_batch_entry(batch_entry);
301
302 // conserve memory
303 self.heap.maybe_compact()?;
304
305 // update memory reservation
306 self.reservation.try_resize(self.size())?;
307
308 // flag the topK as finished if we know that all
309 // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K,
310 // which means the top K won't change and the computation can be finished early.
311 self.attempt_early_completion(&batch)?;
312
313 // update the filter representation of our TopK heap
314 self.update_filter()?;
315 }
316
317 Ok(())
318 }
319
320 fn find_new_topk_items(
321 &mut self,
322 items: impl Iterator<Item = usize>,
323 batch_entry: &mut RecordBatchEntry,
324 ) -> usize {
325 let mut replacements = 0;
326 let rows = &mut self.scratch_rows;
327 for (index, row) in items.zip(rows.iter()) {
328 match self.heap.max() {
329 // heap has k items, and the new row is greater than the
330 // current max in the heap ==> it is not a new topk
331 Some(max_row) if row.as_ref() >= max_row.row() => {}
332 // don't yet have k items or new item is lower than the currently k low values
333 None | Some(_) => {
334 self.heap.add(batch_entry, row, index);
335 replacements += 1;
336 }
337 }
338 }
339 replacements
340 }
341
342 /// Update the filter representation of our TopK heap.
343 /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`,
344 /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`,
345 /// the filter will be updated to:
346 ///
347 /// ```sql
348 /// (a > 1 OR (a = 1 AND b < 5)) AND
349 /// (a > 1 OR (a = 1 AND b < 4)) AND
350 /// (a > 2 OR (a = 2 AND b < 3))
351 /// ```
352 fn update_filter(&mut self) -> Result<()> {
353 // If the heap doesn't have k elements yet, we can't create thresholds
354 let Some(max_row) = self.heap.max() else {
355 return Ok(());
356 };
357
358 let new_threshold_row = &max_row.row;
359
360 // Fast path: check if the current value in topk is better than what is
361 // currently set in the filter with a read only lock
362 let needs_update = self
363 .filter
364 .read()
365 .threshold_row
366 .as_ref()
367 .map(|current_row| {
368 // new < current means new threshold is more selective
369 new_threshold_row < current_row
370 })
371 .unwrap_or(true); // No current threshold, so we need to set one
372
373 // exit early if the current values are better
374 if !needs_update {
375 return Ok(());
376 }
377
378 // Extract scalar values BEFORE acquiring lock to reduce critical section
379 let thresholds = match self.heap.get_threshold_values(&self.expr)? {
380 Some(t) => t,
381 None => return Ok(()),
382 };
383
384 // Build the filter expression OUTSIDE any synchronization
385 let predicate = Self::build_filter_expression(&self.expr, &thresholds)?;
386 let new_threshold = new_threshold_row.to_vec();
387
388 // update the threshold. Since there was a lock gap, we must check if it is still the best
389 // may have changed while we were building the expression without the lock
390 let mut filter = self.filter.write();
391 let old_threshold = filter.threshold_row.take();
392
393 // Update filter if we successfully updated the threshold
394 // (or if there was no previous threshold and we're the first)
395 match old_threshold {
396 Some(old_threshold) => {
397 // new threshold is still better than the old one
398 if new_threshold.as_slice() < old_threshold.as_slice() {
399 filter.threshold_row = Some(new_threshold);
400 } else {
401 // some other thread updated the threshold to a better
402 // one while we were building so there is no need to
403 // update the filter
404 filter.threshold_row = Some(old_threshold);
405 return Ok(());
406 }
407 }
408 None => {
409 // No previous threshold, so we can set the new one
410 filter.threshold_row = Some(new_threshold);
411 }
412 };
413
414 // Update the filter expression
415 if let Some(pred) = predicate
416 && !pred.eq(&lit(true))
417 {
418 filter.expr.update(pred)?;
419 }
420
421 Ok(())
422 }
423
424 /// Build the filter expression with the given thresholds.
425 /// This is now called outside of any locks to reduce critical section time.
426 fn build_filter_expression(
427 sort_exprs: &[PhysicalSortExpr],
428 thresholds: &[ScalarValue],
429 ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
430 // Create filter expressions for each threshold
431 let mut filters: Vec<Arc<dyn PhysicalExpr>> =
432 Vec::with_capacity(thresholds.len());
433
434 let mut prev_sort_expr: Option<Arc<dyn PhysicalExpr>> = None;
435 for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) {
436 // Create the appropriate operator based on sort order
437 let op = if sort_expr.options.descending {
438 // For descending sort, we want col > threshold (exclude smaller values)
439 Operator::Gt
440 } else {
441 // For ascending sort, we want col < threshold (exclude larger values)
442 Operator::Lt
443 };
444
445 let value_null = value.is_null();
446
447 let comparison = Arc::new(BinaryExpr::new(
448 Arc::clone(&sort_expr.expr),
449 op,
450 lit(value.clone()),
451 ));
452
453 let comparison_with_null = match (sort_expr.options.nulls_first, value_null) {
454 // For nulls first, transform to (threshold.value is not null) and (threshold.expr is null or comparison)
455 (true, true) => lit(false),
456 (true, false) => Arc::new(BinaryExpr::new(
457 is_null(Arc::clone(&sort_expr.expr))?,
458 Operator::Or,
459 comparison,
460 )),
461 // For nulls last, transform to (threshold.value is null and threshold.expr is not null)
462 // or (threshold.value is not null and comparison)
463 (false, true) => is_not_null(Arc::clone(&sort_expr.expr))?,
464 (false, false) => comparison,
465 };
466
467 let mut eq_expr = Arc::new(BinaryExpr::new(
468 Arc::clone(&sort_expr.expr),
469 Operator::Eq,
470 lit(value.clone()),
471 ));
472
473 if value_null {
474 eq_expr = Arc::new(BinaryExpr::new(
475 is_null(Arc::clone(&sort_expr.expr))?,
476 Operator::Or,
477 eq_expr,
478 ));
479 }
480
481 // For a query like order by a, b, the filter for column `b` is only applied if
482 // the condition a = threshold.value (considering null equality) is met.
483 // Therefore, we add equality predicates for all preceding fields to the filter logic of the current field,
484 // and include the current field's equality predicate in `prev_sort_expr` for use with subsequent fields.
485 match prev_sort_expr.take() {
486 None => {
487 prev_sort_expr = Some(eq_expr);
488 filters.push(comparison_with_null);
489 }
490 Some(p) => {
491 filters.push(Arc::new(BinaryExpr::new(
492 Arc::clone(&p),
493 Operator::And,
494 comparison_with_null,
495 )));
496
497 prev_sort_expr =
498 Some(Arc::new(BinaryExpr::new(p, Operator::And, eq_expr)));
499 }
500 }
501 }
502
503 let dynamic_predicate = filters
504 .into_iter()
505 .reduce(|a, b| Arc::new(BinaryExpr::new(a, Operator::Or, b)));
506
507 Ok(dynamic_predicate)
508 }
509
510 /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full,
511 /// check if the computation can be finished early.
512 /// This is the case if the last row of the current batch is strictly greater than the max row in the heap,
513 /// comparing only on the shared prefix columns.
514 fn attempt_early_completion(&mut self, batch: &RecordBatch) -> Result<()> {
515 // Early exit if the batch is empty as there is no last row to extract from it.
516 if batch.num_rows() == 0 {
517 return Ok(());
518 }
519
520 // prefix_row_converter is only `Some` if the input ordering has a common prefix with the TopK,
521 // so early exit if it is `None`.
522 let Some(prefix_converter) = &self.common_sort_prefix_converter else {
523 return Ok(());
524 };
525
526 // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full).
527 let Some(max_topk_row) = self.heap.max() else {
528 return Ok(());
529 };
530
531 // Evaluate the prefix for the last row of the current batch.
532 let last_row_idx = batch.num_rows() - 1;
533 let mut batch_prefix_scratch =
534 prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
535
536 self.compute_common_sort_prefix(batch, last_row_idx, &mut batch_prefix_scratch)?;
537
538 // Retrieve the max row from the heap.
539 let store_entry = self
540 .heap
541 .store
542 .get(max_topk_row.batch_id)
543 .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?;
544 let max_batch = &store_entry.batch;
545 let mut heap_prefix_scratch =
546 prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
547 self.compute_common_sort_prefix(
548 max_batch,
549 max_topk_row.index,
550 &mut heap_prefix_scratch,
551 )?;
552
553 // If the last row's prefix is strictly greater than the max prefix, mark as finished.
554 if batch_prefix_scratch.row(0).as_ref() > heap_prefix_scratch.row(0).as_ref() {
555 self.finished = true;
556 }
557
558 Ok(())
559 }
560
561 // Helper function to compute the prefix for a given batch and row index, storing the result in scratch.
562 fn compute_common_sort_prefix(
563 &self,
564 batch: &RecordBatch,
565 last_row_idx: usize,
566 scratch: &mut Rows,
567 ) -> Result<()> {
568 let last_row: Vec<ArrayRef> = self
569 .common_sort_prefix
570 .iter()
571 .map(|expr| {
572 expr.expr
573 .evaluate(&batch.slice(last_row_idx, 1))?
574 .into_array(1)
575 })
576 .collect::<Result<_>>()?;
577
578 self.common_sort_prefix_converter
579 .as_ref()
580 .unwrap()
581 .append(scratch, &last_row)?;
582 Ok(())
583 }
584
585 /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
586 pub fn emit(self) -> Result<SendableRecordBatchStream> {
587 let Self {
588 schema,
589 metrics,
590 reservation: _,
591 batch_size,
592 expr: _,
593 row_converter: _,
594 scratch_rows: _,
595 mut heap,
596 common_sort_prefix_converter: _,
597 common_sort_prefix: _,
598 finished: _,
599 filter,
600 } = self;
601 let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
602
603 // Mark the dynamic filter as complete now that TopK processing is finished.
604 filter.read().expr().mark_complete();
605
606 // break into record batches as needed
607 let mut batches = vec![];
608 if let Some(mut batch) = heap.emit()? {
609 (&batch).record_output(&metrics.baseline);
610
611 loop {
612 if batch.num_rows() <= batch_size {
613 batches.push(Ok(batch));
614 break;
615 } else {
616 batches.push(Ok(batch.slice(0, batch_size)));
617 let remaining_length = batch.num_rows() - batch_size;
618 batch = batch.slice(batch_size, remaining_length);
619 }
620 }
621 };
622 Ok(Box::pin(RecordBatchStreamAdapter::new(
623 schema,
624 futures::stream::iter(batches),
625 )))
626 }
627
628 /// return the size of memory used by this operator, in bytes
629 fn size(&self) -> usize {
630 size_of::<Self>()
631 + self.row_converter.size()
632 + self.scratch_rows.size()
633 + self.heap.size()
634 }
635}
636
637struct TopKMetrics {
638 /// metrics
639 pub baseline: BaselineMetrics,
640
641 /// count of how many rows were replaced in the heap
642 pub row_replacements: Count,
643}
644
645impl TopKMetrics {
646 fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
647 Self {
648 baseline: BaselineMetrics::new(metrics, partition),
649 row_replacements: MetricBuilder::new(metrics)
650 .counter("row_replacements", partition),
651 }
652 }
653}
654
655/// This structure keeps at most the *smallest* k items, using the
656/// [arrow::row] format for sort keys. While it is called "topK" for
657/// values like `1, 2, 3, 4, 5` the "top 3" really means the
658/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
659///
660/// Using the `Row` format handles things such as ascending vs
661/// descending and nulls first vs nulls last.
662struct TopKHeap {
663 /// The maximum number of elements to store in this heap.
664 k: usize,
665 /// The target number of rows for output batches
666 batch_size: usize,
667 /// Storage for up at most `k` items using a BinaryHeap. Reversed
668 /// so that the smallest k so far is on the top
669 inner: BinaryHeap<TopKRow>,
670 /// Storage the original row values (TopKRow only has the sort key)
671 store: RecordBatchStore,
672 /// The size of all owned data held by this heap
673 owned_bytes: usize,
674}
675
676impl TopKHeap {
677 fn new(k: usize, batch_size: usize) -> Self {
678 assert!(k > 0);
679 Self {
680 k,
681 batch_size,
682 inner: BinaryHeap::new(),
683 store: RecordBatchStore::new(),
684 owned_bytes: 0,
685 }
686 }
687
688 /// Register a [`RecordBatch`] with the heap, returning the
689 /// appropriate entry
690 pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
691 self.store.register(batch)
692 }
693
694 /// Insert a [`RecordBatchEntry`] created by a previous call to
695 /// [`Self::register_batch`] into storage.
696 pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
697 self.store.insert(entry)
698 }
699
700 /// Returns the largest value stored by the heap if there are k
701 /// items, otherwise returns None. Remember this structure is
702 /// keeping the "smallest" k values
703 fn max(&self) -> Option<&TopKRow> {
704 if self.inner.len() < self.k {
705 None
706 } else {
707 self.inner.peek()
708 }
709 }
710
711 /// Adds `row` to this heap. If inserting this new item would
712 /// increase the size past `k`, removes the previously smallest
713 /// item.
714 fn add(
715 &mut self,
716 batch_entry: &mut RecordBatchEntry,
717 row: impl AsRef<[u8]>,
718 index: usize,
719 ) {
720 let batch_id = batch_entry.id;
721 batch_entry.uses += 1;
722
723 assert!(self.inner.len() <= self.k);
724 let row = row.as_ref();
725
726 // Reuse storage for evicted item if possible
727 if self.inner.len() == self.k {
728 let mut prev_min = self.inner.peek_mut().unwrap();
729
730 // Update batch use
731 if prev_min.batch_id == batch_entry.id {
732 batch_entry.uses -= 1;
733 } else {
734 self.store.unuse(prev_min.batch_id);
735 }
736
737 // update memory accounting
738 self.owned_bytes -= prev_min.owned_size();
739
740 prev_min.replace_with(row, batch_id, index);
741
742 self.owned_bytes += prev_min.owned_size();
743 } else {
744 let new_row = TopKRow::new(row, batch_id, index);
745 self.owned_bytes += new_row.owned_size();
746 // put the new row into the heap
747 self.inner.push(new_row);
748 };
749 }
750
751 /// Returns the values stored in this heap, from values low to
752 /// high, as a single [`RecordBatch`], resetting the inner heap
753 pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
754 Ok(self.emit_with_state()?.0)
755 }
756
757 /// Returns the values stored in this heap, from values low to
758 /// high, as a single [`RecordBatch`], and a sorted vec of the
759 /// current heap's contents
760 pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
761 // generate sorted rows
762 let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
763
764 if self.store.is_empty() {
765 return Ok((None, topk_rows));
766 }
767
768 // Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then
769 // build the `indices` vec below. This is needed since the batch ids are not continuous.
770 let mut record_batches = Vec::new();
771 let mut batch_id_array_pos = HashMap::new();
772 for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() {
773 record_batches.push(&batch.batch);
774 batch_id_array_pos.insert(*batch_id, array_pos);
775 }
776
777 let indices: Vec<_> = topk_rows
778 .iter()
779 .map(|k| (batch_id_array_pos[&k.batch_id], k.index))
780 .collect();
781
782 // At this point `indices` contains indexes within the
783 // rows and `input_arrays` contains a reference to the
784 // relevant RecordBatch for that index. `interleave_record_batch` pulls
785 // them together into a single new batch
786 let new_batch = interleave_record_batch(&record_batches, &indices)?;
787
788 Ok((Some(new_batch), topk_rows))
789 }
790
791 /// Compact this heap, rewriting all stored batches into a single
792 /// input batch
793 pub fn maybe_compact(&mut self) -> Result<()> {
794 // we compact if the number of "unused" rows in the store is
795 // past some pre-defined threshold. Target holding up to
796 // around 20 batches, but handle cases of large k where some
797 // batches might be partially full
798 let max_unused_rows = (20 * self.batch_size) + self.k;
799 let unused_rows = self.store.unused_rows();
800
801 // don't compact if the store has one extra batch or
802 // unused rows is under the threshold
803 if self.store.len() <= 2 || unused_rows < max_unused_rows {
804 return Ok(());
805 }
806 // at first, compact the entire thing always into a new batch
807 // (maybe we can get fancier in the future about ignoring
808 // batches that have a high usage ratio already
809
810 // Note: new batch is in the same order as inner
811 let num_rows = self.inner.len();
812 let (new_batch, mut topk_rows) = self.emit_with_state()?;
813 let Some(new_batch) = new_batch else {
814 return Ok(());
815 };
816
817 // clear all old entries in store (this invalidates all
818 // store_ids in `inner`)
819 self.store.clear();
820
821 let mut batch_entry = self.register_batch(new_batch);
822 batch_entry.uses = num_rows;
823
824 // rewrite all existing entries to use the new batch, and
825 // remove old entries. The sortedness and their relative
826 // position do not change
827 for (i, topk_row) in topk_rows.iter_mut().enumerate() {
828 topk_row.batch_id = batch_entry.id;
829 topk_row.index = i;
830 }
831 self.insert_batch_entry(batch_entry);
832 // restore the heap
833 self.inner = BinaryHeap::from(topk_rows);
834
835 Ok(())
836 }
837
838 /// return the size of memory used by this heap, in bytes
839 fn size(&self) -> usize {
840 size_of::<Self>()
841 + (self.inner.capacity() * size_of::<TopKRow>())
842 + self.store.size()
843 + self.owned_bytes
844 }
845
846 fn get_threshold_values(
847 &self,
848 sort_exprs: &[PhysicalSortExpr],
849 ) -> Result<Option<Vec<ScalarValue>>> {
850 // If the heap doesn't have k elements yet, we can't create thresholds
851 let max_row = match self.max() {
852 Some(row) => row,
853 None => return Ok(None),
854 };
855
856 // Get the batch that contains the max row
857 let batch_entry = match self.store.get(max_row.batch_id) {
858 Some(entry) => entry,
859 None => return internal_err!("Invalid batch ID in TopKRow"),
860 };
861
862 // Extract threshold values for each sort expression
863 let mut scalar_values = Vec::with_capacity(sort_exprs.len());
864 for sort_expr in sort_exprs {
865 // Extract the value for this column from the max row
866 let expr = Arc::clone(&sort_expr.expr);
867 let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?;
868
869 // Convert to scalar value - should be a single value since we're evaluating on a single row batch
870 let scalar = match value {
871 ColumnarValue::Scalar(scalar) => scalar,
872 ColumnarValue::Array(array) if array.len() == 1 => {
873 // Extract the first (and only) value from the array
874 ScalarValue::try_from_array(&array, 0)?
875 }
876 array => {
877 return internal_err!("Expected a scalar value, got {:?}", array);
878 }
879 };
880
881 scalar_values.push(scalar);
882 }
883
884 Ok(Some(scalar_values))
885 }
886}
887
888/// Represents one of the top K rows held in this heap. Orders
889/// according to memcmp of row (e.g. the arrow Row format, but could
890/// also be primitive values)
891///
892/// Reuses allocations to minimize runtime overhead of creating new Vecs
893#[derive(Debug, PartialEq)]
894struct TopKRow {
895 /// the value of the sort key for this row. This contains the
896 /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
897 /// reuse allocations.
898 row: Vec<u8>,
899 /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
900 batch_id: u32,
901 /// the index in this record batch the row came from
902 index: usize,
903}
904
905impl TopKRow {
906 /// Create a new TopKRow with new allocation
907 fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
908 Self {
909 row: row.as_ref().to_vec(),
910 batch_id,
911 index,
912 }
913 }
914
915 // Replace the existing row capacity with new values
916 fn replace_with(&mut self, new_row: impl AsRef<[u8]>, batch_id: u32, index: usize) {
917 self.row.clear();
918 self.row.extend_from_slice(new_row.as_ref());
919
920 self.batch_id = batch_id;
921 self.index = index;
922 }
923
924 /// Returns the number of bytes owned by this row in the heap (not
925 /// including itself)
926 fn owned_size(&self) -> usize {
927 self.row.capacity()
928 }
929
930 /// Returns a slice to the owned row value
931 fn row(&self) -> &[u8] {
932 self.row.as_slice()
933 }
934}
935
936impl Eq for TopKRow {}
937
938impl PartialOrd for TopKRow {
939 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
940 // TODO PartialOrd is not consistent with PartialEq; PartialOrd contract is violated
941 Some(self.cmp(other))
942 }
943}
944
945impl Ord for TopKRow {
946 fn cmp(&self, other: &Self) -> Ordering {
947 self.row.cmp(&other.row)
948 }
949}
950
951#[derive(Debug)]
952struct RecordBatchEntry {
953 id: u32,
954 batch: RecordBatch,
955 // for this batch, how many times has it been used
956 uses: usize,
957}
958
959/// This structure tracks [`RecordBatch`] by an id so that:
960///
961/// 1. The baches can be tracked via an id that can be copied cheaply
962/// 2. The total memory held by all batches is tracked
963#[derive(Debug)]
964struct RecordBatchStore {
965 /// id generator
966 next_id: u32,
967 /// storage
968 batches: HashMap<u32, RecordBatchEntry>,
969 /// total size of all record batches tracked by this store
970 batches_size: usize,
971}
972
973impl RecordBatchStore {
974 fn new() -> Self {
975 Self {
976 next_id: 0,
977 batches: HashMap::new(),
978 batches_size: 0,
979 }
980 }
981
982 /// Register this batch with the store and assign an ID. No
983 /// attempt is made to compare this batch to other batches
984 pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
985 let id = self.next_id;
986 self.next_id += 1;
987 RecordBatchEntry { id, batch, uses: 0 }
988 }
989
990 /// Insert a record batch entry into this store, tracking its
991 /// memory use, if it has any uses
992 pub fn insert(&mut self, entry: RecordBatchEntry) {
993 // uses of 0 means that none of the rows in the batch were stored in the topk
994 if entry.uses > 0 {
995 self.batches_size += get_record_batch_memory_size(&entry.batch);
996 self.batches.insert(entry.id, entry);
997 }
998 }
999
1000 /// Clear all values in this store, invalidating all previous batch ids
1001 fn clear(&mut self) {
1002 self.batches.clear();
1003 self.batches_size = 0;
1004 }
1005
1006 fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
1007 self.batches.get(&id)
1008 }
1009
1010 /// returns the total number of batches stored in this store
1011 fn len(&self) -> usize {
1012 self.batches.len()
1013 }
1014
1015 /// Returns the total number of rows in batches minus the number
1016 /// which are in use
1017 fn unused_rows(&self) -> usize {
1018 self.batches
1019 .values()
1020 .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses)
1021 .sum()
1022 }
1023
1024 /// returns true if the store has nothing stored
1025 fn is_empty(&self) -> bool {
1026 self.batches.is_empty()
1027 }
1028
1029 /// remove a use from the specified batch id. If the use count
1030 /// reaches zero the batch entry is removed from the store
1031 ///
1032 /// panics if there were no remaining uses of id
1033 pub fn unuse(&mut self, id: u32) {
1034 let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
1035 batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
1036 batch_entry.uses == 0
1037 } else {
1038 panic!("No entry for id {id}");
1039 };
1040
1041 if remove {
1042 let old_entry = self.batches.remove(&id).unwrap();
1043 self.batches_size = self
1044 .batches_size
1045 .checked_sub(get_record_batch_memory_size(&old_entry.batch))
1046 .unwrap();
1047 }
1048 }
1049
1050 /// returns the size of memory used by this store, including all
1051 /// referenced `RecordBatch`es, in bytes
1052 pub fn size(&self) -> usize {
1053 size_of::<Self>()
1054 + self.batches.capacity() * (size_of::<u32>() + size_of::<RecordBatchEntry>())
1055 + self.batches_size
1056 }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use arrow::array::{Float64Array, Int32Array, RecordBatch};
1063 use arrow::datatypes::{DataType, Field, Schema};
1064 use arrow_schema::SortOptions;
1065 use datafusion_common::assert_batches_eq;
1066 use datafusion_physical_expr::expressions::col;
1067 use futures::TryStreamExt;
1068
1069 /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
1070 #[test]
1071 fn test_record_batch_store_size() {
1072 // given
1073 let schema = Arc::new(Schema::new(vec![
1074 Field::new("ints", DataType::Int32, true),
1075 Field::new("float64", DataType::Float64, false),
1076 ]));
1077 let mut record_batch_store = RecordBatchStore::new();
1078 let int_array =
1079 Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
1080 let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
1081
1082 let record_batch_entry = RecordBatchEntry {
1083 id: 0,
1084 batch: RecordBatch::try_new(
1085 schema,
1086 vec![Arc::new(int_array), Arc::new(float64_array)],
1087 )
1088 .unwrap(),
1089 uses: 1,
1090 };
1091
1092 // when insert record batch entry
1093 record_batch_store.insert(record_batch_entry);
1094 assert_eq!(record_batch_store.batches_size, 60);
1095
1096 // when unuse record batch entry
1097 record_batch_store.unuse(0);
1098 assert_eq!(record_batch_store.batches_size, 0);
1099 }
1100
1101 /// This test validates that the `try_finish` method marks the TopK operator as finished
1102 /// when the prefix (on column "a") of the last row in the current batch is strictly greater
1103 /// than the max top‑k row.
1104 /// The full sort expression is defined on both columns ("a", "b"), but the input ordering is only on "a".
1105 #[tokio::test]
1106 async fn test_try_finish_marks_finished_with_prefix() -> Result<()> {
1107 // Create a schema with two columns.
1108 let schema = Arc::new(Schema::new(vec![
1109 Field::new("a", DataType::Int32, false),
1110 Field::new("b", DataType::Float64, false),
1111 ]));
1112
1113 // Create sort expressions.
1114 // Full sort: first by "a", then by "b".
1115 let sort_expr_a = PhysicalSortExpr {
1116 expr: col("a", schema.as_ref())?,
1117 options: SortOptions::default(),
1118 };
1119 let sort_expr_b = PhysicalSortExpr {
1120 expr: col("b", schema.as_ref())?,
1121 options: SortOptions::default(),
1122 };
1123
1124 // Input ordering uses only column "a" (a prefix of the full sort).
1125 let prefix = vec![sort_expr_a.clone()];
1126 let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]);
1127
1128 // Create a dummy runtime environment and metrics.
1129 let runtime = Arc::new(RuntimeEnv::default());
1130 let metrics = ExecutionPlanMetricsSet::new();
1131
1132 // Create a TopK instance with k = 3 and batch_size = 2.
1133 let mut topk = TopK::try_new(
1134 0,
1135 Arc::clone(&schema),
1136 prefix,
1137 full_expr,
1138 3,
1139 2,
1140 runtime,
1141 &metrics,
1142 Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1143 DynamicFilterPhysicalExpr::new(vec![], lit(true)),
1144 )))),
1145 )?;
1146
1147 // Create the first batch with two columns:
1148 // Column "a": [1, 1, 2], Column "b": [20.0, 15.0, 30.0].
1149 let array_a1: ArrayRef =
1150 Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)]));
1151 let array_b1: ArrayRef = Arc::new(Float64Array::from(vec![20.0, 15.0, 30.0]));
1152 let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a1, array_b1])?;
1153
1154 // Insert the first batch.
1155 // At this point the heap is not yet “finished” because the prefix of the last row of the batch
1156 // is not strictly greater than the prefix of the max top‑k row (both being `2`).
1157 topk.insert_batch(batch1)?;
1158 assert!(
1159 !topk.finished,
1160 "Expected 'finished' to be false after the first batch."
1161 );
1162
1163 // Create the second batch with two columns:
1164 // Column "a": [2, 3], Column "b": [10.0, 20.0].
1165 let array_a2: ArrayRef = Arc::new(Int32Array::from(vec![Some(2), Some(3)]));
1166 let array_b2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
1167 let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a2, array_b2])?;
1168
1169 // Insert the second batch.
1170 // The last row in this batch has a prefix value of `3`,
1171 // which is strictly greater than the max top‑k row (with value `2`),
1172 // so try_finish should mark the TopK as finished.
1173 topk.insert_batch(batch2)?;
1174 assert!(
1175 topk.finished,
1176 "Expected 'finished' to be true after the second batch."
1177 );
1178
1179 // Verify the TopK correctly emits the top k rows from both batches
1180 // (the value 10.0 for b is from the second batch).
1181 let results: Vec<_> = topk.emit()?.try_collect().await?;
1182 assert_batches_eq!(
1183 &[
1184 "+---+------+",
1185 "| a | b |",
1186 "+---+------+",
1187 "| 1 | 15.0 |",
1188 "| 1 | 20.0 |",
1189 "| 2 | 10.0 |",
1190 "+---+------+",
1191 ],
1192 &results
1193 );
1194
1195 Ok(())
1196 }
1197
1198 /// This test verifies that the dynamic filter is marked as complete after TopK processing finishes.
1199 #[tokio::test]
1200 async fn test_topk_marks_filter_complete() -> Result<()> {
1201 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1202
1203 let sort_expr = PhysicalSortExpr {
1204 expr: col("a", schema.as_ref())?,
1205 options: SortOptions::default(),
1206 };
1207
1208 let full_expr = LexOrdering::from([sort_expr.clone()]);
1209 let prefix = vec![sort_expr];
1210
1211 // Create a dummy runtime environment and metrics
1212 let runtime = Arc::new(RuntimeEnv::default());
1213 let metrics = ExecutionPlanMetricsSet::new();
1214
1215 // Create a dynamic filter that we'll check for completion
1216 let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
1217 let dynamic_filter_clone = Arc::clone(&dynamic_filter);
1218
1219 // Create a TopK instance
1220 let mut topk = TopK::try_new(
1221 0,
1222 Arc::clone(&schema),
1223 prefix,
1224 full_expr,
1225 2,
1226 10,
1227 runtime,
1228 &metrics,
1229 Arc::new(RwLock::new(TopKDynamicFilters::new(dynamic_filter))),
1230 )?;
1231
1232 let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(3), Some(1), Some(2)]));
1233 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![array])?;
1234 topk.insert_batch(batch)?;
1235
1236 // Call emit to finish TopK processing
1237 let _results: Vec<_> = topk.emit()?.try_collect().await?;
1238
1239 // After emit is called, the dynamic filter should be marked as complete
1240 // wait_complete() should return immediately
1241 dynamic_filter_clone.wait_complete().await;
1242
1243 Ok(())
1244 }
1245}