Skip to main content

datafusion_physical_plan/topk/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! TopK: Combination of Sort / LIMIT
19
20use arrow::{
21    array::{Array, AsArray},
22    compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter},
23    row::{RowConverter, Rows, SortField},
24};
25use datafusion_expr::{ColumnarValue, Operator};
26use std::mem::size_of;
27use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
28
29use super::metrics::{
30    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory,
31    RecordOutput,
32};
33use crate::spill::get_record_batch_memory_size;
34use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter};
35
36use arrow::array::{ArrayRef, RecordBatch};
37use arrow::datatypes::SchemaRef;
38use datafusion_common::{
39    HashMap, Result, ScalarValue, internal_datafusion_err, internal_err,
40};
41use datafusion_execution::{
42    memory_pool::{MemoryConsumer, MemoryReservation},
43    runtime_env::RuntimeEnv,
44};
45use datafusion_physical_expr::{
46    PhysicalExpr,
47    expressions::{BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit},
48};
49use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
50use parking_lot::RwLock;
51
52/// Global TopK
53///
54/// # Background
55///
56/// "Top K" is a common query optimization used for queries such as
57/// "find the top 3 customers by revenue". The (simplified) SQL for
58/// such a query might be:
59///
60/// ```sql
61/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
62/// ```
63///
64/// The simple plan would be:
65///
66/// ```sql
67/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
68/// +--------------+----------------------------------------+
69/// | plan_type    | plan                                   |
70/// +--------------+----------------------------------------+
71/// | logical_plan | Limit: 3                               |
72/// |              |   Sort: revenue DESC NULLS FIRST       |
73/// |              |     Projection: customer_id, revenue   |
74/// |              |       TableScan: sales                 |
75/// +--------------+----------------------------------------+
76/// ```
77///
78/// While this plan produces the correct answer, it will fully sorts the
79/// input before discarding everything other than the top 3 elements.
80///
81/// The same answer can be produced by simply keeping track of the top
82/// K=3 elements, reducing the total amount of required buffer memory.
83///
84/// # Partial Sort Optimization
85///
86/// This implementation additionally optimizes queries where the input is already
87/// partially sorted by a common prefix of the requested ordering. Once the top K
88/// heap is full, if subsequent rows are guaranteed to be strictly greater (in sort
89/// order) on this prefix than the largest row currently stored, the operator
90/// safely terminates early.
91///
92/// ## Example
93///
94/// For input sorted by `(day DESC)`, but not by `timestamp`, a query such as:
95///
96/// ```sql
97/// SELECT day, timestamp FROM sensor ORDER BY day DESC, timestamp DESC LIMIT 10;
98/// ```
99///
100/// can terminate scanning early once sufficient rows from the latest days have been
101/// collected, skipping older data.
102///
103/// # Structure
104///
105/// This operator tracks the top K items using a `TopKHeap`.
106pub struct TopK {
107    /// schema of the output (and the input)
108    schema: SchemaRef,
109    /// Runtime metrics
110    metrics: TopKMetrics,
111    /// Reservation
112    reservation: MemoryReservation,
113    /// The target number of rows for output batches
114    batch_size: usize,
115    /// sort expressions
116    expr: LexOrdering,
117    /// row converter, for sort keys
118    row_converter: RowConverter,
119    /// scratch space for converting rows
120    scratch_rows: Rows,
121    /// stores the top k values and their sort key values, in order
122    heap: TopKHeap,
123    /// row converter, for common keys between the sort keys and the input ordering
124    common_sort_prefix_converter: Option<RowConverter>,
125    /// Common sort prefix between the input and the sort expressions to allow early exit optimization
126    common_sort_prefix: Arc<[PhysicalSortExpr]>,
127    /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
128    filter: Arc<RwLock<TopKDynamicFilters>>,
129    /// If true, indicates that all rows of subsequent batches are guaranteed
130    /// to be greater (by byte order, after row conversion) than the top K,
131    /// which means the top K won't change and the computation can be finished early.
132    pub(crate) finished: bool,
133}
134
135/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]
136///
137/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters
138#[derive(Debug, Clone)]
139pub struct TopKDynamicFilters {
140    /// The current *global* threshold for the dynamic filter.
141    /// This is shared across all partitions and is updated by any of them.
142    /// Stored as row bytes for efficient comparison.
143    threshold_row: Option<Vec<u8>>,
144    /// The expression used to evaluate the dynamic filter
145    /// Only updated when lock held for the duration of the update
146    expr: Arc<DynamicFilterPhysicalExpr>,
147}
148
149impl TopKDynamicFilters {
150    /// Create a new `TopKDynamicFilters` with the given expression
151    pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
152        Self {
153            threshold_row: None,
154            expr,
155        }
156    }
157
158    pub fn expr(&self) -> Arc<DynamicFilterPhysicalExpr> {
159        Arc::clone(&self.expr)
160    }
161}
162
163// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
164const ESTIMATED_BYTES_PER_ROW: usize = 20;
165
166pub(crate) fn build_sort_fields(
167    ordering: &[PhysicalSortExpr],
168    schema: &SchemaRef,
169) -> Result<Vec<SortField>> {
170    ordering
171        .iter()
172        .map(|e| {
173            Ok(SortField::new_with_options(
174                e.expr.data_type(schema)?,
175                e.options,
176            ))
177        })
178        .collect::<Result<_>>()
179}
180
181impl TopK {
182    /// Create a new [`TopK`] that stores the top `k` values, as
183    /// defined by the sort expressions in `expr`.
184    // TODO: make a builder or some other nicer API
185    #[expect(clippy::too_many_arguments)]
186    #[expect(clippy::needless_pass_by_value)]
187    pub fn try_new(
188        partition_id: usize,
189        schema: SchemaRef,
190        common_sort_prefix: Vec<PhysicalSortExpr>,
191        expr: LexOrdering,
192        k: usize,
193        batch_size: usize,
194        runtime: Arc<RuntimeEnv>,
195        metrics: &ExecutionPlanMetricsSet,
196        filter: Arc<RwLock<TopKDynamicFilters>>,
197    ) -> Result<Self> {
198        let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
199            .register(&runtime.memory_pool);
200
201        let sort_fields = build_sort_fields(&expr, &schema)?;
202
203        // TODO there is potential to add special cases for single column sort fields
204        // to improve performance
205        let row_converter = RowConverter::new(sort_fields)?;
206        let scratch_rows =
207            row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size);
208
209        let prefix_row_converter = if common_sort_prefix.is_empty() {
210            None
211        } else {
212            let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?;
213            Some(RowConverter::new(input_sort_fields)?)
214        };
215
216        Ok(Self {
217            schema: Arc::clone(&schema),
218            metrics: TopKMetrics::new(metrics, partition_id),
219            reservation,
220            batch_size,
221            expr,
222            row_converter,
223            scratch_rows,
224            heap: TopKHeap::new(k),
225            common_sort_prefix_converter: prefix_row_converter,
226            common_sort_prefix: Arc::from(common_sort_prefix),
227            finished: false,
228            filter,
229        })
230    }
231
232    /// Insert `batch`, remembering if any of its values are among
233    /// the top k seen so far.
234    #[expect(clippy::needless_pass_by_value)]
235    pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
236        // Updates on drop
237        let baseline = self.metrics.baseline.clone();
238        let _timer = baseline.elapsed_compute().timer();
239
240        let mut sort_keys: Vec<ArrayRef> = self
241            .expr
242            .iter()
243            .map(|expr| {
244                let value = expr.expr.evaluate(&batch)?;
245                value.into_array(batch.num_rows())
246            })
247            .collect::<Result<Vec<_>>>()?;
248
249        let mut selected_rows = None;
250
251        // If a filter is provided, update it with the new rows
252        let filter = self.filter.read().expr.current()?;
253        let filtered = filter.evaluate(&batch)?;
254        let num_rows = batch.num_rows();
255        let array = filtered.into_array(num_rows)?;
256        let mut filter = array.as_boolean().clone();
257        if !filter.has_true() {
258            // nothing to filter, so no need to update
259            return Ok(());
260        }
261        // only update the keys / rows if the filter does not match all rows
262        if filter.null_count() > 0 || filter.has_false() {
263            // Indices in `set_indices` should be correct if filter contains nulls
264            // So we prepare the filter here. Note this is also done in the `FilterBuilder`
265            // so there is no overhead to do this here.
266            if filter.nulls().is_some() {
267                filter = prep_null_mask_filter(&filter);
268            }
269
270            let filter_predicate = FilterBuilder::new(&filter);
271            let filter_predicate = if sort_keys.len() > 1 {
272                // Optimize filter when it has multiple sort keys
273                filter_predicate.optimize().build()
274            } else {
275                filter_predicate.build()
276            };
277            selected_rows = Some(filter);
278            sort_keys = sort_keys
279                .iter()
280                .map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
281                .collect::<Result<Vec<_>>>()?;
282        }
283        // reuse existing `Rows` to avoid reallocations
284        let rows = &mut self.scratch_rows;
285        rows.clear();
286        self.row_converter.append(rows, &sort_keys)?;
287
288        let mut batch_entry = self.heap.register_batch(batch.clone());
289
290        let replacements = match selected_rows {
291            Some(filter) => {
292                self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry)
293            }
294            None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry),
295        };
296
297        if replacements > 0 {
298            self.metrics.row_replacements.add(replacements);
299
300            self.heap.insert_batch_entry(batch_entry);
301
302            // conserve memory
303            self.heap.maybe_compact()?;
304
305            // update memory reservation
306            self.reservation.try_resize(self.size())?;
307
308            // flag the topK as finished if we know that all
309            // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K,
310            // which means the top K won't change and the computation can be finished early.
311            self.attempt_early_completion(&batch)?;
312
313            // update the filter representation of our TopK heap
314            self.update_filter()?;
315        }
316
317        Ok(())
318    }
319
320    fn find_new_topk_items(
321        &mut self,
322        items: impl Iterator<Item = usize>,
323        batch_entry: &mut RecordBatchEntry,
324    ) -> usize {
325        let mut replacements = 0;
326        let rows = &mut self.scratch_rows;
327        for (index, row) in items.zip(rows.iter()) {
328            match self.heap.max() {
329                // heap has k items, and the new row is greater than the
330                // current max in the heap ==> it is not a new topk
331                Some(max_row) if row.as_ref() >= max_row.row() => {}
332                // don't yet have k items or new item is lower than the currently k low values
333                None | Some(_) => {
334                    self.heap.add(batch_entry, row, index);
335                    replacements += 1;
336                }
337            }
338        }
339        replacements
340    }
341
342    /// Update the filter representation of our TopK heap.
343    /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`,
344    /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`,
345    /// the filter will be updated to:
346    ///
347    /// ```sql
348    /// (a > 1 OR (a = 1 AND b < 5)) AND
349    /// (a > 1 OR (a = 1 AND b < 4)) AND
350    /// (a > 2 OR (a = 2 AND b < 3))
351    /// ```
352    fn update_filter(&mut self) -> Result<()> {
353        // If the heap doesn't have k elements yet, we can't create thresholds
354        let Some(max_row) = self.heap.max() else {
355            return Ok(());
356        };
357
358        let new_threshold_row = &max_row.row;
359
360        // Fast path: check if the current value in topk is better than what is
361        // currently set in the filter with a read only lock
362        let needs_update = self
363            .filter
364            .read()
365            .threshold_row
366            .as_ref()
367            .map(|current_row| {
368                // new < current means new threshold is more selective
369                new_threshold_row < current_row
370            })
371            .unwrap_or(true); // No current threshold, so we need to set one
372
373        // exit early if the current values are better
374        if !needs_update {
375            return Ok(());
376        }
377
378        // Extract scalar values BEFORE acquiring lock to reduce critical section
379        let thresholds = match self.heap.get_threshold_values(&self.expr)? {
380            Some(t) => t,
381            None => return Ok(()),
382        };
383
384        // Build the filter expression OUTSIDE any synchronization
385        let predicate = Self::build_filter_expression(&self.expr, &thresholds)?;
386        let new_threshold = new_threshold_row.to_vec();
387
388        // update the threshold. Since there was a lock gap, we must check if it is still the best
389        // may have changed while we were building the expression without the lock
390        let mut filter = self.filter.write();
391        let old_threshold = filter.threshold_row.take();
392
393        // Update filter if we successfully updated the threshold
394        // (or if there was no previous threshold and we're the first)
395        match old_threshold {
396            Some(old_threshold) => {
397                // new threshold is still better than the old one
398                if new_threshold.as_slice() < old_threshold.as_slice() {
399                    filter.threshold_row = Some(new_threshold);
400                } else {
401                    // some other thread updated the threshold to a better
402                    // one while we were building so there is no need to
403                    // update the filter
404                    filter.threshold_row = Some(old_threshold);
405                    return Ok(());
406                }
407            }
408            None => {
409                // No previous threshold, so we can set the new one
410                filter.threshold_row = Some(new_threshold);
411            }
412        };
413
414        // Update the filter expression
415        if let Some(pred) = predicate
416            && !pred.eq(&lit(true))
417        {
418            filter.expr.update(pred)?;
419        }
420
421        Ok(())
422    }
423
424    /// Build the filter expression with the given thresholds.
425    /// This is now called outside of any locks to reduce critical section time.
426    fn build_filter_expression(
427        sort_exprs: &[PhysicalSortExpr],
428        thresholds: &[ScalarValue],
429    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
430        // Create filter expressions for each threshold
431        let mut filters: Vec<Arc<dyn PhysicalExpr>> =
432            Vec::with_capacity(thresholds.len());
433
434        let mut prev_sort_expr: Option<Arc<dyn PhysicalExpr>> = None;
435        for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) {
436            // Create the appropriate operator based on sort order
437            let op = if sort_expr.options.descending {
438                // For descending sort, we want col > threshold (exclude smaller values)
439                Operator::Gt
440            } else {
441                // For ascending sort, we want col < threshold (exclude larger values)
442                Operator::Lt
443            };
444
445            let value_null = value.is_null();
446
447            let comparison = Arc::new(BinaryExpr::new(
448                Arc::clone(&sort_expr.expr),
449                op,
450                lit(value.clone()),
451            ));
452
453            let comparison_with_null = match (sort_expr.options.nulls_first, value_null) {
454                // For nulls first, transform to (threshold.value is not null) and (threshold.expr is null or comparison)
455                (true, true) => lit(false),
456                (true, false) => Arc::new(BinaryExpr::new(
457                    is_null(Arc::clone(&sort_expr.expr))?,
458                    Operator::Or,
459                    comparison,
460                )),
461                // For nulls last, transform to (threshold.value is null and threshold.expr is not null)
462                // or (threshold.value is not null and comparison)
463                (false, true) => is_not_null(Arc::clone(&sort_expr.expr))?,
464                (false, false) => comparison,
465            };
466
467            let mut eq_expr = Arc::new(BinaryExpr::new(
468                Arc::clone(&sort_expr.expr),
469                Operator::Eq,
470                lit(value.clone()),
471            ));
472
473            if value_null {
474                eq_expr = Arc::new(BinaryExpr::new(
475                    is_null(Arc::clone(&sort_expr.expr))?,
476                    Operator::Or,
477                    eq_expr,
478                ));
479            }
480
481            // For a query like order by a, b, the filter for column `b` is only applied if
482            // the condition a = threshold.value (considering null equality) is met.
483            // Therefore, we add equality predicates for all preceding fields to the filter logic of the current field,
484            // and include the current field's equality predicate in `prev_sort_expr` for use with subsequent fields.
485            match prev_sort_expr.take() {
486                None => {
487                    prev_sort_expr = Some(eq_expr);
488                    filters.push(comparison_with_null);
489                }
490                Some(p) => {
491                    filters.push(Arc::new(BinaryExpr::new(
492                        Arc::clone(&p),
493                        Operator::And,
494                        comparison_with_null,
495                    )));
496
497                    prev_sort_expr =
498                        Some(Arc::new(BinaryExpr::new(p, Operator::And, eq_expr)));
499                }
500            }
501        }
502
503        let dynamic_predicate = filters
504            .into_iter()
505            .reduce(|a, b| Arc::new(BinaryExpr::new(a, Operator::Or, b)));
506
507        Ok(dynamic_predicate)
508    }
509
510    /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full,
511    /// check if the computation can be finished early.
512    /// This is the case if the last row of the current batch is strictly greater than the max row in the heap,
513    /// comparing only on the shared prefix columns.
514    fn attempt_early_completion(&mut self, batch: &RecordBatch) -> Result<()> {
515        // Early exit if the batch is empty as there is no last row to extract from it.
516        if batch.num_rows() == 0 {
517            return Ok(());
518        }
519
520        // prefix_row_converter is only `Some` if the input ordering has a common prefix with the TopK,
521        // so early exit if it is `None`.
522        let Some(prefix_converter) = &self.common_sort_prefix_converter else {
523            return Ok(());
524        };
525
526        // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full).
527        let Some(max_topk_row) = self.heap.max() else {
528            return Ok(());
529        };
530
531        // Evaluate the prefix for the last row of the current batch.
532        let last_row_idx = batch.num_rows() - 1;
533        let mut batch_prefix_scratch =
534            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
535
536        self.compute_common_sort_prefix(batch, last_row_idx, &mut batch_prefix_scratch)?;
537
538        // Retrieve the max row from the heap.
539        let store_entry = self
540            .heap
541            .store
542            .get(max_topk_row.batch_id)
543            .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?;
544        let max_batch = &store_entry.batch;
545        let mut heap_prefix_scratch =
546            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
547        self.compute_common_sort_prefix(
548            max_batch,
549            max_topk_row.index,
550            &mut heap_prefix_scratch,
551        )?;
552
553        // If the last row's prefix is strictly greater than the max prefix, mark as finished.
554        if batch_prefix_scratch.row(0).as_ref() > heap_prefix_scratch.row(0).as_ref() {
555            self.finished = true;
556        }
557
558        Ok(())
559    }
560
561    // Helper function to compute the prefix for a given batch and row index, storing the result in scratch.
562    fn compute_common_sort_prefix(
563        &self,
564        batch: &RecordBatch,
565        last_row_idx: usize,
566        scratch: &mut Rows,
567    ) -> Result<()> {
568        let last_row: Vec<ArrayRef> = self
569            .common_sort_prefix
570            .iter()
571            .map(|expr| {
572                expr.expr
573                    .evaluate(&batch.slice(last_row_idx, 1))?
574                    .into_array(1)
575            })
576            .collect::<Result<_>>()?;
577
578        self.common_sort_prefix_converter
579            .as_ref()
580            .unwrap()
581            .append(scratch, &last_row)?;
582        Ok(())
583    }
584
585    /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
586    pub fn emit(self) -> Result<SendableRecordBatchStream> {
587        let Self {
588            schema,
589            metrics,
590            reservation: _,
591            batch_size,
592            expr: _,
593            row_converter: _,
594            scratch_rows: _,
595            mut heap,
596            common_sort_prefix_converter: _,
597            common_sort_prefix: _,
598            finished: _,
599            filter,
600        } = self;
601        let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
602
603        // Mark the dynamic filter as complete now that TopK processing is finished.
604        filter.read().expr().mark_complete();
605
606        // break into record batches as needed
607        let mut batches = vec![];
608        if let Some(mut batch) = heap.emit()? {
609            (&batch).record_output(&metrics.baseline);
610
611            loop {
612                if batch.num_rows() <= batch_size {
613                    batches.push(Ok(batch));
614                    break;
615                } else {
616                    batches.push(Ok(batch.slice(0, batch_size)));
617                    let remaining_length = batch.num_rows() - batch_size;
618                    batch = batch.slice(batch_size, remaining_length);
619                }
620            }
621        };
622        Ok(Box::pin(RecordBatchStreamAdapter::new(
623            schema,
624            futures::stream::iter(batches),
625        )))
626    }
627
628    /// return the size of memory used by this operator, in bytes
629    fn size(&self) -> usize {
630        size_of::<Self>()
631            + self.row_converter.size()
632            + self.scratch_rows.size()
633            + self.heap.size()
634    }
635}
636
637struct TopKMetrics {
638    /// metrics
639    pub baseline: BaselineMetrics,
640
641    /// count of how many rows were replaced in the heap
642    pub row_replacements: Count,
643}
644
645impl TopKMetrics {
646    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
647        Self {
648            baseline: BaselineMetrics::new(metrics, partition),
649            row_replacements: MetricBuilder::new(metrics)
650                .with_category(MetricCategory::Rows)
651                .counter("row_replacements", partition),
652        }
653    }
654}
655
656/// This structure keeps at most the *smallest* k items, using the
657/// [arrow::row] format for sort keys. While it is called "topK" for
658/// values like `1, 2, 3, 4, 5` the "top 3" really means the
659/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
660///
661/// Using the `Row` format handles things such as ascending vs
662/// descending and nulls first vs nulls last.
663struct TopKHeap {
664    /// The maximum number of elements to store in this heap.
665    k: usize,
666    /// Storage for up at most `k` items using a BinaryHeap. Reversed
667    /// so that the smallest k so far is on the top
668    inner: BinaryHeap<TopKRow>,
669    /// Storage the original row values (TopKRow only has the sort key)
670    store: RecordBatchStore,
671    /// The size of all owned data held by this heap
672    owned_bytes: usize,
673}
674
675impl TopKHeap {
676    fn new(k: usize) -> Self {
677        assert!(k > 0);
678        Self {
679            k,
680            inner: BinaryHeap::new(),
681            store: RecordBatchStore::new(),
682            owned_bytes: 0,
683        }
684    }
685
686    /// Register a [`RecordBatch`] with the heap, returning the
687    /// appropriate entry
688    pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
689        self.store.register(batch)
690    }
691
692    /// Insert a [`RecordBatchEntry`] created by a previous call to
693    /// [`Self::register_batch`] into storage.
694    pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
695        self.store.insert(entry)
696    }
697
698    /// Returns the largest value stored by the heap if there are k
699    /// items, otherwise returns None. Remember this structure is
700    /// keeping the "smallest" k values
701    fn max(&self) -> Option<&TopKRow> {
702        if self.inner.len() < self.k {
703            None
704        } else {
705            self.inner.peek()
706        }
707    }
708
709    /// Adds `row` to this heap. If inserting this new item would
710    /// increase the size past `k`, removes the previously smallest
711    /// item.
712    fn add(
713        &mut self,
714        batch_entry: &mut RecordBatchEntry,
715        row: impl AsRef<[u8]>,
716        index: usize,
717    ) {
718        let batch_id = batch_entry.id;
719        batch_entry.uses += 1;
720
721        assert!(self.inner.len() <= self.k);
722        let row = row.as_ref();
723
724        // Reuse storage for evicted item if possible
725        if self.inner.len() == self.k {
726            let mut prev_min = self.inner.peek_mut().unwrap();
727
728            // Update batch use
729            if prev_min.batch_id == batch_entry.id {
730                batch_entry.uses -= 1;
731            } else {
732                self.store.unuse(prev_min.batch_id);
733            }
734
735            // update memory accounting
736            self.owned_bytes -= prev_min.owned_size();
737
738            prev_min.replace_with(row, batch_id, index);
739
740            self.owned_bytes += prev_min.owned_size();
741        } else {
742            let new_row = TopKRow::new(row, batch_id, index);
743            self.owned_bytes += new_row.owned_size();
744            // put the new row into the heap
745            self.inner.push(new_row);
746        };
747    }
748
749    /// Returns the values stored in this heap, from values low to
750    /// high, as a single [`RecordBatch`], resetting the inner heap
751    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
752        Ok(self.emit_with_state()?.0)
753    }
754
755    /// Returns the values stored in this heap, from values low to
756    /// high, as a single [`RecordBatch`], and a sorted vec of the
757    /// current heap's contents
758    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
759        // generate sorted rows
760        let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
761
762        if self.store.is_empty() {
763            return Ok((None, topk_rows));
764        }
765
766        // Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then
767        // build the `indices` vec below. This is needed since the batch ids are not continuous.
768        let mut record_batches = Vec::new();
769        let mut batch_id_array_pos = HashMap::new();
770        for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() {
771            record_batches.push(&batch.batch);
772            batch_id_array_pos.insert(*batch_id, array_pos);
773        }
774
775        let indices: Vec<_> = topk_rows
776            .iter()
777            .map(|k| (batch_id_array_pos[&k.batch_id], k.index))
778            .collect();
779
780        // At this point `indices` contains indexes within the
781        // rows and `input_arrays` contains a reference to the
782        // relevant RecordBatch for that index. `interleave_record_batch` pulls
783        // them together into a single new batch
784        let new_batch = interleave_record_batch(&record_batches, &indices)?;
785
786        Ok((Some(new_batch), topk_rows))
787    }
788
789    /// Compact this heap, rewriting all stored batches into a single
790    /// input batch
791    pub fn maybe_compact(&mut self) -> Result<()> {
792        // Don't compact if there's only one batch (compacting into itself is pointless)
793        if self.store.len() <= 1 {
794            return Ok(());
795        }
796
797        let total_rows = self.store.total_rows;
798        let num_rows = self.inner.len();
799
800        // Compact when current store memory exceeds 2x what the compacted
801        // result would need. The multiplier avoids compacting when the
802        // savings would be marginal.
803        if total_rows <= num_rows * 2 {
804            return Ok(());
805        }
806
807        // at first, compact the entire thing always into a new batch
808        // (maybe we can get fancier in the future about ignoring
809        // batches that have a high usage ratio already
810
811        // Note: new batch is in the same order as inner
812        let (new_batch, mut topk_rows) = self.emit_with_state()?;
813        let Some(new_batch) = new_batch else {
814            return Ok(());
815        };
816
817        // clear all old entries in store (this invalidates all
818        // store_ids in `inner`)
819        self.store.clear();
820
821        let mut batch_entry = self.register_batch(new_batch);
822        batch_entry.uses = num_rows;
823
824        // rewrite all existing entries to use the new batch, and
825        // remove old entries. The sortedness and their relative
826        // position do not change
827        for (i, topk_row) in topk_rows.iter_mut().enumerate() {
828            topk_row.batch_id = batch_entry.id;
829            topk_row.index = i;
830        }
831        self.insert_batch_entry(batch_entry);
832        // restore the heap
833        self.inner = BinaryHeap::from(topk_rows);
834
835        Ok(())
836    }
837
838    /// return the size of memory used by this heap, in bytes
839    fn size(&self) -> usize {
840        size_of::<Self>()
841            + (self.inner.capacity() * size_of::<TopKRow>())
842            + self.store.size()
843            + self.owned_bytes
844    }
845
846    fn get_threshold_values(
847        &self,
848        sort_exprs: &[PhysicalSortExpr],
849    ) -> Result<Option<Vec<ScalarValue>>> {
850        // If the heap doesn't have k elements yet, we can't create thresholds
851        let max_row = match self.max() {
852            Some(row) => row,
853            None => return Ok(None),
854        };
855
856        // Get the batch that contains the max row
857        let batch_entry = match self.store.get(max_row.batch_id) {
858            Some(entry) => entry,
859            None => return internal_err!("Invalid batch ID in TopKRow"),
860        };
861
862        // Extract threshold values for each sort expression
863        let mut scalar_values = Vec::with_capacity(sort_exprs.len());
864        for sort_expr in sort_exprs {
865            // Extract the value for this column from the max row
866            let expr = Arc::clone(&sort_expr.expr);
867            let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?;
868
869            // Convert to scalar value - should be a single value since we're evaluating on a single row batch
870            let scalar = match value {
871                ColumnarValue::Scalar(scalar) => scalar,
872                ColumnarValue::Array(array) if array.len() == 1 => {
873                    // Extract the first (and only) value from the array
874                    ScalarValue::try_from_array(&array, 0)?
875                }
876                array => {
877                    return internal_err!("Expected a scalar value, got {:?}", array);
878                }
879            };
880
881            scalar_values.push(scalar);
882        }
883
884        Ok(Some(scalar_values))
885    }
886}
887
888/// Represents one of the top K rows held in this heap. Orders
889/// according to memcmp of row (e.g. the arrow Row format, but could
890/// also be primitive values)
891///
892/// Reuses allocations to minimize runtime overhead of creating new Vecs
893#[derive(Debug, PartialEq)]
894struct TopKRow {
895    /// the value of the sort key for this row. This contains the
896    /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
897    /// reuse allocations.
898    row: Vec<u8>,
899    /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
900    batch_id: u32,
901    /// the index in this record batch the row came from
902    index: usize,
903}
904
905impl TopKRow {
906    /// Create a new TopKRow with new allocation
907    fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
908        Self {
909            row: row.as_ref().to_vec(),
910            batch_id,
911            index,
912        }
913    }
914
915    // Replace the existing row capacity with new values
916    fn replace_with(&mut self, new_row: impl AsRef<[u8]>, batch_id: u32, index: usize) {
917        self.row.clear();
918        self.row.extend_from_slice(new_row.as_ref());
919
920        self.batch_id = batch_id;
921        self.index = index;
922    }
923
924    /// Returns the number of bytes owned by this row in the heap (not
925    /// including itself)
926    fn owned_size(&self) -> usize {
927        self.row.capacity()
928    }
929
930    /// Returns a slice to the owned row value
931    fn row(&self) -> &[u8] {
932        self.row.as_slice()
933    }
934}
935
936impl Eq for TopKRow {}
937
938impl PartialOrd for TopKRow {
939    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
940        // TODO PartialOrd is not consistent with PartialEq; PartialOrd contract is violated
941        Some(self.cmp(other))
942    }
943}
944
945impl Ord for TopKRow {
946    fn cmp(&self, other: &Self) -> Ordering {
947        self.row.cmp(&other.row)
948    }
949}
950
951#[derive(Debug)]
952struct RecordBatchEntry {
953    id: u32,
954    batch: RecordBatch,
955    // for this batch, how many times has it been used
956    uses: usize,
957}
958
959/// This structure tracks [`RecordBatch`] by an id so that:
960///
961/// 1. The baches can be tracked via an id that can be copied cheaply
962/// 2. The total memory held by all batches is tracked
963#[derive(Debug)]
964struct RecordBatchStore {
965    /// id generator
966    next_id: u32,
967    /// storage
968    batches: HashMap<u32, RecordBatchEntry>,
969    /// total size of all record batches tracked by this store
970    batches_size: usize,
971    /// row count of all the batches
972    total_rows: usize,
973}
974
975impl RecordBatchStore {
976    fn new() -> Self {
977        Self {
978            next_id: 0,
979            batches: HashMap::new(),
980            batches_size: 0,
981            total_rows: 0,
982        }
983    }
984
985    /// Register this batch with the store and assign an ID. No
986    /// attempt is made to compare this batch to other batches
987    pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
988        let id = self.next_id;
989        self.next_id += 1;
990        RecordBatchEntry { id, batch, uses: 0 }
991    }
992
993    /// Insert a record batch entry into this store, tracking its
994    /// memory use, if it has any uses
995    pub fn insert(&mut self, entry: RecordBatchEntry) {
996        // uses of 0 means that none of the rows in the batch were stored in the topk
997        if entry.uses > 0 {
998            self.batches_size += get_record_batch_memory_size(&entry.batch);
999            self.total_rows += entry.batch.num_rows();
1000            self.batches.insert(entry.id, entry);
1001        }
1002    }
1003
1004    /// Clear all values in this store, invalidating all previous batch ids
1005    fn clear(&mut self) {
1006        self.batches.clear();
1007        self.batches_size = 0;
1008        self.total_rows = 0;
1009    }
1010
1011    fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
1012        self.batches.get(&id)
1013    }
1014
1015    /// returns the total number of batches stored in this store
1016    fn len(&self) -> usize {
1017        self.batches.len()
1018    }
1019
1020    /// returns true if the store has nothing stored
1021    fn is_empty(&self) -> bool {
1022        self.batches.is_empty()
1023    }
1024
1025    /// remove a use from the specified batch id. If the use count
1026    /// reaches zero the batch entry is removed from the store
1027    ///
1028    /// panics if there were no remaining uses of id
1029    pub fn unuse(&mut self, id: u32) {
1030        let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
1031            batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
1032            batch_entry.uses == 0
1033        } else {
1034            panic!("No entry for id {id}");
1035        };
1036
1037        if remove {
1038            let old_entry = self.batches.remove(&id).unwrap();
1039            self.batches_size = self
1040                .batches_size
1041                .checked_sub(get_record_batch_memory_size(&old_entry.batch))
1042                .unwrap();
1043
1044            self.total_rows = self
1045                .total_rows
1046                .checked_sub(old_entry.batch.num_rows())
1047                .unwrap();
1048        }
1049    }
1050
1051    /// returns the size of memory used by this store, including all
1052    /// referenced `RecordBatch`es, in bytes
1053    pub fn size(&self) -> usize {
1054        size_of::<Self>()
1055            + self.batches.capacity() * (size_of::<u32>() + size_of::<RecordBatchEntry>())
1056            + self.batches_size
1057    }
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063    use arrow::array::{BooleanArray, Float64Array, Int32Array};
1064    use arrow::datatypes::{DataType, Field, Schema};
1065    use arrow_schema::SortOptions;
1066    use datafusion_common::assert_batches_eq;
1067    use datafusion_physical_expr::expressions::col;
1068    use futures::TryStreamExt;
1069
1070    /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
1071    #[test]
1072    fn test_record_batch_store_size() {
1073        // given
1074        let schema = Arc::new(Schema::new(vec![
1075            Field::new("ints", DataType::Int32, true),
1076            Field::new("float64", DataType::Float64, false),
1077        ]));
1078        let mut record_batch_store = RecordBatchStore::new();
1079        let int_array =
1080            Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
1081        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
1082
1083        let record_batch_entry = RecordBatchEntry {
1084            id: 0,
1085            batch: RecordBatch::try_new(
1086                schema,
1087                vec![Arc::new(int_array), Arc::new(float64_array)],
1088            )
1089            .unwrap(),
1090            uses: 1,
1091        };
1092
1093        // when insert record batch entry
1094        record_batch_store.insert(record_batch_entry);
1095        assert_eq!(record_batch_store.batches_size, 60);
1096
1097        // when unuse record batch entry
1098        record_batch_store.unuse(0);
1099        assert_eq!(record_batch_store.batches_size, 0);
1100    }
1101
1102    /// This test validates that the `try_finish` method marks the TopK operator as finished
1103    /// when the prefix (on column "a") of the last row in the current batch is strictly greater
1104    /// than the max top‑k row.
1105    /// The full sort expression is defined on both columns ("a", "b"), but the input ordering is only on "a".
1106    #[tokio::test]
1107    async fn test_try_finish_marks_finished_with_prefix() -> Result<()> {
1108        // Create a schema with two columns.
1109        let schema = Arc::new(Schema::new(vec![
1110            Field::new("a", DataType::Int32, false),
1111            Field::new("b", DataType::Float64, false),
1112        ]));
1113
1114        // Create sort expressions.
1115        // Full sort: first by "a", then by "b".
1116        let sort_expr_a = PhysicalSortExpr {
1117            expr: col("a", schema.as_ref())?,
1118            options: SortOptions::default(),
1119        };
1120        let sort_expr_b = PhysicalSortExpr {
1121            expr: col("b", schema.as_ref())?,
1122            options: SortOptions::default(),
1123        };
1124
1125        // Input ordering uses only column "a" (a prefix of the full sort).
1126        let prefix = vec![sort_expr_a.clone()];
1127        let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]);
1128
1129        // Create a dummy runtime environment and metrics.
1130        let runtime = Arc::new(RuntimeEnv::default());
1131        let metrics = ExecutionPlanMetricsSet::new();
1132
1133        // Create a TopK instance with k = 3 and batch_size = 2.
1134        let mut topk = TopK::try_new(
1135            0,
1136            Arc::clone(&schema),
1137            prefix,
1138            full_expr,
1139            3,
1140            2,
1141            runtime,
1142            &metrics,
1143            Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1144                DynamicFilterPhysicalExpr::new(vec![], lit(true)),
1145            )))),
1146        )?;
1147
1148        // Create the first batch with two columns:
1149        // Column "a": [1, 1, 2], Column "b": [20.0, 15.0, 30.0].
1150        let array_a1: ArrayRef =
1151            Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)]));
1152        let array_b1: ArrayRef = Arc::new(Float64Array::from(vec![20.0, 15.0, 30.0]));
1153        let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a1, array_b1])?;
1154
1155        // Insert the first batch.
1156        // At this point the heap is not yet “finished” because the prefix of the last row of the batch
1157        // is not strictly greater than the prefix of the max top‑k row (both being `2`).
1158        topk.insert_batch(batch1)?;
1159        assert!(
1160            !topk.finished,
1161            "Expected 'finished' to be false after the first batch."
1162        );
1163
1164        // Create the second batch with two columns:
1165        // Column "a": [2, 3], Column "b": [10.0, 20.0].
1166        let array_a2: ArrayRef = Arc::new(Int32Array::from(vec![Some(2), Some(3)]));
1167        let array_b2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
1168        let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a2, array_b2])?;
1169
1170        // Insert the second batch.
1171        // The last row in this batch has a prefix value of `3`,
1172        // which is strictly greater than the max top‑k row (with value `2`),
1173        // so try_finish should mark the TopK as finished.
1174        topk.insert_batch(batch2)?;
1175        assert!(
1176            topk.finished,
1177            "Expected 'finished' to be true after the second batch."
1178        );
1179
1180        // Verify the TopK correctly emits the top k rows from both batches
1181        // (the value 10.0 for b is from the second batch).
1182        let results: Vec<_> = topk.emit()?.try_collect().await?;
1183        assert_batches_eq!(
1184            &[
1185                "+---+------+",
1186                "| a | b    |",
1187                "+---+------+",
1188                "| 1 | 15.0 |",
1189                "| 1 | 20.0 |",
1190                "| 2 | 10.0 |",
1191                "+---+------+",
1192            ],
1193            &results
1194        );
1195
1196        Ok(())
1197    }
1198
1199    /// This test verifies that the dynamic filter is marked as complete after TopK processing finishes.
1200    #[tokio::test]
1201    async fn test_topk_marks_filter_complete() -> Result<()> {
1202        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1203
1204        let sort_expr = PhysicalSortExpr {
1205            expr: col("a", schema.as_ref())?,
1206            options: SortOptions::default(),
1207        };
1208
1209        let full_expr = LexOrdering::from([sort_expr.clone()]);
1210        let prefix = vec![sort_expr];
1211
1212        // Create a dummy runtime environment and metrics
1213        let runtime = Arc::new(RuntimeEnv::default());
1214        let metrics = ExecutionPlanMetricsSet::new();
1215
1216        // Create a dynamic filter that we'll check for completion
1217        let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
1218        let dynamic_filter_clone = Arc::clone(&dynamic_filter);
1219
1220        // Create a TopK instance
1221        let mut topk = TopK::try_new(
1222            0,
1223            Arc::clone(&schema),
1224            prefix,
1225            full_expr,
1226            2,
1227            10,
1228            runtime,
1229            &metrics,
1230            Arc::new(RwLock::new(TopKDynamicFilters::new(dynamic_filter))),
1231        )?;
1232
1233        let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(3), Some(1), Some(2)]));
1234        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![array])?;
1235        topk.insert_batch(batch)?;
1236
1237        // Call emit to finish TopK processing
1238        let _results: Vec<_> = topk.emit()?.try_collect().await?;
1239
1240        // After emit is called, the dynamic filter should be marked as complete
1241        // wait_complete() should return immediately
1242        dynamic_filter_clone.wait_complete().await;
1243
1244        Ok(())
1245    }
1246
1247    /// Tests that memory-based compaction triggers when a large batch
1248    /// has very few rows referenced by the top-k heap.
1249    #[tokio::test]
1250    async fn test_topk_memory_compaction() -> Result<()> {
1251        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1252
1253        let sort_expr = PhysicalSortExpr {
1254            expr: col("a", schema.as_ref())?,
1255            options: SortOptions::default(),
1256        };
1257
1258        let full_expr = LexOrdering::from([sort_expr.clone()]);
1259        let prefix = vec![sort_expr];
1260
1261        let runtime = Arc::new(RuntimeEnv::default());
1262        let metrics = ExecutionPlanMetricsSet::new();
1263
1264        let k = 5;
1265        let mut topk = TopK::try_new(
1266            0,
1267            Arc::clone(&schema),
1268            prefix,
1269            full_expr,
1270            k,
1271            8192,
1272            runtime,
1273            &metrics,
1274            Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1275                DynamicFilterPhysicalExpr::new(vec![], lit(true)),
1276            )))),
1277        )?;
1278
1279        // Insert a large batch (100,000 rows) with values 1..=100_000.
1280        // Only the smallest 5 values (1..=5) will end up in the heap.
1281        let large_values: Vec<i32> = (1..=100_000).collect();
1282        let array1: ArrayRef = Arc::new(Int32Array::from(large_values));
1283        let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array1])?;
1284        topk.insert_batch(batch1)?;
1285
1286        // After the first batch, store has 1 batch — compaction should
1287        // not trigger (guard: store.len() <= 1).
1288        assert_eq!(
1289            topk.heap.store.len(),
1290            1,
1291            "should have 1 batch before second insert"
1292        );
1293
1294        // Insert a second batch whose values displace entries in the heap.
1295        // -1 and 0 are smaller than the current top-5 (1..=5), so they
1296        // produce 2 replacements. With replacements > 0, `insert_batch`
1297        // calls `insert_batch_entry` (briefly making store.len() == 2)
1298        // and then `maybe_compact`, which should collapse it back to 1.
1299        let array2: ArrayRef = Arc::new(Int32Array::from(vec![-1, 0]));
1300        let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array2])?;
1301        let replacements_before = topk.metrics.row_replacements.value();
1302        topk.insert_batch(batch2)?;
1303
1304        // Sanity check: batch2 was actually integrated. Without
1305        // replacements, `maybe_compact` is never called and the
1306        // store-length assertion below would pass vacuously.
1307        assert!(
1308            topk.metrics.row_replacements.value() > replacements_before,
1309            "batch2 must produce replacements so compaction is exercised"
1310        );
1311
1312        // The compacted-estimate guard is `total_rows <= num_rows * 2`,
1313        // i.e. 100_002 <= 10, which is false, so compaction fires and
1314        // collapses the two stored batches back into one.
1315        assert_eq!(
1316            topk.heap.store.len(),
1317            1,
1318            "store should be compacted to 1 batch"
1319        );
1320
1321        // Verify the emitted results are correct (top 5 ascending).
1322        let results: Vec<_> = topk.emit()?.try_collect().await?;
1323        assert_batches_eq!(
1324            &[
1325                "+----+", "| a  |", "+----+", "| -1 |", "| 0  |", "| 1  |", "| 2  |",
1326                "| 3  |", "+----+",
1327            ],
1328            &results
1329        );
1330
1331        Ok(())
1332    }
1333
1334    /// Negative path: when stored rows are close to the heap size,
1335    /// compaction must NOT fire even with multiple batches present,
1336    /// because the savings would be marginal
1337    /// (guard: `total_rows <= num_rows * 2`).
1338    ///
1339    /// Uses a bit-packed `BooleanArray` so that future changes to the
1340    /// compaction heuristic that reintroduce a per-byte estimate
1341    /// (where integer truncation could misbehave on sub-byte types)
1342    /// are caught here.
1343    #[tokio::test]
1344    async fn test_topk_memory_compaction_skipped_when_marginal() -> Result<()> {
1345        let schema =
1346            Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)]));
1347
1348        let sort_expr = PhysicalSortExpr {
1349            expr: col("a", schema.as_ref())?,
1350            options: SortOptions::default(),
1351        };
1352        let full_expr = LexOrdering::from([sort_expr.clone()]);
1353        let prefix = vec![sort_expr];
1354
1355        let runtime = Arc::new(RuntimeEnv::default());
1356        let metrics = ExecutionPlanMetricsSet::new();
1357
1358        let k = 10;
1359        let mut topk = TopK::try_new(
1360            0,
1361            Arc::clone(&schema),
1362            prefix,
1363            full_expr,
1364            k,
1365            8192,
1366            runtime,
1367            &metrics,
1368            Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1369                DynamicFilterPhysicalExpr::new(vec![], lit(true)),
1370            )))),
1371        )?;
1372
1373        // Two small batches; every row from both batches ends up referenced
1374        // by the heap, so total_rows == num_rows == 10.
1375        let batch1 = RecordBatch::try_new(
1376            Arc::clone(&schema),
1377            vec![
1378                Arc::new(BooleanArray::from(vec![false, false, true, true, true]))
1379                    as ArrayRef,
1380            ],
1381        )?;
1382        topk.insert_batch(batch1)?;
1383
1384        let batch2 = RecordBatch::try_new(
1385            Arc::clone(&schema),
1386            vec![
1387                Arc::new(BooleanArray::from(vec![false, false, false, true, true]))
1388                    as ArrayRef,
1389            ],
1390        )?;
1391        topk.insert_batch(batch2)?;
1392
1393        // Guard `total_rows <= num_rows * 2` should hold (10 <= 20),
1394        // so compaction is skipped and BOTH batches remain in the store.
1395        assert_eq!(
1396            topk.heap.store.len(),
1397            2,
1398            "store must keep 2 batches when savings would be marginal"
1399        );
1400        assert_eq!(topk.heap.inner.len(), 10, "heap should hold all 10 rows");
1401
1402        // Output is still correct (5 falses then 5 trues ascending).
1403        let results: Vec<_> = topk.emit()?.try_collect().await?;
1404        assert_batches_eq!(
1405            &[
1406                "+-------+",
1407                "| a     |",
1408                "+-------+",
1409                "| false |",
1410                "| false |",
1411                "| false |",
1412                "| false |",
1413                "| false |",
1414                "| true  |",
1415                "| true  |",
1416                "| true  |",
1417                "| true  |",
1418                "| true  |",
1419                "+-------+",
1420            ],
1421            &results
1422        );
1423
1424        Ok(())
1425    }
1426}