datafusion_physical_plan/topk/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! TopK: Combination of Sort / LIMIT
19
20use arrow::{
21    array::{Array, AsArray},
22    compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter},
23    row::{RowConverter, Rows, SortField},
24};
25use datafusion_expr::{ColumnarValue, Operator};
26use std::mem::size_of;
27use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
28
29use super::metrics::{
30    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, RecordOutput,
31};
32use crate::spill::get_record_batch_memory_size;
33use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter};
34
35use arrow::array::{ArrayRef, RecordBatch};
36use arrow::datatypes::SchemaRef;
37use datafusion_common::{
38    HashMap, Result, ScalarValue, internal_datafusion_err, internal_err,
39};
40use datafusion_execution::{
41    memory_pool::{MemoryConsumer, MemoryReservation},
42    runtime_env::RuntimeEnv,
43};
44use datafusion_physical_expr::{
45    PhysicalExpr,
46    expressions::{BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit},
47};
48use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
49use parking_lot::RwLock;
50
51/// Global TopK
52///
53/// # Background
54///
55/// "Top K" is a common query optimization used for queries such as
56/// "find the top 3 customers by revenue". The (simplified) SQL for
57/// such a query might be:
58///
59/// ```sql
60/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
61/// ```
62///
63/// The simple plan would be:
64///
65/// ```sql
66/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
67/// +--------------+----------------------------------------+
68/// | plan_type    | plan                                   |
69/// +--------------+----------------------------------------+
70/// | logical_plan | Limit: 3                               |
71/// |              |   Sort: revenue DESC NULLS FIRST       |
72/// |              |     Projection: customer_id, revenue   |
73/// |              |       TableScan: sales                 |
74/// +--------------+----------------------------------------+
75/// ```
76///
77/// While this plan produces the correct answer, it will fully sorts the
78/// input before discarding everything other than the top 3 elements.
79///
80/// The same answer can be produced by simply keeping track of the top
81/// K=3 elements, reducing the total amount of required buffer memory.
82///
83/// # Partial Sort Optimization
84///
85/// This implementation additionally optimizes queries where the input is already
86/// partially sorted by a common prefix of the requested ordering. Once the top K
87/// heap is full, if subsequent rows are guaranteed to be strictly greater (in sort
88/// order) on this prefix than the largest row currently stored, the operator
89/// safely terminates early.
90///
91/// ## Example
92///
93/// For input sorted by `(day DESC)`, but not by `timestamp`, a query such as:
94///
95/// ```sql
96/// SELECT day, timestamp FROM sensor ORDER BY day DESC, timestamp DESC LIMIT 10;
97/// ```
98///
99/// can terminate scanning early once sufficient rows from the latest days have been
100/// collected, skipping older data.
101///
102/// # Structure
103///
104/// This operator tracks the top K items using a `TopKHeap`.
105pub struct TopK {
106    /// schema of the output (and the input)
107    schema: SchemaRef,
108    /// Runtime metrics
109    metrics: TopKMetrics,
110    /// Reservation
111    reservation: MemoryReservation,
112    /// The target number of rows for output batches
113    batch_size: usize,
114    /// sort expressions
115    expr: LexOrdering,
116    /// row converter, for sort keys
117    row_converter: RowConverter,
118    /// scratch space for converting rows
119    scratch_rows: Rows,
120    /// stores the top k values and their sort key values, in order
121    heap: TopKHeap,
122    /// row converter, for common keys between the sort keys and the input ordering
123    common_sort_prefix_converter: Option<RowConverter>,
124    /// Common sort prefix between the input and the sort expressions to allow early exit optimization
125    common_sort_prefix: Arc<[PhysicalSortExpr]>,
126    /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
127    filter: Arc<RwLock<TopKDynamicFilters>>,
128    /// If true, indicates that all rows of subsequent batches are guaranteed
129    /// to be greater (by byte order, after row conversion) than the top K,
130    /// which means the top K won't change and the computation can be finished early.
131    pub(crate) finished: bool,
132}
133
134#[derive(Debug, Clone)]
135pub struct TopKDynamicFilters {
136    /// The current *global* threshold for the dynamic filter.
137    /// This is shared across all partitions and is updated by any of them.
138    /// Stored as row bytes for efficient comparison.
139    threshold_row: Option<Vec<u8>>,
140    /// The expression used to evaluate the dynamic filter
141    /// Only updated when lock held for the duration of the update
142    expr: Arc<DynamicFilterPhysicalExpr>,
143}
144
145impl TopKDynamicFilters {
146    /// Create a new `TopKDynamicFilters` with the given expression
147    pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
148        Self {
149            threshold_row: None,
150            expr,
151        }
152    }
153
154    pub fn expr(&self) -> Arc<DynamicFilterPhysicalExpr> {
155        Arc::clone(&self.expr)
156    }
157}
158
159// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
160const ESTIMATED_BYTES_PER_ROW: usize = 20;
161
162fn build_sort_fields(
163    ordering: &[PhysicalSortExpr],
164    schema: &SchemaRef,
165) -> Result<Vec<SortField>> {
166    ordering
167        .iter()
168        .map(|e| {
169            Ok(SortField::new_with_options(
170                e.expr.data_type(schema)?,
171                e.options,
172            ))
173        })
174        .collect::<Result<_>>()
175}
176
177impl TopK {
178    /// Create a new [`TopK`] that stores the top `k` values, as
179    /// defined by the sort expressions in `expr`.
180    // TODO: make a builder or some other nicer API
181    #[expect(clippy::too_many_arguments)]
182    #[expect(clippy::needless_pass_by_value)]
183    pub fn try_new(
184        partition_id: usize,
185        schema: SchemaRef,
186        common_sort_prefix: Vec<PhysicalSortExpr>,
187        expr: LexOrdering,
188        k: usize,
189        batch_size: usize,
190        runtime: Arc<RuntimeEnv>,
191        metrics: &ExecutionPlanMetricsSet,
192        filter: Arc<RwLock<TopKDynamicFilters>>,
193    ) -> Result<Self> {
194        let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
195            .register(&runtime.memory_pool);
196
197        let sort_fields = build_sort_fields(&expr, &schema)?;
198
199        // TODO there is potential to add special cases for single column sort fields
200        // to improve performance
201        let row_converter = RowConverter::new(sort_fields)?;
202        let scratch_rows =
203            row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size);
204
205        let prefix_row_converter = if common_sort_prefix.is_empty() {
206            None
207        } else {
208            let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?;
209            Some(RowConverter::new(input_sort_fields)?)
210        };
211
212        Ok(Self {
213            schema: Arc::clone(&schema),
214            metrics: TopKMetrics::new(metrics, partition_id),
215            reservation,
216            batch_size,
217            expr,
218            row_converter,
219            scratch_rows,
220            heap: TopKHeap::new(k, batch_size),
221            common_sort_prefix_converter: prefix_row_converter,
222            common_sort_prefix: Arc::from(common_sort_prefix),
223            finished: false,
224            filter,
225        })
226    }
227
228    /// Insert `batch`, remembering if any of its values are among
229    /// the top k seen so far.
230    #[expect(clippy::needless_pass_by_value)]
231    pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
232        // Updates on drop
233        let baseline = self.metrics.baseline.clone();
234        let _timer = baseline.elapsed_compute().timer();
235
236        let mut sort_keys: Vec<ArrayRef> = self
237            .expr
238            .iter()
239            .map(|expr| {
240                let value = expr.expr.evaluate(&batch)?;
241                value.into_array(batch.num_rows())
242            })
243            .collect::<Result<Vec<_>>>()?;
244
245        let mut selected_rows = None;
246
247        // If a filter is provided, update it with the new rows
248        let filter = self.filter.read().expr.current()?;
249        let filtered = filter.evaluate(&batch)?;
250        let num_rows = batch.num_rows();
251        let array = filtered.into_array(num_rows)?;
252        let mut filter = array.as_boolean().clone();
253        let true_count = filter.true_count();
254        if true_count == 0 {
255            // nothing to filter, so no need to update
256            return Ok(());
257        }
258        // only update the keys / rows if the filter does not match all rows
259        if true_count < num_rows {
260            // Indices in `set_indices` should be correct if filter contains nulls
261            // So we prepare the filter here. Note this is also done in the `FilterBuilder`
262            // so there is no overhead to do this here.
263            if filter.nulls().is_some() {
264                filter = prep_null_mask_filter(&filter);
265            }
266
267            let filter_predicate = FilterBuilder::new(&filter);
268            let filter_predicate = if sort_keys.len() > 1 {
269                // Optimize filter when it has multiple sort keys
270                filter_predicate.optimize().build()
271            } else {
272                filter_predicate.build()
273            };
274            selected_rows = Some(filter);
275            sort_keys = sort_keys
276                .iter()
277                .map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
278                .collect::<Result<Vec<_>>>()?;
279        }
280        // reuse existing `Rows` to avoid reallocations
281        let rows = &mut self.scratch_rows;
282        rows.clear();
283        self.row_converter.append(rows, &sort_keys)?;
284
285        let mut batch_entry = self.heap.register_batch(batch.clone());
286
287        let replacements = match selected_rows {
288            Some(filter) => {
289                self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry)
290            }
291            None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry),
292        };
293
294        if replacements > 0 {
295            self.metrics.row_replacements.add(replacements);
296
297            self.heap.insert_batch_entry(batch_entry);
298
299            // conserve memory
300            self.heap.maybe_compact()?;
301
302            // update memory reservation
303            self.reservation.try_resize(self.size())?;
304
305            // flag the topK as finished if we know that all
306            // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K,
307            // which means the top K won't change and the computation can be finished early.
308            self.attempt_early_completion(&batch)?;
309
310            // update the filter representation of our TopK heap
311            self.update_filter()?;
312        }
313
314        Ok(())
315    }
316
317    fn find_new_topk_items(
318        &mut self,
319        items: impl Iterator<Item = usize>,
320        batch_entry: &mut RecordBatchEntry,
321    ) -> usize {
322        let mut replacements = 0;
323        let rows = &mut self.scratch_rows;
324        for (index, row) in items.zip(rows.iter()) {
325            match self.heap.max() {
326                // heap has k items, and the new row is greater than the
327                // current max in the heap ==> it is not a new topk
328                Some(max_row) if row.as_ref() >= max_row.row() => {}
329                // don't yet have k items or new item is lower than the currently k low values
330                None | Some(_) => {
331                    self.heap.add(batch_entry, row, index);
332                    replacements += 1;
333                }
334            }
335        }
336        replacements
337    }
338
339    /// Update the filter representation of our TopK heap.
340    /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`,
341    /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`,
342    /// the filter will be updated to:
343    ///
344    /// ```sql
345    /// (a > 1 OR (a = 1 AND b < 5)) AND
346    /// (a > 1 OR (a = 1 AND b < 4)) AND
347    /// (a > 2 OR (a = 2 AND b < 3))
348    /// ```
349    fn update_filter(&mut self) -> Result<()> {
350        // If the heap doesn't have k elements yet, we can't create thresholds
351        let Some(max_row) = self.heap.max() else {
352            return Ok(());
353        };
354
355        let new_threshold_row = &max_row.row;
356
357        // Fast path: check if the current value in topk is better than what is
358        // currently set in the filter with a read only lock
359        let needs_update = self
360            .filter
361            .read()
362            .threshold_row
363            .as_ref()
364            .map(|current_row| {
365                // new < current means new threshold is more selective
366                new_threshold_row < current_row
367            })
368            .unwrap_or(true); // No current threshold, so we need to set one
369
370        // exit early if the current values are better
371        if !needs_update {
372            return Ok(());
373        }
374
375        // Extract scalar values BEFORE acquiring lock to reduce critical section
376        let thresholds = match self.heap.get_threshold_values(&self.expr)? {
377            Some(t) => t,
378            None => return Ok(()),
379        };
380
381        // Build the filter expression OUTSIDE any synchronization
382        let predicate = Self::build_filter_expression(&self.expr, &thresholds)?;
383        let new_threshold = new_threshold_row.to_vec();
384
385        // update the threshold. Since there was a lock gap, we must check if it is still the best
386        // may have changed while we were building the expression without the lock
387        let mut filter = self.filter.write();
388        let old_threshold = filter.threshold_row.take();
389
390        // Update filter if we successfully updated the threshold
391        // (or if there was no previous threshold and we're the first)
392        match old_threshold {
393            Some(old_threshold) => {
394                // new threshold is still better than the old one
395                if new_threshold.as_slice() < old_threshold.as_slice() {
396                    filter.threshold_row = Some(new_threshold);
397                } else {
398                    // some other thread updated the threshold to a better
399                    // one while we were building so there is no need to
400                    // update the filter
401                    filter.threshold_row = Some(old_threshold);
402                    return Ok(());
403                }
404            }
405            None => {
406                // No previous threshold, so we can set the new one
407                filter.threshold_row = Some(new_threshold);
408            }
409        };
410
411        // Update the filter expression
412        if let Some(pred) = predicate
413            && !pred.eq(&lit(true))
414        {
415            filter.expr.update(pred)?;
416        }
417
418        Ok(())
419    }
420
421    /// Build the filter expression with the given thresholds.
422    /// This is now called outside of any locks to reduce critical section time.
423    fn build_filter_expression(
424        sort_exprs: &[PhysicalSortExpr],
425        thresholds: &[ScalarValue],
426    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
427        // Create filter expressions for each threshold
428        let mut filters: Vec<Arc<dyn PhysicalExpr>> =
429            Vec::with_capacity(thresholds.len());
430
431        let mut prev_sort_expr: Option<Arc<dyn PhysicalExpr>> = None;
432        for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) {
433            // Create the appropriate operator based on sort order
434            let op = if sort_expr.options.descending {
435                // For descending sort, we want col > threshold (exclude smaller values)
436                Operator::Gt
437            } else {
438                // For ascending sort, we want col < threshold (exclude larger values)
439                Operator::Lt
440            };
441
442            let value_null = value.is_null();
443
444            let comparison = Arc::new(BinaryExpr::new(
445                Arc::clone(&sort_expr.expr),
446                op,
447                lit(value.clone()),
448            ));
449
450            let comparison_with_null = match (sort_expr.options.nulls_first, value_null) {
451                // For nulls first, transform to (threshold.value is not null) and (threshold.expr is null or comparison)
452                (true, true) => lit(false),
453                (true, false) => Arc::new(BinaryExpr::new(
454                    is_null(Arc::clone(&sort_expr.expr))?,
455                    Operator::Or,
456                    comparison,
457                )),
458                // For nulls last, transform to (threshold.value is null and threshold.expr is not null)
459                // or (threshold.value is not null and comparison)
460                (false, true) => is_not_null(Arc::clone(&sort_expr.expr))?,
461                (false, false) => comparison,
462            };
463
464            let mut eq_expr = Arc::new(BinaryExpr::new(
465                Arc::clone(&sort_expr.expr),
466                Operator::Eq,
467                lit(value.clone()),
468            ));
469
470            if value_null {
471                eq_expr = Arc::new(BinaryExpr::new(
472                    is_null(Arc::clone(&sort_expr.expr))?,
473                    Operator::Or,
474                    eq_expr,
475                ));
476            }
477
478            // For a query like order by a, b, the filter for column `b` is only applied if
479            // the condition a = threshold.value (considering null equality) is met.
480            // Therefore, we add equality predicates for all preceding fields to the filter logic of the current field,
481            // and include the current field's equality predicate in `prev_sort_expr` for use with subsequent fields.
482            match prev_sort_expr.take() {
483                None => {
484                    prev_sort_expr = Some(eq_expr);
485                    filters.push(comparison_with_null);
486                }
487                Some(p) => {
488                    filters.push(Arc::new(BinaryExpr::new(
489                        Arc::clone(&p),
490                        Operator::And,
491                        comparison_with_null,
492                    )));
493
494                    prev_sort_expr =
495                        Some(Arc::new(BinaryExpr::new(p, Operator::And, eq_expr)));
496                }
497            }
498        }
499
500        let dynamic_predicate = filters
501            .into_iter()
502            .reduce(|a, b| Arc::new(BinaryExpr::new(a, Operator::Or, b)));
503
504        Ok(dynamic_predicate)
505    }
506
507    /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full,
508    /// check if the computation can be finished early.
509    /// This is the case if the last row of the current batch is strictly greater than the max row in the heap,
510    /// comparing only on the shared prefix columns.
511    fn attempt_early_completion(&mut self, batch: &RecordBatch) -> Result<()> {
512        // Early exit if the batch is empty as there is no last row to extract from it.
513        if batch.num_rows() == 0 {
514            return Ok(());
515        }
516
517        // prefix_row_converter is only `Some` if the input ordering has a common prefix with the TopK,
518        // so early exit if it is `None`.
519        let Some(prefix_converter) = &self.common_sort_prefix_converter else {
520            return Ok(());
521        };
522
523        // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full).
524        let Some(max_topk_row) = self.heap.max() else {
525            return Ok(());
526        };
527
528        // Evaluate the prefix for the last row of the current batch.
529        let last_row_idx = batch.num_rows() - 1;
530        let mut batch_prefix_scratch =
531            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
532
533        self.compute_common_sort_prefix(batch, last_row_idx, &mut batch_prefix_scratch)?;
534
535        // Retrieve the max row from the heap.
536        let store_entry = self
537            .heap
538            .store
539            .get(max_topk_row.batch_id)
540            .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?;
541        let max_batch = &store_entry.batch;
542        let mut heap_prefix_scratch =
543            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
544        self.compute_common_sort_prefix(
545            max_batch,
546            max_topk_row.index,
547            &mut heap_prefix_scratch,
548        )?;
549
550        // If the last row's prefix is strictly greater than the max prefix, mark as finished.
551        if batch_prefix_scratch.row(0).as_ref() > heap_prefix_scratch.row(0).as_ref() {
552            self.finished = true;
553        }
554
555        Ok(())
556    }
557
558    // Helper function to compute the prefix for a given batch and row index, storing the result in scratch.
559    fn compute_common_sort_prefix(
560        &self,
561        batch: &RecordBatch,
562        last_row_idx: usize,
563        scratch: &mut Rows,
564    ) -> Result<()> {
565        let last_row: Vec<ArrayRef> = self
566            .common_sort_prefix
567            .iter()
568            .map(|expr| {
569                expr.expr
570                    .evaluate(&batch.slice(last_row_idx, 1))?
571                    .into_array(1)
572            })
573            .collect::<Result<_>>()?;
574
575        self.common_sort_prefix_converter
576            .as_ref()
577            .unwrap()
578            .append(scratch, &last_row)?;
579        Ok(())
580    }
581
582    /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
583    pub fn emit(self) -> Result<SendableRecordBatchStream> {
584        let Self {
585            schema,
586            metrics,
587            reservation: _,
588            batch_size,
589            expr: _,
590            row_converter: _,
591            scratch_rows: _,
592            mut heap,
593            common_sort_prefix_converter: _,
594            common_sort_prefix: _,
595            finished: _,
596            filter,
597        } = self;
598        let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
599
600        // Mark the dynamic filter as complete now that TopK processing is finished.
601        filter.read().expr().mark_complete();
602
603        // break into record batches as needed
604        let mut batches = vec![];
605        if let Some(mut batch) = heap.emit()? {
606            (&batch).record_output(&metrics.baseline);
607
608            loop {
609                if batch.num_rows() <= batch_size {
610                    batches.push(Ok(batch));
611                    break;
612                } else {
613                    batches.push(Ok(batch.slice(0, batch_size)));
614                    let remaining_length = batch.num_rows() - batch_size;
615                    batch = batch.slice(batch_size, remaining_length);
616                }
617            }
618        };
619        Ok(Box::pin(RecordBatchStreamAdapter::new(
620            schema,
621            futures::stream::iter(batches),
622        )))
623    }
624
625    /// return the size of memory used by this operator, in bytes
626    fn size(&self) -> usize {
627        size_of::<Self>()
628            + self.row_converter.size()
629            + self.scratch_rows.size()
630            + self.heap.size()
631    }
632}
633
634struct TopKMetrics {
635    /// metrics
636    pub baseline: BaselineMetrics,
637
638    /// count of how many rows were replaced in the heap
639    pub row_replacements: Count,
640}
641
642impl TopKMetrics {
643    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
644        Self {
645            baseline: BaselineMetrics::new(metrics, partition),
646            row_replacements: MetricBuilder::new(metrics)
647                .counter("row_replacements", partition),
648        }
649    }
650}
651
652/// This structure keeps at most the *smallest* k items, using the
653/// [arrow::row] format for sort keys. While it is called "topK" for
654/// values like `1, 2, 3, 4, 5` the "top 3" really means the
655/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
656///
657/// Using the `Row` format handles things such as ascending vs
658/// descending and nulls first vs nulls last.
659struct TopKHeap {
660    /// The maximum number of elements to store in this heap.
661    k: usize,
662    /// The target number of rows for output batches
663    batch_size: usize,
664    /// Storage for up at most `k` items using a BinaryHeap. Reversed
665    /// so that the smallest k so far is on the top
666    inner: BinaryHeap<TopKRow>,
667    /// Storage the original row values (TopKRow only has the sort key)
668    store: RecordBatchStore,
669    /// The size of all owned data held by this heap
670    owned_bytes: usize,
671}
672
673impl TopKHeap {
674    fn new(k: usize, batch_size: usize) -> Self {
675        assert!(k > 0);
676        Self {
677            k,
678            batch_size,
679            inner: BinaryHeap::new(),
680            store: RecordBatchStore::new(),
681            owned_bytes: 0,
682        }
683    }
684
685    /// Register a [`RecordBatch`] with the heap, returning the
686    /// appropriate entry
687    pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
688        self.store.register(batch)
689    }
690
691    /// Insert a [`RecordBatchEntry`] created by a previous call to
692    /// [`Self::register_batch`] into storage.
693    pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
694        self.store.insert(entry)
695    }
696
697    /// Returns the largest value stored by the heap if there are k
698    /// items, otherwise returns None. Remember this structure is
699    /// keeping the "smallest" k values
700    fn max(&self) -> Option<&TopKRow> {
701        if self.inner.len() < self.k {
702            None
703        } else {
704            self.inner.peek()
705        }
706    }
707
708    /// Adds `row` to this heap. If inserting this new item would
709    /// increase the size past `k`, removes the previously smallest
710    /// item.
711    fn add(
712        &mut self,
713        batch_entry: &mut RecordBatchEntry,
714        row: impl AsRef<[u8]>,
715        index: usize,
716    ) {
717        let batch_id = batch_entry.id;
718        batch_entry.uses += 1;
719
720        assert!(self.inner.len() <= self.k);
721        let row = row.as_ref();
722
723        // Reuse storage for evicted item if possible
724        let new_top_k = if self.inner.len() == self.k {
725            let prev_min = self.inner.pop().unwrap();
726
727            // Update batch use
728            if prev_min.batch_id == batch_entry.id {
729                batch_entry.uses -= 1;
730            } else {
731                self.store.unuse(prev_min.batch_id);
732            }
733
734            // update memory accounting
735            self.owned_bytes -= prev_min.owned_size();
736            prev_min.with_new_row(row, batch_id, index)
737        } else {
738            TopKRow::new(row, batch_id, index)
739        };
740
741        self.owned_bytes += new_top_k.owned_size();
742
743        // put the new row into the heap
744        self.inner.push(new_top_k)
745    }
746
747    /// Returns the values stored in this heap, from values low to
748    /// high, as a single [`RecordBatch`], resetting the inner heap
749    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
750        Ok(self.emit_with_state()?.0)
751    }
752
753    /// Returns the values stored in this heap, from values low to
754    /// high, as a single [`RecordBatch`], and a sorted vec of the
755    /// current heap's contents
756    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
757        // generate sorted rows
758        let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
759
760        if self.store.is_empty() {
761            return Ok((None, topk_rows));
762        }
763
764        // Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then
765        // build the `indices` vec below. This is needed since the batch ids are not continuous.
766        let mut record_batches = Vec::new();
767        let mut batch_id_array_pos = HashMap::new();
768        for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() {
769            record_batches.push(&batch.batch);
770            batch_id_array_pos.insert(*batch_id, array_pos);
771        }
772
773        let indices: Vec<_> = topk_rows
774            .iter()
775            .map(|k| (batch_id_array_pos[&k.batch_id], k.index))
776            .collect();
777
778        // At this point `indices` contains indexes within the
779        // rows and `input_arrays` contains a reference to the
780        // relevant RecordBatch for that index. `interleave_record_batch` pulls
781        // them together into a single new batch
782        let new_batch = interleave_record_batch(&record_batches, &indices)?;
783
784        Ok((Some(new_batch), topk_rows))
785    }
786
787    /// Compact this heap, rewriting all stored batches into a single
788    /// input batch
789    pub fn maybe_compact(&mut self) -> Result<()> {
790        // we compact if the number of "unused" rows in the store is
791        // past some pre-defined threshold. Target holding up to
792        // around 20 batches, but handle cases of large k where some
793        // batches might be partially full
794        let max_unused_rows = (20 * self.batch_size) + self.k;
795        let unused_rows = self.store.unused_rows();
796
797        // don't compact if the store has one extra batch or
798        // unused rows is under the threshold
799        if self.store.len() <= 2 || unused_rows < max_unused_rows {
800            return Ok(());
801        }
802        // at first, compact the entire thing always into a new batch
803        // (maybe we can get fancier in the future about ignoring
804        // batches that have a high usage ratio already
805
806        // Note: new batch is in the same order as inner
807        let num_rows = self.inner.len();
808        let (new_batch, mut topk_rows) = self.emit_with_state()?;
809        let Some(new_batch) = new_batch else {
810            return Ok(());
811        };
812
813        // clear all old entries in store (this invalidates all
814        // store_ids in `inner`)
815        self.store.clear();
816
817        let mut batch_entry = self.register_batch(new_batch);
818        batch_entry.uses = num_rows;
819
820        // rewrite all existing entries to use the new batch, and
821        // remove old entries. The sortedness and their relative
822        // position do not change
823        for (i, topk_row) in topk_rows.iter_mut().enumerate() {
824            topk_row.batch_id = batch_entry.id;
825            topk_row.index = i;
826        }
827        self.insert_batch_entry(batch_entry);
828        // restore the heap
829        self.inner = BinaryHeap::from(topk_rows);
830
831        Ok(())
832    }
833
834    /// return the size of memory used by this heap, in bytes
835    fn size(&self) -> usize {
836        size_of::<Self>()
837            + (self.inner.capacity() * size_of::<TopKRow>())
838            + self.store.size()
839            + self.owned_bytes
840    }
841
842    fn get_threshold_values(
843        &self,
844        sort_exprs: &[PhysicalSortExpr],
845    ) -> Result<Option<Vec<ScalarValue>>> {
846        // If the heap doesn't have k elements yet, we can't create thresholds
847        let max_row = match self.max() {
848            Some(row) => row,
849            None => return Ok(None),
850        };
851
852        // Get the batch that contains the max row
853        let batch_entry = match self.store.get(max_row.batch_id) {
854            Some(entry) => entry,
855            None => return internal_err!("Invalid batch ID in TopKRow"),
856        };
857
858        // Extract threshold values for each sort expression
859        let mut scalar_values = Vec::with_capacity(sort_exprs.len());
860        for sort_expr in sort_exprs {
861            // Extract the value for this column from the max row
862            let expr = Arc::clone(&sort_expr.expr);
863            let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?;
864
865            // Convert to scalar value - should be a single value since we're evaluating on a single row batch
866            let scalar = match value {
867                ColumnarValue::Scalar(scalar) => scalar,
868                ColumnarValue::Array(array) if array.len() == 1 => {
869                    // Extract the first (and only) value from the array
870                    ScalarValue::try_from_array(&array, 0)?
871                }
872                array => {
873                    return internal_err!("Expected a scalar value, got {:?}", array);
874                }
875            };
876
877            scalar_values.push(scalar);
878        }
879
880        Ok(Some(scalar_values))
881    }
882}
883
884/// Represents one of the top K rows held in this heap. Orders
885/// according to memcmp of row (e.g. the arrow Row format, but could
886/// also be primitive values)
887///
888/// Reuses allocations to minimize runtime overhead of creating new Vecs
889#[derive(Debug, PartialEq)]
890struct TopKRow {
891    /// the value of the sort key for this row. This contains the
892    /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
893    /// reuse allocations.
894    row: Vec<u8>,
895    /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
896    batch_id: u32,
897    /// the index in this record batch the row came from
898    index: usize,
899}
900
901impl TopKRow {
902    /// Create a new TopKRow with new allocation
903    fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
904        Self {
905            row: row.as_ref().to_vec(),
906            batch_id,
907            index,
908        }
909    }
910
911    /// Create a new  TopKRow reusing the existing allocation
912    fn with_new_row(
913        self,
914        new_row: impl AsRef<[u8]>,
915        batch_id: u32,
916        index: usize,
917    ) -> Self {
918        let Self {
919            mut row,
920            batch_id: _,
921            index: _,
922        } = self;
923        row.clear();
924        row.extend_from_slice(new_row.as_ref());
925
926        Self {
927            row,
928            batch_id,
929            index,
930        }
931    }
932
933    /// Returns the number of bytes owned by this row in the heap (not
934    /// including itself)
935    fn owned_size(&self) -> usize {
936        self.row.capacity()
937    }
938
939    /// Returns a slice to the owned row value
940    fn row(&self) -> &[u8] {
941        self.row.as_slice()
942    }
943}
944
945impl Eq for TopKRow {}
946
947impl PartialOrd for TopKRow {
948    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
949        // TODO PartialOrd is not consistent with PartialEq; PartialOrd contract is violated
950        Some(self.cmp(other))
951    }
952}
953
954impl Ord for TopKRow {
955    fn cmp(&self, other: &Self) -> Ordering {
956        self.row.cmp(&other.row)
957    }
958}
959
960#[derive(Debug)]
961struct RecordBatchEntry {
962    id: u32,
963    batch: RecordBatch,
964    // for this batch, how many times has it been used
965    uses: usize,
966}
967
968/// This structure tracks [`RecordBatch`] by an id so that:
969///
970/// 1. The baches can be tracked via an id that can be copied cheaply
971/// 2. The total memory held by all batches is tracked
972#[derive(Debug)]
973struct RecordBatchStore {
974    /// id generator
975    next_id: u32,
976    /// storage
977    batches: HashMap<u32, RecordBatchEntry>,
978    /// total size of all record batches tracked by this store
979    batches_size: usize,
980}
981
982impl RecordBatchStore {
983    fn new() -> Self {
984        Self {
985            next_id: 0,
986            batches: HashMap::new(),
987            batches_size: 0,
988        }
989    }
990
991    /// Register this batch with the store and assign an ID. No
992    /// attempt is made to compare this batch to other batches
993    pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
994        let id = self.next_id;
995        self.next_id += 1;
996        RecordBatchEntry { id, batch, uses: 0 }
997    }
998
999    /// Insert a record batch entry into this store, tracking its
1000    /// memory use, if it has any uses
1001    pub fn insert(&mut self, entry: RecordBatchEntry) {
1002        // uses of 0 means that none of the rows in the batch were stored in the topk
1003        if entry.uses > 0 {
1004            self.batches_size += get_record_batch_memory_size(&entry.batch);
1005            self.batches.insert(entry.id, entry);
1006        }
1007    }
1008
1009    /// Clear all values in this store, invalidating all previous batch ids
1010    fn clear(&mut self) {
1011        self.batches.clear();
1012        self.batches_size = 0;
1013    }
1014
1015    fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
1016        self.batches.get(&id)
1017    }
1018
1019    /// returns the total number of batches stored in this store
1020    fn len(&self) -> usize {
1021        self.batches.len()
1022    }
1023
1024    /// Returns the total number of rows in batches minus the number
1025    /// which are in use
1026    fn unused_rows(&self) -> usize {
1027        self.batches
1028            .values()
1029            .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses)
1030            .sum()
1031    }
1032
1033    /// returns true if the store has nothing stored
1034    fn is_empty(&self) -> bool {
1035        self.batches.is_empty()
1036    }
1037
1038    /// remove a use from the specified batch id. If the use count
1039    /// reaches zero the batch entry is removed from the store
1040    ///
1041    /// panics if there were no remaining uses of id
1042    pub fn unuse(&mut self, id: u32) {
1043        let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
1044            batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
1045            batch_entry.uses == 0
1046        } else {
1047            panic!("No entry for id {id}");
1048        };
1049
1050        if remove {
1051            let old_entry = self.batches.remove(&id).unwrap();
1052            self.batches_size = self
1053                .batches_size
1054                .checked_sub(get_record_batch_memory_size(&old_entry.batch))
1055                .unwrap();
1056        }
1057    }
1058
1059    /// returns the size of memory used by this store, including all
1060    /// referenced `RecordBatch`es, in bytes
1061    pub fn size(&self) -> usize {
1062        size_of::<Self>()
1063            + self.batches.capacity() * (size_of::<u32>() + size_of::<RecordBatchEntry>())
1064            + self.batches_size
1065    }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::*;
1071    use arrow::array::{Float64Array, Int32Array, RecordBatch};
1072    use arrow::datatypes::{DataType, Field, Schema};
1073    use arrow_schema::SortOptions;
1074    use datafusion_common::assert_batches_eq;
1075    use datafusion_physical_expr::expressions::col;
1076    use futures::TryStreamExt;
1077
1078    /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
1079    #[test]
1080    fn test_record_batch_store_size() {
1081        // given
1082        let schema = Arc::new(Schema::new(vec![
1083            Field::new("ints", DataType::Int32, true),
1084            Field::new("float64", DataType::Float64, false),
1085        ]));
1086        let mut record_batch_store = RecordBatchStore::new();
1087        let int_array =
1088            Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
1089        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
1090
1091        let record_batch_entry = RecordBatchEntry {
1092            id: 0,
1093            batch: RecordBatch::try_new(
1094                schema,
1095                vec![Arc::new(int_array), Arc::new(float64_array)],
1096            )
1097            .unwrap(),
1098            uses: 1,
1099        };
1100
1101        // when insert record batch entry
1102        record_batch_store.insert(record_batch_entry);
1103        assert_eq!(record_batch_store.batches_size, 60);
1104
1105        // when unuse record batch entry
1106        record_batch_store.unuse(0);
1107        assert_eq!(record_batch_store.batches_size, 0);
1108    }
1109
1110    /// This test validates that the `try_finish` method marks the TopK operator as finished
1111    /// when the prefix (on column "a") of the last row in the current batch is strictly greater
1112    /// than the max top‑k row.
1113    /// The full sort expression is defined on both columns ("a", "b"), but the input ordering is only on "a".
1114    #[tokio::test]
1115    async fn test_try_finish_marks_finished_with_prefix() -> Result<()> {
1116        // Create a schema with two columns.
1117        let schema = Arc::new(Schema::new(vec![
1118            Field::new("a", DataType::Int32, false),
1119            Field::new("b", DataType::Float64, false),
1120        ]));
1121
1122        // Create sort expressions.
1123        // Full sort: first by "a", then by "b".
1124        let sort_expr_a = PhysicalSortExpr {
1125            expr: col("a", schema.as_ref())?,
1126            options: SortOptions::default(),
1127        };
1128        let sort_expr_b = PhysicalSortExpr {
1129            expr: col("b", schema.as_ref())?,
1130            options: SortOptions::default(),
1131        };
1132
1133        // Input ordering uses only column "a" (a prefix of the full sort).
1134        let prefix = vec![sort_expr_a.clone()];
1135        let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]);
1136
1137        // Create a dummy runtime environment and metrics.
1138        let runtime = Arc::new(RuntimeEnv::default());
1139        let metrics = ExecutionPlanMetricsSet::new();
1140
1141        // Create a TopK instance with k = 3 and batch_size = 2.
1142        let mut topk = TopK::try_new(
1143            0,
1144            Arc::clone(&schema),
1145            prefix,
1146            full_expr,
1147            3,
1148            2,
1149            runtime,
1150            &metrics,
1151            Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1152                DynamicFilterPhysicalExpr::new(vec![], lit(true)),
1153            )))),
1154        )?;
1155
1156        // Create the first batch with two columns:
1157        // Column "a": [1, 1, 2], Column "b": [20.0, 15.0, 30.0].
1158        let array_a1: ArrayRef =
1159            Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)]));
1160        let array_b1: ArrayRef = Arc::new(Float64Array::from(vec![20.0, 15.0, 30.0]));
1161        let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a1, array_b1])?;
1162
1163        // Insert the first batch.
1164        // At this point the heap is not yet “finished” because the prefix of the last row of the batch
1165        // is not strictly greater than the prefix of the max top‑k row (both being `2`).
1166        topk.insert_batch(batch1)?;
1167        assert!(
1168            !topk.finished,
1169            "Expected 'finished' to be false after the first batch."
1170        );
1171
1172        // Create the second batch with two columns:
1173        // Column "a": [2, 3], Column "b": [10.0, 20.0].
1174        let array_a2: ArrayRef = Arc::new(Int32Array::from(vec![Some(2), Some(3)]));
1175        let array_b2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
1176        let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a2, array_b2])?;
1177
1178        // Insert the second batch.
1179        // The last row in this batch has a prefix value of `3`,
1180        // which is strictly greater than the max top‑k row (with value `2`),
1181        // so try_finish should mark the TopK as finished.
1182        topk.insert_batch(batch2)?;
1183        assert!(
1184            topk.finished,
1185            "Expected 'finished' to be true after the second batch."
1186        );
1187
1188        // Verify the TopK correctly emits the top k rows from both batches
1189        // (the value 10.0 for b is from the second batch).
1190        let results: Vec<_> = topk.emit()?.try_collect().await?;
1191        assert_batches_eq!(
1192            &[
1193                "+---+------+",
1194                "| a | b    |",
1195                "+---+------+",
1196                "| 1 | 15.0 |",
1197                "| 1 | 20.0 |",
1198                "| 2 | 10.0 |",
1199                "+---+------+",
1200            ],
1201            &results
1202        );
1203
1204        Ok(())
1205    }
1206
1207    /// This test verifies that the dynamic filter is marked as complete after TopK processing finishes.
1208    #[tokio::test]
1209    async fn test_topk_marks_filter_complete() -> Result<()> {
1210        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1211
1212        let sort_expr = PhysicalSortExpr {
1213            expr: col("a", schema.as_ref())?,
1214            options: SortOptions::default(),
1215        };
1216
1217        let full_expr = LexOrdering::from([sort_expr.clone()]);
1218        let prefix = vec![sort_expr];
1219
1220        // Create a dummy runtime environment and metrics
1221        let runtime = Arc::new(RuntimeEnv::default());
1222        let metrics = ExecutionPlanMetricsSet::new();
1223
1224        // Create a dynamic filter that we'll check for completion
1225        let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
1226        let dynamic_filter_clone = Arc::clone(&dynamic_filter);
1227
1228        // Create a TopK instance
1229        let mut topk = TopK::try_new(
1230            0,
1231            Arc::clone(&schema),
1232            prefix,
1233            full_expr,
1234            2,
1235            10,
1236            runtime,
1237            &metrics,
1238            Arc::new(RwLock::new(TopKDynamicFilters::new(dynamic_filter))),
1239        )?;
1240
1241        let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(3), Some(1), Some(2)]));
1242        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![array])?;
1243        topk.insert_batch(batch)?;
1244
1245        // Call emit to finish TopK processing
1246        let _results: Vec<_> = topk.emit()?.try_collect().await?;
1247
1248        // After emit is called, the dynamic filter should be marked as complete
1249        // wait_complete() should return immediately
1250        dynamic_filter_clone.wait_complete().await;
1251
1252        Ok(())
1253    }
1254}