Skip to main content

datafusion_physical_plan/sorts/
sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Sort that deals with an arbitrary size of the input.
19//! It will do in-memory sorting if it has enough memory budget
20//! but spills to disk if needed.
21
22use std::fmt;
23use std::fmt::{Debug, Formatter};
24use std::sync::Arc;
25
26use parking_lot::RwLock;
27
28use crate::common::spawn_buffered;
29use crate::execution_plan::{
30    Boundedness, CardinalityEffect, EmissionType, has_same_children_properties,
31};
32use crate::expressions::PhysicalSortExpr;
33use crate::filter::FilterExec;
34use crate::filter_pushdown::{
35    ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase,
36    FilterPushdownPropagation, PushedDown,
37};
38use crate::limit::LimitStream;
39use crate::metrics::{
40    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics,
41};
42use crate::projection::{ProjectionExec, make_with_child, update_ordering};
43use crate::sorts::IncrementalSortIterator;
44use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
45use crate::spill::get_record_batch_memory_size;
46use crate::spill::in_progress_spill_file::InProgressSpillFile;
47use crate::spill::spill_manager::{GetSlicedSize, SpillManager};
48use crate::stream::RecordBatchStreamAdapter;
49use crate::stream::ReservationStream;
50use crate::topk::TopK;
51use crate::topk::TopKDynamicFilters;
52use crate::{
53    DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan,
54    ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream,
55    Statistics,
56};
57
58use arrow::array::{RecordBatch, RecordBatchOptions};
59use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays};
60use arrow::datatypes::SchemaRef;
61use datafusion_common::config::SpillCompression;
62use datafusion_common::{
63    DataFusionError, Result, assert_or_internal_err, internal_datafusion_err,
64    unwrap_or_internal_err,
65};
66use datafusion_execution::TaskContext;
67use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
68use datafusion_execution::runtime_env::RuntimeEnv;
69use datafusion_physical_expr::LexOrdering;
70use datafusion_physical_expr::PhysicalExpr;
71use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
72
73use futures::{StreamExt, TryStreamExt};
74use log::{debug, trace};
75
76struct ExternalSorterMetrics {
77    /// metrics
78    baseline: BaselineMetrics,
79
80    spill_metrics: SpillMetrics,
81}
82
83impl ExternalSorterMetrics {
84    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
85        Self {
86            baseline: BaselineMetrics::new(metrics, partition),
87            spill_metrics: SpillMetrics::new(metrics, partition),
88        }
89    }
90}
91
92/// Sorts an arbitrary sized, unsorted, stream of [`RecordBatch`]es to
93/// a total order. Depending on the input size and memory manager
94/// configuration, writes intermediate results to disk ("spills")
95/// using Arrow IPC format.
96///
97/// # Algorithm
98///
99/// 1. get a non-empty new batch from input
100///
101/// 2. check with the memory manager there is sufficient space to
102///    buffer the batch in memory.
103///
104/// 2.1 if memory is sufficient, buffer batch in memory, go to 1.
105///
106/// 2.2 if no more memory is available, sort all buffered batches and
107///     spill to file.  buffer the next batch in memory, go to 1.
108///
109/// 3. when input is exhausted, merge all in memory batches and spills
110///    to get a total order.
111///
112/// # When data fits in available memory
113///
114/// If there is sufficient memory, data is sorted in memory to produce the output
115///
116/// ```text
117///    ┌─────┐
118///    │  2  │
119///    │  3  │
120///    │  1  │─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
121///    │  4  │
122///    │  2  │                  │
123///    └─────┘                  ▼
124///    ┌─────┐
125///    │  1  │              In memory
126///    │  4  │─ ─ ─ ─ ─ ─▶ sort/merge  ─ ─ ─ ─ ─▶  total sorted output
127///    │  1  │
128///    └─────┘                  ▲
129///      ...                    │
130///
131///    ┌─────┐                  │
132///    │  4  │
133///    │  3  │─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
134///    └─────┘
135///
136/// in_mem_batches
137/// ```
138///
139/// # When data does not fit in available memory
140///
141///  When memory is exhausted, data is first sorted and written to one
142///  or more spill files on disk:
143///
144/// ```text
145///    ┌─────┐                               .─────────────────.
146///    │  2  │                              (                   )
147///    │  3  │                              │`─────────────────'│
148///    │  1  │─ ─ ─ ─ ─ ─ ─                 │  ┌────┐           │
149///    │  4  │             │                │  │ 1  │░          │
150///    │  2  │                              │  │... │░          │
151///    └─────┘             ▼                │  │ 4  │░  ┌ ─ ─   │
152///    ┌─────┐                              │  └────┘░    1  │░ │
153///    │  1  │         In memory            │   ░░░░░░  │    ░░ │
154///    │  4  │─ ─ ▶   sort/merge    ─ ─ ─ ─ ┼ ─ ─ ─ ─ ─▶ ... │░ │
155///    │  1  │     and write to file        │           │    ░░ │
156///    └─────┘                              │             4  │░ │
157///      ...               ▲                │           └░─░─░░ │
158///                        │                │            ░░░░░░ │
159///    ┌─────┐                              │.─────────────────.│
160///    │  4  │             │                (                   )
161///    │  3  │─ ─ ─ ─ ─ ─ ─                  `─────────────────'
162///    └─────┘
163///
164/// in_mem_batches                                  spills
165///                                         (file on disk in Arrow
166///                                               IPC format)
167/// ```
168///
169/// Once the input is completely read, the spill files are read and
170/// merged with any in memory batches to produce a single total sorted
171/// output:
172///
173/// ```text
174///   .─────────────────.
175///  (                   )
176///  │`─────────────────'│
177///  │  ┌────┐           │
178///  │  │ 1  │░          │
179///  │  │... │─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─
180///  │  │ 4  │░ ┌────┐   │           │
181///  │  └────┘░ │ 1  │░  │           ▼
182///  │   ░░░░░░ │    │░  │
183///  │          │... │─ ─│─ ─ ─ ▶ merge  ─ ─ ─▶  total sorted output
184///  │          │    │░  │
185///  │          │ 4  │░  │           ▲
186///  │          └────┘░  │           │
187///  │           ░░░░░░  │
188///  │.─────────────────.│           │
189///  (                   )
190///   `─────────────────'            │
191///         spills
192///                                  │
193///
194///                                  │
195///
196///     ┌─────┐                      │
197///     │  1  │
198///     │  4  │─ ─ ─ ─               │
199///     └─────┘       │
200///       ...                   In memory
201///                   └ ─ ─ ─▶  sort/merge
202///     ┌─────┐
203///     │  4  │                      ▲
204///     │  3  │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
205///     └─────┘
206///
207///  in_mem_batches
208/// ```
209struct ExternalSorter {
210    // ========================================================================
211    // PROPERTIES:
212    // Fields that define the sorter's configuration and remain constant
213    // ========================================================================
214    /// Schema of the output (and the input)
215    schema: SchemaRef,
216    /// Sort expressions
217    expr: LexOrdering,
218    /// The target number of rows for output batches
219    batch_size: usize,
220    /// If the in size of buffered memory batches is below this size,
221    /// the data will be concatenated and sorted in place rather than
222    /// sort/merged.
223    sort_in_place_threshold_bytes: usize,
224
225    // ========================================================================
226    // STATE BUFFERS:
227    // Fields that hold intermediate data during sorting
228    // ========================================================================
229    /// Unsorted input batches stored in the memory buffer
230    in_mem_batches: Vec<RecordBatch>,
231
232    /// During external sorting, in-memory intermediate data will be appended to
233    /// this file incrementally. Once finished, this file will be moved to [`Self::finished_spill_files`].
234    ///
235    /// this is a tuple of:
236    /// 1. `InProgressSpillFile` - the file that is being written to
237    /// 2. `max_record_batch_memory` - the maximum memory usage of a single batch in this spill file.
238    in_progress_spill_file: Option<(InProgressSpillFile, usize)>,
239    /// If data has previously been spilled, the locations of the spill files (in
240    /// Arrow IPC format)
241    /// Within the same spill file, the data might be chunked into multiple batches,
242    /// and ordered by sort keys.
243    finished_spill_files: Vec<SortedSpillFile>,
244
245    // ========================================================================
246    // EXECUTION RESOURCES:
247    // Fields related to managing execution resources and monitoring performance.
248    // ========================================================================
249    /// Runtime metrics
250    metrics: ExternalSorterMetrics,
251    /// A handle to the runtime to get spill files
252    runtime: Arc<RuntimeEnv>,
253    /// Reservation for in_mem_batches
254    reservation: MemoryReservation,
255    spill_manager: SpillManager,
256
257    /// Reservation for the merging of in-memory batches. If the sort
258    /// might spill, `sort_spill_reservation_bytes` will be
259    /// pre-reserved to ensure there is some space for this sort/merge.
260    merge_reservation: MemoryReservation,
261    /// How much memory to reserve for performing in-memory sort/merges
262    /// prior to spilling.
263    sort_spill_reservation_bytes: usize,
264}
265
266impl ExternalSorter {
267    // TODO: make a builder or some other nicer API to avoid the
268    // clippy warning
269    #[expect(clippy::too_many_arguments)]
270    pub fn new(
271        partition_id: usize,
272        schema: SchemaRef,
273        expr: LexOrdering,
274        batch_size: usize,
275        sort_spill_reservation_bytes: usize,
276        sort_in_place_threshold_bytes: usize,
277        // Configured via `datafusion.execution.spill_compression`.
278        spill_compression: SpillCompression,
279        metrics: &ExecutionPlanMetricsSet,
280        runtime: Arc<RuntimeEnv>,
281    ) -> Result<Self> {
282        let metrics = ExternalSorterMetrics::new(metrics, partition_id);
283        let reservation = MemoryConsumer::new(format!("ExternalSorter[{partition_id}]"))
284            .with_can_spill(true)
285            .register(&runtime.memory_pool);
286
287        let merge_reservation =
288            MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]"))
289                .register(&runtime.memory_pool);
290
291        let spill_manager = SpillManager::new(
292            Arc::clone(&runtime),
293            metrics.spill_metrics.clone(),
294            Arc::clone(&schema),
295        )
296        .with_compression_type(spill_compression);
297
298        Ok(Self {
299            schema,
300            in_mem_batches: vec![],
301            in_progress_spill_file: None,
302            finished_spill_files: vec![],
303            expr,
304            metrics,
305            reservation,
306            spill_manager,
307            merge_reservation,
308            runtime,
309            batch_size,
310            sort_spill_reservation_bytes,
311            sort_in_place_threshold_bytes,
312        })
313    }
314
315    /// Appends an unsorted [`RecordBatch`] to `in_mem_batches`
316    ///
317    /// Updates memory usage metrics, and possibly triggers spilling to disk
318    async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> {
319        if input.num_rows() == 0 {
320            return Ok(());
321        }
322
323        self.reserve_memory_for_merge()?;
324        self.reserve_memory_for_batch_and_maybe_spill(&input)
325            .await?;
326
327        self.in_mem_batches.push(input);
328        Ok(())
329    }
330
331    fn spilled_before(&self) -> bool {
332        !self.finished_spill_files.is_empty()
333    }
334
335    /// Returns the final sorted output of all batches inserted via
336    /// [`Self::insert_batch`] as a stream of [`RecordBatch`]es.
337    ///
338    /// This process could either be:
339    ///
340    /// 1. An in-memory sort/merge (if the input fit in memory)
341    ///
342    /// 2. A combined streaming merge incorporating both in-memory
343    ///    batches and data from spill files on disk.
344    async fn sort(&mut self) -> Result<SendableRecordBatchStream> {
345        if self.spilled_before() {
346            // Sort `in_mem_batches` and spill it first. If there are many
347            // `in_mem_batches` and the memory limit is almost reached, merging
348            // them with the spilled files at the same time might cause OOM.
349            if !self.in_mem_batches.is_empty() {
350                self.sort_and_spill_in_mem_batches().await?;
351            }
352
353            // Transfer the pre-reserved merge memory to the streaming merge
354            // using `take()` instead of `new_empty()`. This ensures the merge
355            // stream starts with `sort_spill_reservation_bytes` already
356            // allocated, preventing starvation when concurrent sort partitions
357            // compete for pool memory. `take()` moves the bytes atomically
358            // without releasing them back to the pool, so other partitions
359            // cannot race to consume the freed memory.
360            StreamingMergeBuilder::new()
361                .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files))
362                .with_spill_manager(self.spill_manager.clone())
363                .with_schema(Arc::clone(&self.schema))
364                .with_expressions(&self.expr.clone())
365                .with_metrics(self.metrics.baseline.clone())
366                .with_batch_size(self.batch_size)
367                .with_fetch(None)
368                .with_reservation(self.merge_reservation.take())
369                .build()
370        } else {
371            // Release the memory reserved for merge back to the pool so
372            // there is some left when `in_mem_sort_stream` requests an
373            // allocation. Only needed for the non-spill path; the spill
374            // path transfers the reservation to the merge stream instead.
375            self.merge_reservation.free();
376            self.in_mem_sort_stream(self.metrics.baseline.clone())
377        }
378    }
379
380    /// How much memory is buffered in this `ExternalSorter`?
381    fn used(&self) -> usize {
382        self.reservation.size()
383    }
384
385    /// How much memory is reserved for the merge phase?
386    #[cfg(test)]
387    fn merge_reservation_size(&self) -> usize {
388        self.merge_reservation.size()
389    }
390
391    /// How many bytes have been spilled to disk?
392    fn spilled_bytes(&self) -> usize {
393        self.metrics.spill_metrics.spilled_bytes.value()
394    }
395
396    /// How many rows have been spilled to disk?
397    fn spilled_rows(&self) -> usize {
398        self.metrics.spill_metrics.spilled_rows.value()
399    }
400
401    /// How many spill files have been created?
402    fn spill_count(&self) -> usize {
403        self.metrics.spill_metrics.spill_file_count.value()
404    }
405
406    /// Appending globally sorted batches to the in-progress spill file, and clears
407    /// the `globally_sorted_batches` (also its memory reservation) afterwards.
408    async fn consume_and_spill_append(
409        &mut self,
410        globally_sorted_batches: &mut Vec<RecordBatch>,
411    ) -> Result<()> {
412        if globally_sorted_batches.is_empty() {
413            return Ok(());
414        }
415
416        // Lazily initialize the in-progress spill file
417        if self.in_progress_spill_file.is_none() {
418            self.in_progress_spill_file =
419                Some((self.spill_manager.create_in_progress_file("Sorting")?, 0));
420        }
421
422        debug!("Spilling sort data of ExternalSorter to disk whilst inserting");
423
424        let batches_to_spill = std::mem::take(globally_sorted_batches);
425        self.reservation.free();
426
427        let (in_progress_file, max_record_batch_size) =
428            self.in_progress_spill_file.as_mut().ok_or_else(|| {
429                internal_datafusion_err!("In-progress spill file should be initialized")
430            })?;
431
432        for batch in batches_to_spill {
433            let gc_sliced_size = in_progress_file.append_batch(&batch)?;
434
435            *max_record_batch_size = (*max_record_batch_size).max(gc_sliced_size);
436        }
437
438        assert_or_internal_err!(
439            globally_sorted_batches.is_empty(),
440            "This function consumes globally_sorted_batches, so it should be empty after taking."
441        );
442
443        Ok(())
444    }
445
446    /// Finishes the in-progress spill file and moves it to the finished spill files.
447    async fn spill_finish(&mut self) -> Result<()> {
448        let (mut in_progress_file, max_record_batch_memory) =
449            self.in_progress_spill_file.take().ok_or_else(|| {
450                internal_datafusion_err!("Should be called after `spill_append`")
451            })?;
452        let spill_file = in_progress_file.finish()?;
453
454        if let Some(spill_file) = spill_file {
455            self.finished_spill_files.push(SortedSpillFile {
456                file: spill_file,
457                max_record_batch_memory,
458            });
459        }
460
461        Ok(())
462    }
463
464    /// Sorts the in-memory batches and merges them into a single sorted run, then writes
465    /// the result to spill files.
466    async fn sort_and_spill_in_mem_batches(&mut self) -> Result<()> {
467        assert_or_internal_err!(
468            !self.in_mem_batches.is_empty(),
469            "in_mem_batches must not be empty when attempting to sort and spill"
470        );
471
472        // Release the memory reserved for merge back to the pool so
473        // there is some left when `in_mem_sort_stream` requests an
474        // allocation. At the end of this function, memory will be
475        // reserved again for the next spill.
476        self.merge_reservation.free();
477
478        let mut sorted_stream =
479            self.in_mem_sort_stream(self.metrics.baseline.intermediate())?;
480        // After `in_mem_sort_stream()` is constructed, all `in_mem_batches` is taken
481        // to construct a globally sorted stream.
482        assert_or_internal_err!(
483            self.in_mem_batches.is_empty(),
484            "in_mem_batches should be empty after constructing sorted stream"
485        );
486        // 'global' here refers to all buffered batches when the memory limit is
487        // reached. This variable will buffer the sorted batches after
488        // sort-preserving merge and incrementally append to spill files.
489        let mut globally_sorted_batches: Vec<RecordBatch> = vec![];
490
491        while let Some(batch) = sorted_stream.next().await {
492            let batch = batch?;
493            let sorted_size = get_reserved_bytes_for_record_batch(&batch)?;
494            if self.reservation.try_grow(sorted_size).is_err() {
495                // Although the reservation is not enough, the batch is
496                // already in memory, so it's okay to combine it with previously
497                // sorted batches, and spill together.
498                globally_sorted_batches.push(batch);
499                self.consume_and_spill_append(&mut globally_sorted_batches)
500                    .await?; // reservation is freed in spill()
501            } else {
502                globally_sorted_batches.push(batch);
503            }
504        }
505
506        // Drop early to free up memory reserved by the sorted stream, otherwise the
507        // upcoming `self.reserve_memory_for_merge()` may fail due to insufficient memory.
508        drop(sorted_stream);
509
510        self.consume_and_spill_append(&mut globally_sorted_batches)
511            .await?;
512        self.spill_finish().await?;
513
514        // Sanity check after spilling
515        let buffers_cleared_property =
516            self.in_mem_batches.is_empty() && globally_sorted_batches.is_empty();
517        assert_or_internal_err!(
518            buffers_cleared_property,
519            "in_mem_batches and globally_sorted_batches should be cleared before"
520        );
521
522        // Reserve headroom for next sort/merge
523        self.reserve_memory_for_merge()?;
524
525        Ok(())
526    }
527
528    /// Consumes in_mem_batches returning a sorted stream of
529    /// batches. This proceeds in one of two ways:
530    ///
531    /// # Small Datasets
532    ///
533    /// For "smaller" datasets, the data is first concatenated into a
534    /// single batch and then sorted. This is often faster than
535    /// sorting and then merging.
536    ///
537    /// ```text
538    ///        ┌─────┐
539    ///        │  2  │
540    ///        │  3  │
541    ///        │  1  │─ ─ ─ ─ ┐            ┌─────┐
542    ///        │  4  │                     │  2  │
543    ///        │  2  │        │            │  3  │
544    ///        └─────┘                     │  1  │             sorted output
545    ///        ┌─────┐        ▼            │  4  │                stream
546    ///        │  1  │                     │  2  │
547    ///        │  4  │─ ─▶ concat ─ ─ ─ ─ ▶│  1  │─ ─ ▶  sort  ─ ─ ─ ─ ─▶
548    ///        │  1  │                     │  4  │
549    ///        └─────┘        ▲            │  1  │
550    ///          ...          │            │ ... │
551    ///                                    │  4  │
552    ///        ┌─────┐        │            │  3  │
553    ///        │  4  │                     └─────┘
554    ///        │  3  │─ ─ ─ ─ ┘
555    ///        └─────┘
556    ///     in_mem_batches
557    /// ```
558    ///
559    /// # Larger datasets
560    ///
561    /// For larger datasets, the batches are first sorted individually
562    /// and then merged together.
563    ///
564    /// ```text
565    ///      ┌─────┐                ┌─────┐
566    ///      │  2  │                │  1  │
567    ///      │  3  │                │  2  │
568    ///      │  1  │─ ─▶  sort  ─ ─▶│  2  │─ ─ ─ ─ ─ ┐
569    ///      │  4  │                │  3  │
570    ///      │  2  │                │  4  │          │
571    ///      └─────┘                └─────┘               sorted output
572    ///      ┌─────┐                ┌─────┐          ▼       stream
573    ///      │  1  │                │  1  │
574    ///      │  4  │─ ▶  sort  ─ ─ ▶│  1  ├ ─ ─ ▶ merge  ─ ─ ─ ─▶
575    ///      │  1  │                │  4  │
576    ///      └─────┘                └─────┘          ▲
577    ///        ...       ...         ...             │
578    ///
579    ///      ┌─────┐                ┌─────┐          │
580    ///      │  4  │                │  3  │
581    ///      │  3  │─ ▶  sort  ─ ─ ▶│  4  │─ ─ ─ ─ ─ ┘
582    ///      └─────┘                └─────┘
583    ///
584    ///   in_mem_batches
585    /// ```
586    fn in_mem_sort_stream(
587        &mut self,
588        metrics: BaselineMetrics,
589    ) -> Result<SendableRecordBatchStream> {
590        if self.in_mem_batches.is_empty() {
591            return Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(
592                &self.schema,
593            ))));
594        }
595
596        // The elapsed compute timer is updated when the value is dropped.
597        // There is no need for an explicit call to drop.
598        let elapsed_compute = metrics.elapsed_compute().clone();
599        let _timer = elapsed_compute.timer();
600
601        // Please pay attention that any operation inside of `in_mem_sort_stream` will
602        // not perform any memory reservation. This is for avoiding the need of handling
603        // reservation failure and spilling in the middle of the sort/merge. The memory
604        // space for batches produced by the resulting stream will be reserved by the
605        // consumer of the stream.
606
607        if self.in_mem_batches.len() == 1 {
608            let batch = self.in_mem_batches.swap_remove(0);
609            let reservation = self.reservation.take();
610            return self.sort_batch_stream(batch, &metrics, reservation);
611        }
612
613        // If less than sort_in_place_threshold_bytes, concatenate and sort in place
614        if self.reservation.size() < self.sort_in_place_threshold_bytes {
615            // Concatenate memory batches together and sort
616            let batch = concat_batches(&self.schema, &self.in_mem_batches)?;
617            self.in_mem_batches.clear();
618            self.reservation
619                .try_resize(get_reserved_bytes_for_record_batch(&batch)?)
620                .map_err(Self::err_with_oom_context)?;
621            let reservation = self.reservation.take();
622            return self.sort_batch_stream(batch, &metrics, reservation);
623        }
624
625        let streams = std::mem::take(&mut self.in_mem_batches)
626            .into_iter()
627            .map(|batch| {
628                let metrics = self.metrics.baseline.intermediate();
629                let reservation = self
630                    .reservation
631                    .split(get_reserved_bytes_for_record_batch(&batch)?);
632                let input = self.sort_batch_stream(batch, &metrics, reservation)?;
633                Ok(spawn_buffered(input, 1))
634            })
635            .collect::<Result<_>>()?;
636
637        StreamingMergeBuilder::new()
638            .with_streams(streams)
639            .with_schema(Arc::clone(&self.schema))
640            .with_expressions(&self.expr.clone())
641            .with_metrics(metrics)
642            .with_batch_size(self.batch_size)
643            .with_fetch(None)
644            .with_reservation(self.merge_reservation.new_empty())
645            .build()
646    }
647
648    /// Sorts a single `RecordBatch` into a single stream.
649    ///
650    /// This may output multiple batches depending on the size of the
651    /// sorted data and the target batch size.
652    /// For single-batch output cases, `reservation` will be freed immediately after sorting,
653    /// as the batch will be output and is expected to be reserved by the consumer of the stream.
654    /// For multi-batch output cases, `reservation` will be grown to match the actual
655    /// size of sorted output, and as each batch is output, its memory will be freed from the reservation.
656    /// (This leads to the same behaviour, as futures are only evaluated when polled by the consumer.)
657    fn sort_batch_stream(
658        &self,
659        batch: RecordBatch,
660        metrics: &BaselineMetrics,
661        reservation: MemoryReservation,
662    ) -> Result<SendableRecordBatchStream> {
663        assert_eq!(
664            get_reserved_bytes_for_record_batch(&batch)?,
665            reservation.size()
666        );
667
668        let schema = batch.schema();
669        let expressions = self.expr.clone();
670        let batch_size = self.batch_size;
671        let output_row_metrics = metrics.output_rows().clone();
672
673        let stream = futures::stream::once(async move {
674            let schema = batch.schema();
675
676            // Sort the batch immediately and get all output batches
677            let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?;
678
679            // Resize the reservation to match the actual sorted output size.
680            // Using try_resize avoids a release-then-reacquire cycle, which
681            // matters for MemoryPool implementations where grow/shrink have
682            // non-trivial cost (e.g. JNI calls in Comet).
683            let total_sorted_size: usize = sorted_batches
684                .iter()
685                .map(get_record_batch_memory_size)
686                .sum();
687            reservation
688                .try_resize(total_sorted_size)
689                .map_err(Self::err_with_oom_context)?;
690
691            // Wrap in ReservationStream to hold the reservation
692            Result::<_, DataFusionError>::Ok(Box::pin(ReservationStream::new(
693                Arc::clone(&schema),
694                Box::pin(RecordBatchStreamAdapter::new(
695                    Arc::clone(&schema),
696                    futures::stream::iter(sorted_batches.into_iter().map(Ok)),
697                )),
698                reservation,
699            )) as SendableRecordBatchStream)
700        })
701        .try_flatten()
702        .map(move |batch| match batch {
703            Ok(batch) => {
704                output_row_metrics.add(batch.num_rows());
705                Ok(batch)
706            }
707            Err(e) => Err(e),
708        });
709
710        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
711    }
712
713    /// If this sort may spill, pre-allocates
714    /// `sort_spill_reservation_bytes` of memory to guarantee memory
715    /// left for the in memory sort/merge.
716    fn reserve_memory_for_merge(&mut self) -> Result<()> {
717        // Reserve headroom for next merge sort
718        if self.runtime.disk_manager.tmp_files_enabled() {
719            let size = self.sort_spill_reservation_bytes;
720            if self.merge_reservation.size() != size {
721                self.merge_reservation
722                    .try_resize(size)
723                    .map_err(Self::err_with_oom_context)?;
724            }
725        }
726
727        Ok(())
728    }
729
730    /// Reserves memory to be able to accommodate the given batch.
731    /// If memory is scarce, tries to spill current in-memory batches to disk first.
732    async fn reserve_memory_for_batch_and_maybe_spill(
733        &mut self,
734        input: &RecordBatch,
735    ) -> Result<()> {
736        let size = get_reserved_bytes_for_record_batch(input)?;
737
738        match self.reservation.try_grow(size) {
739            Ok(_) => Ok(()),
740            Err(e) => {
741                if self.in_mem_batches.is_empty() {
742                    return Err(Self::err_with_oom_context(e));
743                }
744
745                // Spill and try again.
746                self.sort_and_spill_in_mem_batches().await?;
747                self.reservation
748                    .try_grow(size)
749                    .map_err(Self::err_with_oom_context)
750            }
751        }
752    }
753
754    /// Wraps the error with a context message suggesting settings to tweak.
755    /// This is meant to be used with DataFusionError::ResourcesExhausted only.
756    fn err_with_oom_context(e: DataFusionError) -> DataFusionError {
757        match e {
758            DataFusionError::ResourcesExhausted(_) => e.context(
759                "Not enough memory to continue external sort. \
760                    Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', \
761                    or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'."
762            ),
763            // This is not an OOM error, so just return it as is.
764            _ => e,
765        }
766    }
767}
768
769/// Estimate how much memory is needed to sort a `RecordBatch`.
770///
771/// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves
772/// creating sorted copies of sorted columns in record batches for speeding up comparison
773/// in sorting and merging. The sorted copies are in either row format or array format.
774/// Please refer to cursor.rs and stream.rs for more details. No matter what format the
775/// sorted copies are, they will use more memory than the original record batch.
776///
777/// This can basically be calculated as the sum of the actual space it takes in
778/// memory (which would be larger for a sliced batch), and the size of the actual data.
779pub(crate) fn get_reserved_bytes_for_record_batch_size(
780    record_batch_size: usize,
781    sliced_size: usize,
782) -> usize {
783    // Even 2x may not be enough for some cases, but it's a good enough estimation as a baseline.
784    // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes`
785    // to compensate for the extra memory needed.
786    record_batch_size + sliced_size
787}
788
789/// Estimate how much memory is needed to sort a `RecordBatch`.
790/// This will just call `get_reserved_bytes_for_record_batch_size` with the
791/// memory size of the record batch and its sliced size.
792pub(crate) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result<usize> {
793    batch.get_sliced_size().map(|sliced_size| {
794        get_reserved_bytes_for_record_batch_size(
795            get_record_batch_memory_size(batch),
796            sliced_size,
797        )
798    })
799}
800
801impl Debug for ExternalSorter {
802    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
803        f.debug_struct("ExternalSorter")
804            .field("memory_used", &self.used())
805            .field("spilled_bytes", &self.spilled_bytes())
806            .field("spilled_rows", &self.spilled_rows())
807            .field("spill_count", &self.spill_count())
808            .finish()
809    }
810}
811
812pub fn sort_batch(
813    batch: &RecordBatch,
814    expressions: &LexOrdering,
815    fetch: Option<usize>,
816) -> Result<RecordBatch> {
817    let sort_columns = expressions
818        .iter()
819        .map(|expr| expr.evaluate_to_sort_column(batch))
820        .collect::<Result<Vec<_>>>()?;
821
822    let indices = lexsort_to_indices(&sort_columns, fetch)?;
823    let columns = take_arrays(batch.columns(), &indices, None)?;
824
825    let options = RecordBatchOptions::new().with_row_count(Some(indices.len()));
826    Ok(RecordBatch::try_new_with_options(
827        batch.schema(),
828        columns,
829        &options,
830    )?)
831}
832
833/// Sort a batch and return the result as multiple batches of size `batch_size`.
834/// This is useful when you want to avoid creating one large sorted batch in memory,
835/// and instead want to process the sorted data in smaller chunks.
836pub fn sort_batch_chunked(
837    batch: &RecordBatch,
838    expressions: &LexOrdering,
839    batch_size: usize,
840) -> Result<Vec<RecordBatch>> {
841    IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect()
842}
843
844/// Sort execution plan.
845///
846/// Support sorting datasets that are larger than the memory allotted
847/// by the memory manager, by spilling to disk.
848#[derive(Debug, Clone)]
849pub struct SortExec {
850    /// Input schema
851    pub(crate) input: Arc<dyn ExecutionPlan>,
852    /// Sort expressions
853    expr: LexOrdering,
854    /// Containing all metrics set created during sort
855    metrics_set: ExecutionPlanMetricsSet,
856    /// Preserve partitions of input plan. If false, the input partitions
857    /// will be sorted and merged into a single output partition.
858    preserve_partitioning: bool,
859    /// Fetch highest/lowest n results
860    fetch: Option<usize>,
861    /// Normalized common sort prefix between the input and the sort expressions (only used with fetch)
862    common_sort_prefix: Vec<PhysicalSortExpr>,
863    /// Cache holding plan properties like equivalences, output partitioning etc.
864    cache: Arc<PlanProperties>,
865    /// Filter matching the state of the sort for dynamic filter pushdown.
866    /// If `fetch` is `Some`, this will also be set and a TopK operator may be used.
867    /// If `fetch` is `None`, this will be `None`.
868    filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
869}
870
871impl SortExec {
872    /// Create a new sort execution plan that produces a single,
873    /// sorted output partition.
874    pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
875        let preserve_partitioning = false;
876        let (cache, sort_prefix) =
877            Self::compute_properties(&input, expr.clone(), preserve_partitioning)
878                .unwrap();
879        Self {
880            expr,
881            input,
882            metrics_set: ExecutionPlanMetricsSet::new(),
883            preserve_partitioning,
884            fetch: None,
885            common_sort_prefix: sort_prefix,
886            cache: Arc::new(cache),
887            filter: None,
888        }
889    }
890
891    /// Whether this `SortExec` preserves partitioning of the children
892    pub fn preserve_partitioning(&self) -> bool {
893        self.preserve_partitioning
894    }
895
896    /// Specify the partitioning behavior of this sort exec
897    ///
898    /// If `preserve_partitioning` is true, sorts each partition
899    /// individually, producing one sorted stream for each input partition.
900    ///
901    /// If `preserve_partitioning` is false, sorts and merges all
902    /// input partitions producing a single, sorted partition.
903    pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
904        self.preserve_partitioning = preserve_partitioning;
905        Arc::make_mut(&mut self.cache).partitioning =
906            Self::output_partitioning_helper(&self.input, self.preserve_partitioning);
907        self
908    }
909
910    /// Add or reset `self.filter` to a new `TopKDynamicFilters`.
911    fn create_filter(&self) -> Arc<RwLock<TopKDynamicFilters>> {
912        let children = self
913            .expr
914            .iter()
915            .map(|sort_expr| Arc::clone(&sort_expr.expr))
916            .collect::<Vec<_>>();
917        Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
918            DynamicFilterPhysicalExpr::new(children, lit(true)),
919        ))))
920    }
921
922    fn cloned(&self) -> Self {
923        SortExec {
924            input: Arc::clone(&self.input),
925            expr: self.expr.clone(),
926            metrics_set: self.metrics_set.clone(),
927            preserve_partitioning: self.preserve_partitioning,
928            common_sort_prefix: self.common_sort_prefix.clone(),
929            fetch: self.fetch,
930            cache: Arc::clone(&self.cache),
931            filter: self.filter.clone(),
932        }
933    }
934
935    /// Modify how many rows to include in the result
936    ///
937    /// If None, then all rows will be returned, in sorted order.
938    /// If Some, then only the top `fetch` rows will be returned.
939    /// This can reduce the memory pressure required by the sort
940    /// operation since rows that are not going to be included
941    /// can be dropped.
942    pub fn with_fetch(&self, fetch: Option<usize>) -> Self {
943        let mut cache = PlanProperties::clone(&self.cache);
944        // If the SortExec can emit incrementally (that means the sort requirements
945        // and properties of the input match), the SortExec can generate its result
946        // without scanning the entire input when a fetch value exists.
947        let is_pipeline_friendly = matches!(
948            cache.emission_type,
949            EmissionType::Incremental | EmissionType::Both
950        );
951        if fetch.is_some() && is_pipeline_friendly {
952            cache = cache.with_boundedness(Boundedness::Bounded);
953        }
954        let filter = fetch.is_some().then(|| {
955            // If we already have a filter, keep it. Otherwise, create a new one.
956            self.filter.clone().unwrap_or_else(|| self.create_filter())
957        });
958        let mut new_sort = self.cloned();
959        new_sort.fetch = fetch;
960        new_sort.cache = cache.into();
961        new_sort.filter = filter;
962        new_sort
963    }
964
965    /// Input schema
966    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
967        &self.input
968    }
969
970    /// Sort expressions
971    pub fn expr(&self) -> &LexOrdering {
972        &self.expr
973    }
974
975    /// If `Some(fetch)`, limits output to only the first "fetch" items
976    pub fn fetch(&self) -> Option<usize> {
977        self.fetch
978    }
979
980    /// Returns the dynamic filter expression for this sort (TopK), if set.
981    pub fn dynamic_filter_expr(&self) -> Option<Arc<DynamicFilterPhysicalExpr>> {
982        self.filter.as_ref().map(|f| f.read().expr())
983    }
984
985    /// Replace the dynamic filter expression for this sort.
986    ///
987    ///
988    /// Resets any internal state which may depend on the previous dynamic filter.
989    ///
990    /// Validates that the filter's children reference valid columns in
991    /// the sort's input schema.
992    pub fn with_dynamic_filter_expr(
993        mut self,
994        filter: Arc<DynamicFilterPhysicalExpr>,
995    ) -> Result<Self> {
996        let input_schema = self.input.schema();
997        for child in filter.children() {
998            child.data_type(&input_schema)?;
999        }
1000        self.filter = Some(Arc::new(RwLock::new(TopKDynamicFilters::new(filter))));
1001        Ok(self)
1002    }
1003
1004    fn output_partitioning_helper(
1005        input: &Arc<dyn ExecutionPlan>,
1006        preserve_partitioning: bool,
1007    ) -> Partitioning {
1008        // Get output partitioning:
1009        if preserve_partitioning {
1010            input.output_partitioning().clone()
1011        } else {
1012            Partitioning::UnknownPartitioning(1)
1013        }
1014    }
1015
1016    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1017    /// It also returns the common sort prefix between the input and the sort expressions.
1018    fn compute_properties(
1019        input: &Arc<dyn ExecutionPlan>,
1020        sort_exprs: LexOrdering,
1021        preserve_partitioning: bool,
1022    ) -> Result<(PlanProperties, Vec<PhysicalSortExpr>)> {
1023        let (sort_prefix, sort_satisfied) = input
1024            .equivalence_properties()
1025            .extract_common_sort_prefix(sort_exprs.clone())?;
1026
1027        // The emission type depends on whether the input is already sorted:
1028        // - If already fully sorted, we can emit results in the same way as the input
1029        // - If not sorted, we must wait until all data is processed to emit results (Final)
1030        let emission_type = if sort_satisfied {
1031            input.pipeline_behavior()
1032        } else {
1033            EmissionType::Final
1034        };
1035
1036        // The boundedness depends on whether the input is already sorted:
1037        // - If already sorted, we have the same property as the input
1038        // - If not sorted and input is unbounded, we require infinite memory and generates
1039        //   unbounded data (not practical).
1040        // - If not sorted and input is bounded, then the SortExec is bounded, too.
1041        let boundedness = if sort_satisfied {
1042            input.boundedness()
1043        } else {
1044            match input.boundedness() {
1045                Boundedness::Unbounded { .. } => Boundedness::Unbounded {
1046                    requires_infinite_memory: true,
1047                },
1048                bounded => bounded,
1049            }
1050        };
1051
1052        // Calculate equivalence properties; i.e. reset the ordering equivalence
1053        // class with the new ordering:
1054        let mut eq_properties = input.equivalence_properties().clone();
1055        eq_properties.reorder(sort_exprs)?;
1056
1057        // Get output partitioning:
1058        let output_partitioning =
1059            Self::output_partitioning_helper(input, preserve_partitioning);
1060
1061        Ok((
1062            PlanProperties::new(
1063                eq_properties,
1064                output_partitioning,
1065                emission_type,
1066                boundedness,
1067            ),
1068            sort_prefix,
1069        ))
1070    }
1071}
1072
1073impl DisplayAs for SortExec {
1074    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result {
1075        match t {
1076            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1077                let preserve_partitioning = self.preserve_partitioning;
1078                match self.fetch {
1079                    Some(fetch) => {
1080                        write!(
1081                            f,
1082                            "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]",
1083                            self.expr
1084                        )?;
1085                        if let Some(filter) = &self.filter
1086                            && let Ok(current) = filter.read().expr().current()
1087                            && !current.eq(&lit(true))
1088                        {
1089                            write!(f, ", filter=[{current}]")?;
1090                        }
1091                        if !self.common_sort_prefix.is_empty() {
1092                            write!(f, ", sort_prefix=[")?;
1093                            let mut first = true;
1094                            for sort_expr in &self.common_sort_prefix {
1095                                if first {
1096                                    first = false;
1097                                } else {
1098                                    write!(f, ", ")?;
1099                                }
1100                                write!(f, "{sort_expr}")?;
1101                            }
1102                            write!(f, "]")
1103                        } else {
1104                            Ok(())
1105                        }
1106                    }
1107                    None => write!(
1108                        f,
1109                        "SortExec: expr=[{}], preserve_partitioning=[{preserve_partitioning}]",
1110                        self.expr
1111                    ),
1112                }
1113            }
1114            DisplayFormatType::TreeRender => match self.fetch {
1115                Some(fetch) => {
1116                    writeln!(f, "{}", self.expr)?;
1117                    writeln!(f, "limit={fetch}")
1118                }
1119                None => {
1120                    writeln!(f, "{}", self.expr)
1121                }
1122            },
1123        }
1124    }
1125}
1126
1127impl ExecutionPlan for SortExec {
1128    fn name(&self) -> &'static str {
1129        match self.fetch {
1130            Some(_) => "SortExec(TopK)",
1131            None => "SortExec",
1132        }
1133    }
1134
1135    fn properties(&self) -> &Arc<PlanProperties> {
1136        &self.cache
1137    }
1138
1139    fn required_input_distribution(&self) -> Vec<Distribution> {
1140        if self.preserve_partitioning {
1141            vec![Distribution::UnspecifiedDistribution]
1142        } else {
1143            // global sort
1144            // TODO support RangePartition and OrderedDistribution
1145            vec![Distribution::SinglePartition]
1146        }
1147    }
1148
1149    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1150        vec![&self.input]
1151    }
1152
1153    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
1154        vec![false]
1155    }
1156
1157    fn with_new_children(
1158        self: Arc<Self>,
1159        children: Vec<Arc<dyn ExecutionPlan>>,
1160    ) -> Result<Arc<dyn ExecutionPlan>> {
1161        let mut new_sort = self.cloned();
1162        assert_eq!(children.len(), 1, "SortExec should have exactly one child");
1163        new_sort.input = Arc::clone(&children[0]);
1164
1165        if !has_same_children_properties(self.as_ref(), &children)? {
1166            // Recompute the properties based on the new input since they may have changed
1167            let (cache, sort_prefix) = Self::compute_properties(
1168                &new_sort.input,
1169                new_sort.expr.clone(),
1170                new_sort.preserve_partitioning,
1171            )?;
1172            new_sort.cache = Arc::new(cache);
1173            new_sort.common_sort_prefix = sort_prefix;
1174        }
1175
1176        Ok(Arc::new(new_sort))
1177    }
1178
1179    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
1180        let children = self.children().into_iter().cloned().collect();
1181        let new_sort = self.with_new_children(children)?;
1182        let mut new_sort = new_sort
1183            .downcast_ref::<SortExec>()
1184            .expect("cloned 1 lines above this line, we know the type")
1185            .clone();
1186        // Our dynamic filter and execution metrics are the state we need to reset.
1187        new_sort.filter = Some(new_sort.create_filter());
1188        new_sort.metrics_set = ExecutionPlanMetricsSet::new();
1189
1190        Ok(Arc::new(new_sort))
1191    }
1192
1193    fn execute(
1194        &self,
1195        partition: usize,
1196        context: Arc<TaskContext>,
1197    ) -> Result<SendableRecordBatchStream> {
1198        trace!(
1199            "Start SortExec::execute for partition {} of context session_id {} and task_id {:?}",
1200            partition,
1201            context.session_id(),
1202            context.task_id()
1203        );
1204
1205        let mut input = self.input.execute(partition, Arc::clone(&context))?;
1206
1207        let execution_options = &context.session_config().options().execution;
1208
1209        trace!("End SortExec's input.execute for partition: {partition}");
1210
1211        let sort_satisfied = self
1212            .input
1213            .equivalence_properties()
1214            .ordering_satisfy(self.expr.clone())?;
1215
1216        match (sort_satisfied, self.fetch.as_ref()) {
1217            (true, Some(fetch)) => Ok(Box::pin(LimitStream::new(
1218                input,
1219                0,
1220                Some(*fetch),
1221                BaselineMetrics::new(&self.metrics_set, partition),
1222            ))),
1223            (true, None) => Ok(input),
1224            (false, Some(fetch)) => {
1225                let filter = self.filter.clone();
1226                let mut topk = TopK::try_new(
1227                    partition,
1228                    input.schema(),
1229                    self.common_sort_prefix.clone(),
1230                    self.expr.clone(),
1231                    *fetch,
1232                    context.session_config().batch_size(),
1233                    context.runtime_env(),
1234                    &self.metrics_set,
1235                    Arc::clone(&unwrap_or_internal_err!(filter)),
1236                )?;
1237                Ok(Box::pin(RecordBatchStreamAdapter::new(
1238                    self.schema(),
1239                    futures::stream::once(async move {
1240                        while let Some(batch) = input.next().await {
1241                            let batch = batch?;
1242                            topk.insert_batch(batch)?;
1243                            if topk.finished {
1244                                break;
1245                            }
1246                        }
1247                        drop(input);
1248                        topk.emit()
1249                    })
1250                    .try_flatten(),
1251                )))
1252            }
1253            (false, None) => {
1254                let mut sorter = ExternalSorter::new(
1255                    partition,
1256                    input.schema(),
1257                    self.expr.clone(),
1258                    context.session_config().batch_size(),
1259                    execution_options.sort_spill_reservation_bytes,
1260                    execution_options.sort_in_place_threshold_bytes,
1261                    context.session_config().spill_compression(),
1262                    &self.metrics_set,
1263                    context.runtime_env(),
1264                )?;
1265                Ok(Box::pin(RecordBatchStreamAdapter::new(
1266                    self.schema(),
1267                    futures::stream::once(async move {
1268                        while let Some(batch) = input.next().await {
1269                            let batch = batch?;
1270                            sorter.insert_batch(batch).await?;
1271                        }
1272                        drop(input);
1273                        sorter.sort().await
1274                    })
1275                    .try_flatten(),
1276                )))
1277            }
1278        }
1279    }
1280
1281    fn metrics(&self) -> Option<MetricsSet> {
1282        Some(self.metrics_set.clone_inner())
1283    }
1284
1285    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
1286        let p = if !self.preserve_partitioning() {
1287            None
1288        } else {
1289            partition
1290        };
1291        let stats = Arc::unwrap_or_clone(self.input.partition_statistics(p)?);
1292        Ok(Arc::new(stats.with_fetch(self.fetch, 0, 1)?))
1293    }
1294
1295    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
1296        Some(Arc::new(SortExec::with_fetch(self, limit)))
1297    }
1298
1299    fn fetch(&self) -> Option<usize> {
1300        self.fetch
1301    }
1302
1303    fn cardinality_effect(&self) -> CardinalityEffect {
1304        if self.fetch.is_none() {
1305            CardinalityEffect::Equal
1306        } else {
1307            CardinalityEffect::LowerEqual
1308        }
1309    }
1310
1311    /// Tries to swap the projection with its input [`SortExec`]. If it can be done,
1312    /// it returns the new swapped version having the [`SortExec`] as the top plan.
1313    /// Otherwise, it returns None.
1314    fn try_swapping_with_projection(
1315        &self,
1316        projection: &ProjectionExec,
1317    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1318        // If the projection does not narrow the schema, we should not try to push it down.
1319        if projection.expr().len() >= projection.input().schema().fields().len() {
1320            return Ok(None);
1321        }
1322
1323        let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
1324        else {
1325            return Ok(None);
1326        };
1327
1328        Ok(Some(Arc::new(
1329            SortExec::new(updated_exprs, make_with_child(projection, self.input())?)
1330                .with_fetch(self.fetch())
1331                .with_preserve_partitioning(self.preserve_partitioning()),
1332        )))
1333    }
1334
1335    fn gather_filters_for_pushdown(
1336        &self,
1337        phase: FilterPushdownPhase,
1338        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1339        config: &datafusion_common::config::ConfigOptions,
1340    ) -> Result<FilterDescription> {
1341        if phase != FilterPushdownPhase::Post {
1342            if self.fetch.is_some() {
1343                return Ok(FilterDescription::all_unsupported(
1344                    &parent_filters,
1345                    &self.children(),
1346                ));
1347            }
1348            return FilterDescription::from_children(parent_filters, &self.children());
1349        }
1350
1351        // In Post phase: block parent filters when fetch is set,
1352        // but still push the TopK dynamic filter (self-filter).
1353        let mut child = if self.fetch.is_some() {
1354            ChildFilterDescription::all_unsupported(&parent_filters)
1355        } else {
1356            ChildFilterDescription::from_child(&parent_filters, self.input())?
1357        };
1358
1359        if let Some(filter) = &self.filter
1360            && config.optimizer.enable_topk_dynamic_filter_pushdown
1361        {
1362            child = child.with_self_filter(filter.read().expr());
1363        }
1364
1365        Ok(FilterDescription::new().with_child(child))
1366    }
1367
1368    fn handle_child_pushdown_result(
1369        &self,
1370        _phase: FilterPushdownPhase,
1371        child_pushdown_result: ChildPushdownResult,
1372        _config: &datafusion_common::config::ConfigOptions,
1373    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1374        // For a plain sort (no fetch) we intercept any unsupported filters
1375        // by inserting a FilterExec below this Sort. Moving the filter below
1376        // Sort is safe because Sort preserves all rows.
1377        //
1378        // Why not fetch (TopK)?
1379        // A sort with fetch limits the number of output rows.  Inserting a
1380        // FilterExec *below* the TopK would change semantics.  A filter *above*
1381        // the TopK is supposed to post-filter its output (e.g. "take the top 10
1382        // rows, then keep only those with a > 5").  Pushing the filter below
1383        // Sort changes the meaning to "filter first, then take top 10", which
1384        // produces a different result.
1385        if self.fetch.is_some() {
1386            return Ok(FilterPushdownPropagation::if_all(child_pushdown_result));
1387        }
1388
1389        // Collect parent filters that were NOT successfully pushed to our child.
1390        let unsupported_filters: Vec<Arc<dyn PhysicalExpr>> = child_pushdown_result
1391            .parent_filters
1392            .iter()
1393            .filter(|&f| matches!(f.all(), PushedDown::No))
1394            .map(|f| Arc::clone(&f.filter))
1395            .collect();
1396
1397        if unsupported_filters.is_empty() {
1398            // All filters were pushed — nothing extra to do.
1399            return Ok(FilterPushdownPropagation::if_all(child_pushdown_result));
1400        }
1401
1402        // Build a single conjunctive predicate from the unsupported filters
1403        // and insert a FilterExec between this SortExec and its child.
1404        let predicate = datafusion_physical_expr::conjunction(unsupported_filters);
1405        let new_child =
1406            Arc::new(FilterExec::try_new(predicate, Arc::clone(self.input()))?)
1407                as Arc<dyn ExecutionPlan>;
1408        let new_sort = Arc::new(
1409            SortExec::new(self.expr.clone(), new_child)
1410                .with_fetch(self.fetch())
1411                .with_preserve_partitioning(self.preserve_partitioning()),
1412        ) as Arc<dyn ExecutionPlan>;
1413
1414        Ok(FilterPushdownPropagation {
1415            filters: vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()],
1416            updated_node: Some(new_sort),
1417        })
1418    }
1419}
1420
1421#[cfg(test)]
1422mod tests {
1423    use std::collections::HashMap;
1424    use std::pin::Pin;
1425    use std::task::{Context, Poll};
1426
1427    use super::*;
1428    use crate::coalesce_partitions::CoalescePartitionsExec;
1429    use crate::collect;
1430    use crate::empty::EmptyExec;
1431    use crate::execution_plan::Boundedness;
1432    use crate::expressions::col;
1433    use crate::filter_pushdown::{FilterPushdownPhase, PushedDown};
1434    use crate::test;
1435    use crate::test::TestMemoryExec;
1436    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
1437    use crate::test::{assert_is_pending, make_partition};
1438
1439    use arrow::array::*;
1440    use arrow::compute::SortOptions;
1441    use arrow::datatypes::*;
1442    use datafusion_common::ScalarValue;
1443    use datafusion_common::cast::as_primitive_array;
1444    use datafusion_common::config::ConfigOptions;
1445    use datafusion_common::test_util::batches_to_string;
1446    use datafusion_execution::RecordBatchStream;
1447    use datafusion_execution::config::SessionConfig;
1448    use datafusion_execution::memory_pool::{
1449        GreedyMemoryPool, MemoryConsumer, MemoryPool,
1450    };
1451    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1452    use datafusion_physical_expr::EquivalenceProperties;
1453    use datafusion_physical_expr::expressions::{Column, Literal};
1454
1455    use futures::{FutureExt, Stream, TryStreamExt};
1456    use insta::assert_snapshot;
1457
1458    #[derive(Debug, Clone)]
1459    pub struct SortedUnboundedExec {
1460        schema: Schema,
1461        batch_size: u64,
1462        cache: Arc<PlanProperties>,
1463    }
1464
1465    impl DisplayAs for SortedUnboundedExec {
1466        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result {
1467            match t {
1468                DisplayFormatType::Default
1469                | DisplayFormatType::Verbose
1470                | DisplayFormatType::TreeRender => write!(f, "UnboundableExec",).unwrap(),
1471            }
1472            Ok(())
1473        }
1474    }
1475
1476    impl SortedUnboundedExec {
1477        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1478            let mut eq_properties = EquivalenceProperties::new(schema);
1479            eq_properties.add_ordering([PhysicalSortExpr::new_default(Arc::new(
1480                Column::new("c1", 0),
1481            ))]);
1482            PlanProperties::new(
1483                eq_properties,
1484                Partitioning::UnknownPartitioning(1),
1485                EmissionType::Final,
1486                Boundedness::Unbounded {
1487                    requires_infinite_memory: false,
1488                },
1489            )
1490        }
1491    }
1492
1493    impl ExecutionPlan for SortedUnboundedExec {
1494        fn name(&self) -> &'static str {
1495            Self::static_name()
1496        }
1497
1498        fn properties(&self) -> &Arc<PlanProperties> {
1499            &self.cache
1500        }
1501
1502        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1503            vec![]
1504        }
1505
1506        fn with_new_children(
1507            self: Arc<Self>,
1508            _: Vec<Arc<dyn ExecutionPlan>>,
1509        ) -> Result<Arc<dyn ExecutionPlan>> {
1510            Ok(self)
1511        }
1512
1513        fn execute(
1514            &self,
1515            _partition: usize,
1516            _context: Arc<TaskContext>,
1517        ) -> Result<SendableRecordBatchStream> {
1518            Ok(Box::pin(SortedUnboundedStream {
1519                schema: Arc::new(self.schema.clone()),
1520                batch_size: self.batch_size,
1521                offset: 0,
1522            }))
1523        }
1524    }
1525
1526    #[derive(Debug)]
1527    pub struct SortedUnboundedStream {
1528        schema: SchemaRef,
1529        batch_size: u64,
1530        offset: u64,
1531    }
1532
1533    impl Stream for SortedUnboundedStream {
1534        type Item = Result<RecordBatch>;
1535
1536        fn poll_next(
1537            mut self: Pin<&mut Self>,
1538            _cx: &mut Context<'_>,
1539        ) -> Poll<Option<Self::Item>> {
1540            let batch = SortedUnboundedStream::create_record_batch(
1541                Arc::clone(&self.schema),
1542                self.offset,
1543                self.batch_size,
1544            );
1545            self.offset += self.batch_size;
1546            Poll::Ready(Some(Ok(batch)))
1547        }
1548    }
1549
1550    impl RecordBatchStream for SortedUnboundedStream {
1551        fn schema(&self) -> SchemaRef {
1552            Arc::clone(&self.schema)
1553        }
1554    }
1555
1556    impl SortedUnboundedStream {
1557        fn create_record_batch(
1558            schema: SchemaRef,
1559            offset: u64,
1560            batch_size: u64,
1561        ) -> RecordBatch {
1562            let values = (0..batch_size).map(|i| offset + i).collect::<Vec<_>>();
1563            let array = UInt64Array::from(values);
1564            let array_ref: ArrayRef = Arc::new(array);
1565            RecordBatch::try_new(schema, vec![array_ref]).unwrap()
1566        }
1567    }
1568
1569    #[tokio::test]
1570    async fn test_in_mem_sort() -> Result<()> {
1571        let task_ctx = Arc::new(TaskContext::default());
1572        let partitions = 4;
1573        let csv = test::scan_partitioned(partitions);
1574        let schema = csv.schema();
1575
1576        let sort_exec = Arc::new(SortExec::new(
1577            [PhysicalSortExpr {
1578                expr: col("i", &schema)?,
1579                options: SortOptions::default(),
1580            }]
1581            .into(),
1582            Arc::new(CoalescePartitionsExec::new(csv)),
1583        ));
1584
1585        let result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
1586
1587        assert_eq!(result.len(), 1);
1588        assert_eq!(result[0].num_rows(), 400);
1589        assert_eq!(
1590            task_ctx.runtime_env().memory_pool.reserved(),
1591            0,
1592            "The sort should have returned all memory used back to the memory manager"
1593        );
1594
1595        Ok(())
1596    }
1597
1598    #[tokio::test]
1599    async fn test_sort_spill() -> Result<()> {
1600        // trigger spill w/ 100 batches
1601        let session_config = SessionConfig::new();
1602        let sort_spill_reservation_bytes = session_config
1603            .options()
1604            .execution
1605            .sort_spill_reservation_bytes;
1606        let runtime = RuntimeEnvBuilder::new()
1607            .with_memory_limit(sort_spill_reservation_bytes + 12288, 1.0)
1608            .build_arc()?;
1609        let task_ctx = Arc::new(
1610            TaskContext::default()
1611                .with_session_config(session_config)
1612                .with_runtime(runtime),
1613        );
1614
1615        // The input has 100 partitions, each partition has a batch containing 100 rows.
1616        // Each row has a single Int32 column with values 0..100. The total size of the
1617        // input is roughly 40000 bytes.
1618        let partitions = 100;
1619        let input = test::scan_partitioned(partitions);
1620        let schema = input.schema();
1621
1622        let sort_exec = Arc::new(SortExec::new(
1623            [PhysicalSortExpr {
1624                expr: col("i", &schema)?,
1625                options: SortOptions::default(),
1626            }]
1627            .into(),
1628            Arc::new(CoalescePartitionsExec::new(input)),
1629        ));
1630
1631        let result = collect(
1632            Arc::clone(&sort_exec) as Arc<dyn ExecutionPlan>,
1633            Arc::clone(&task_ctx),
1634        )
1635        .await?;
1636
1637        assert_eq!(result.len(), 2);
1638
1639        // Now, validate metrics
1640        let metrics = sort_exec.metrics().unwrap();
1641
1642        assert_eq!(metrics.output_rows().unwrap(), 10000);
1643        assert!(metrics.elapsed_compute().unwrap() > 0);
1644
1645        let spill_count = metrics.spill_count().unwrap();
1646        let spilled_rows = metrics.spilled_rows().unwrap();
1647        let spilled_bytes = metrics.spilled_bytes().unwrap();
1648        // Processing 40000 bytes of data using 12288 bytes of memory requires 3 spills
1649        // unless we do something really clever. It will spill roughly 9000+ rows and 36000
1650        // bytes. We leave a little wiggle room for the actual numbers.
1651        assert!((3..=10).contains(&spill_count));
1652        assert!((9000..=10000).contains(&spilled_rows));
1653        assert!((38000..=44000).contains(&spilled_bytes));
1654
1655        let columns = result[0].columns();
1656
1657        let i = as_primitive_array::<Int32Type>(&columns[0])?;
1658        assert_eq!(i.value(0), 0);
1659        assert_eq!(i.value(i.len() - 1), 81);
1660        assert_eq!(
1661            task_ctx.runtime_env().memory_pool.reserved(),
1662            0,
1663            "The sort should have returned all memory used back to the memory manager"
1664        );
1665
1666        Ok(())
1667    }
1668
1669    #[tokio::test]
1670    async fn test_batch_reservation_error() -> Result<()> {
1671        // Pick a memory limit and sort_spill_reservation that make the first batch reservation fail.
1672        let merge_reservation: usize = 0; // Set to 0 for simplicity
1673
1674        let session_config =
1675            SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation);
1676
1677        let plan = test::scan_partitioned(1);
1678
1679        // Read the first record batch to determine the actual memory requirement
1680        let expected_batch_reservation = {
1681            let temp_ctx = Arc::new(TaskContext::default());
1682            let mut stream = plan.execute(0, Arc::clone(&temp_ctx))?;
1683            let first_batch = stream.next().await.unwrap()?;
1684            get_reserved_bytes_for_record_batch(&first_batch)?
1685        };
1686
1687        // Set memory limit just short of what we need
1688        let memory_limit: usize = expected_batch_reservation + merge_reservation - 1;
1689
1690        let runtime = RuntimeEnvBuilder::new()
1691            .with_memory_limit(memory_limit, 1.0)
1692            .build_arc()?;
1693        let task_ctx = Arc::new(
1694            TaskContext::default()
1695                .with_session_config(session_config)
1696                .with_runtime(runtime),
1697        );
1698
1699        // Verify that our memory limit is insufficient
1700        {
1701            let mut stream = plan.execute(0, Arc::clone(&task_ctx))?;
1702            let first_batch = stream.next().await.unwrap()?;
1703            let batch_reservation = get_reserved_bytes_for_record_batch(&first_batch)?;
1704
1705            assert_eq!(batch_reservation, expected_batch_reservation);
1706            assert!(memory_limit < (merge_reservation + batch_reservation));
1707        }
1708
1709        let sort_exec = Arc::new(SortExec::new(
1710            [PhysicalSortExpr::new_default(col("i", &plan.schema())?)].into(),
1711            plan,
1712        ));
1713
1714        let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await;
1715
1716        let err = result.unwrap_err();
1717        assert!(
1718            matches!(err, DataFusionError::Context(..)),
1719            "Assertion failed: expected a Context error, but got: {err:?}"
1720        );
1721
1722        // Assert that the context error is wrapping a resources exhausted error.
1723        assert!(
1724            matches!(err.find_root(), DataFusionError::ResourcesExhausted(_)),
1725            "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}"
1726        );
1727
1728        // Verify external sorter error message when resource is exhausted
1729        let config_vector = vec![
1730            "datafusion.runtime.memory_limit",
1731            "datafusion.execution.sort_spill_reservation_bytes",
1732        ];
1733        let error_message = err.message().to_string();
1734        for config in config_vector.into_iter() {
1735            assert!(
1736                error_message.as_str().contains(config),
1737                "Config: '{}' should be contained in error message: {}.",
1738                config,
1739                error_message.as_str()
1740            );
1741        }
1742
1743        Ok(())
1744    }
1745
1746    #[tokio::test]
1747    async fn test_sort_spill_utf8_strings() -> Result<()> {
1748        let session_config = SessionConfig::new()
1749            .with_batch_size(100)
1750            .with_sort_in_place_threshold_bytes(20 * 1024)
1751            .with_sort_spill_reservation_bytes(100 * 1024);
1752        let runtime = RuntimeEnvBuilder::new()
1753            .with_memory_limit(500 * 1024, 1.0)
1754            .build_arc()?;
1755        let task_ctx = Arc::new(
1756            TaskContext::default()
1757                .with_session_config(session_config)
1758                .with_runtime(runtime),
1759        );
1760
1761        // The input has 200 partitions, each partition has a batch containing 100 rows.
1762        // Each row has a single Utf8 column, the Utf8 string values are roughly 42 bytes.
1763        // The total size of the input is roughly 820 KB.
1764        let input = test::scan_partitioned_utf8(200);
1765        let schema = input.schema();
1766
1767        let sort_exec = Arc::new(SortExec::new(
1768            [PhysicalSortExpr {
1769                expr: col("i", &schema)?,
1770                options: SortOptions::default(),
1771            }]
1772            .into(),
1773            Arc::new(CoalescePartitionsExec::new(input)),
1774        ));
1775
1776        let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?;
1777
1778        let num_rows = result.iter().map(|batch| batch.num_rows()).sum::<usize>();
1779        assert_eq!(num_rows, 20000);
1780
1781        // Now, validate metrics
1782        let metrics = sort_exec.metrics().unwrap();
1783
1784        assert_eq!(metrics.output_rows().unwrap(), 20000);
1785        assert!(metrics.elapsed_compute().unwrap() > 0);
1786
1787        let spill_count = metrics.spill_count().unwrap();
1788        let spilled_rows = metrics.spilled_rows().unwrap();
1789        let spilled_bytes = metrics.spilled_bytes().unwrap();
1790
1791        // This test case is processing 840KB of data using 400KB of memory. Note
1792        // that buffered batches can't be dropped until all sorted batches are
1793        // generated, so we can only buffer `sort_spill_reservation_bytes` of sorted
1794        // batches.
1795        // The number of spills is roughly calculated as:
1796        //  `number_of_batches / (sort_spill_reservation_bytes / batch_size)`
1797
1798        // If this assertion fail with large spill count, make sure the following
1799        // case does not happen:
1800        // During external sorting, one sorted run should be spilled to disk in a
1801        // single file, due to memory limit we might need to append to the file
1802        // multiple times to spill all the data. Make sure we're not writing each
1803        // appending as a separate file.
1804        assert!((4..=8).contains(&spill_count));
1805        assert!((15000..=20000).contains(&spilled_rows));
1806        assert!((900000..=1000000).contains(&spilled_bytes));
1807
1808        // Verify that the result is sorted
1809        let concated_result = concat_batches(&schema, &result)?;
1810        let columns = concated_result.columns();
1811        let string_array = as_string_array(&columns[0]);
1812        for i in 0..string_array.len() - 1 {
1813            assert!(string_array.value(i) <= string_array.value(i + 1));
1814        }
1815
1816        assert_eq!(
1817            task_ctx.runtime_env().memory_pool.reserved(),
1818            0,
1819            "The sort should have returned all memory used back to the memory manager"
1820        );
1821
1822        Ok(())
1823    }
1824
1825    #[tokio::test]
1826    async fn test_sort_fetch_memory_calculation() -> Result<()> {
1827        // This test mirrors down the size from the example above.
1828        let avg_batch_size = 400;
1829        let partitions = 4;
1830
1831        // A tuple of (fetch, expect_spillage)
1832        let test_options = vec![
1833            // Since we don't have a limit (and the memory is less than the total size of
1834            // all the batches we are processing, we expect it to spill.
1835            (None, true),
1836            // When we have a limit however, the buffered size of batches should fit in memory
1837            // since it is much lower than the total size of the input batch.
1838            (Some(1), false),
1839        ];
1840
1841        for (fetch, expect_spillage) in test_options {
1842            let session_config = SessionConfig::new();
1843            let sort_spill_reservation_bytes = session_config
1844                .options()
1845                .execution
1846                .sort_spill_reservation_bytes;
1847
1848            let runtime = RuntimeEnvBuilder::new()
1849                .with_memory_limit(
1850                    sort_spill_reservation_bytes + avg_batch_size * (partitions - 1),
1851                    1.0,
1852                )
1853                .build_arc()?;
1854            let task_ctx = Arc::new(
1855                TaskContext::default()
1856                    .with_runtime(runtime)
1857                    .with_session_config(session_config),
1858            );
1859
1860            let csv = test::scan_partitioned(partitions);
1861            let schema = csv.schema();
1862
1863            let sort_exec = Arc::new(
1864                SortExec::new(
1865                    [PhysicalSortExpr {
1866                        expr: col("i", &schema)?,
1867                        options: SortOptions::default(),
1868                    }]
1869                    .into(),
1870                    Arc::new(CoalescePartitionsExec::new(csv)),
1871                )
1872                .with_fetch(fetch),
1873            );
1874
1875            let result =
1876                collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?;
1877            assert_eq!(result.len(), 1);
1878
1879            let metrics = sort_exec.metrics().unwrap();
1880            let did_it_spill = metrics.spill_count().unwrap_or(0) > 0;
1881            assert_eq!(did_it_spill, expect_spillage, "with fetch: {fetch:?}");
1882        }
1883        Ok(())
1884    }
1885
1886    #[tokio::test]
1887    async fn test_sort_memory_reduction_per_batch() -> Result<()> {
1888        // This test verifies that memory reservation is reduced for every batch emitted
1889        // during the sort process. This is important to ensure we don't hold onto
1890        // memory longer than necessary.
1891
1892        // Create a large enough batch that will be split into multiple output batches
1893        let batch_size = 50; // Small batch size to force multiple output batches
1894        let num_rows = 1000; // Create enough data for multiple batches
1895
1896        let task_ctx = Arc::new(
1897            TaskContext::default().with_session_config(
1898                SessionConfig::new()
1899                    .with_batch_size(batch_size)
1900                    .with_sort_in_place_threshold_bytes(usize::MAX), // Ensure we don't concat batches
1901            ),
1902        );
1903
1904        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1905
1906        // Create unsorted data
1907        let mut values: Vec<i32> = (0..num_rows).collect();
1908        values.reverse();
1909
1910        let input_batch = RecordBatch::try_new(
1911            Arc::clone(&schema),
1912            vec![Arc::new(Int32Array::from(values))],
1913        )?;
1914
1915        let batches = vec![input_batch];
1916
1917        let sort_exec = Arc::new(SortExec::new(
1918            [PhysicalSortExpr {
1919                expr: Arc::new(Column::new("a", 0)),
1920                options: SortOptions::default(),
1921            }]
1922            .into(),
1923            TestMemoryExec::try_new_exec(
1924                std::slice::from_ref(&batches),
1925                Arc::clone(&schema),
1926                None,
1927            )?,
1928        ));
1929
1930        let mut stream = sort_exec.execute(0, Arc::clone(&task_ctx))?;
1931
1932        let mut previous_reserved = task_ctx.runtime_env().memory_pool.reserved();
1933        let mut batch_count = 0;
1934
1935        // Collect batches and verify memory is reduced with each batch
1936        while let Some(result) = stream.next().await {
1937            let batch = result?;
1938            batch_count += 1;
1939
1940            // Verify we got a non-empty batch
1941            assert!(batch.num_rows() > 0, "Batch should not be empty");
1942
1943            let current_reserved = task_ctx.runtime_env().memory_pool.reserved();
1944
1945            // After the first batch, memory should be reducing or staying the same
1946            // (it should not increase as we emit batches)
1947            if batch_count > 1 {
1948                assert!(
1949                    current_reserved <= previous_reserved,
1950                    "Memory reservation should decrease or stay same as batches are emitted. \
1951                     Batch {batch_count}: previous={previous_reserved}, current={current_reserved}"
1952                );
1953            }
1954
1955            previous_reserved = current_reserved;
1956        }
1957
1958        assert!(
1959            batch_count > 1,
1960            "Expected multiple batches to be emitted, got {batch_count}"
1961        );
1962
1963        // Verify all memory is returned at the end
1964        assert_eq!(
1965            task_ctx.runtime_env().memory_pool.reserved(),
1966            0,
1967            "All memory should be returned after consuming all batches"
1968        );
1969
1970        Ok(())
1971    }
1972
1973    #[tokio::test]
1974    async fn test_sort_metadata() -> Result<()> {
1975        let task_ctx = Arc::new(TaskContext::default());
1976        let field_metadata: HashMap<String, String> =
1977            vec![("foo".to_string(), "bar".to_string())]
1978                .into_iter()
1979                .collect();
1980        let schema_metadata: HashMap<String, String> =
1981            vec![("baz".to_string(), "barf".to_string())]
1982                .into_iter()
1983                .collect();
1984
1985        let mut field = Field::new("field_name", DataType::UInt64, true);
1986        field.set_metadata(field_metadata.clone());
1987        let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
1988        let schema = Arc::new(schema);
1989
1990        let data: ArrayRef =
1991            Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::<UInt64Array>());
1992
1993        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
1994        let input =
1995            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
1996
1997        let sort_exec = Arc::new(SortExec::new(
1998            [PhysicalSortExpr {
1999                expr: col("field_name", &schema)?,
2000                options: SortOptions::default(),
2001            }]
2002            .into(),
2003            input,
2004        ));
2005
2006        let result: Vec<RecordBatch> = collect(sort_exec, task_ctx).await?;
2007
2008        let expected_data: ArrayRef =
2009            Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::<UInt64Array>());
2010        let expected_batch =
2011            RecordBatch::try_new(Arc::clone(&schema), vec![expected_data])?;
2012
2013        // Data is correct
2014        assert_eq!(&vec![expected_batch], &result);
2015
2016        // explicitly ensure the metadata is present
2017        assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
2018        assert_eq!(result[0].schema().metadata(), &schema_metadata);
2019
2020        Ok(())
2021    }
2022
2023    #[tokio::test]
2024    async fn test_lex_sort_by_mixed_types() -> Result<()> {
2025        let task_ctx = Arc::new(TaskContext::default());
2026        let schema = Arc::new(Schema::new(vec![
2027            Field::new("a", DataType::Int32, true),
2028            Field::new(
2029                "b",
2030                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
2031                true,
2032            ),
2033        ]));
2034
2035        // define data.
2036        let batch = RecordBatch::try_new(
2037            Arc::clone(&schema),
2038            vec![
2039                Arc::new(Int32Array::from(vec![Some(2), None, Some(1), Some(2)])),
2040                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
2041                    Some(vec![Some(3)]),
2042                    Some(vec![Some(1)]),
2043                    Some(vec![Some(6), None]),
2044                    Some(vec![Some(5)]),
2045                ])),
2046            ],
2047        )?;
2048
2049        let sort_exec = Arc::new(SortExec::new(
2050            [
2051                PhysicalSortExpr {
2052                    expr: col("a", &schema)?,
2053                    options: SortOptions {
2054                        descending: false,
2055                        nulls_first: true,
2056                    },
2057                },
2058                PhysicalSortExpr {
2059                    expr: col("b", &schema)?,
2060                    options: SortOptions {
2061                        descending: true,
2062                        nulls_first: false,
2063                    },
2064                },
2065            ]
2066            .into(),
2067            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?,
2068        ));
2069
2070        assert_eq!(DataType::Int32, *sort_exec.schema().field(0).data_type());
2071        assert_eq!(
2072            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
2073            *sort_exec.schema().field(1).data_type()
2074        );
2075
2076        let result: Vec<RecordBatch> =
2077            collect(Arc::clone(&sort_exec) as Arc<dyn ExecutionPlan>, task_ctx).await?;
2078        let metrics = sort_exec.metrics().unwrap();
2079        assert!(metrics.elapsed_compute().unwrap() > 0);
2080        assert_eq!(metrics.output_rows().unwrap(), 4);
2081        assert_eq!(result.len(), 1);
2082
2083        let expected = RecordBatch::try_new(
2084            schema,
2085            vec![
2086                Arc::new(Int32Array::from(vec![None, Some(1), Some(2), Some(2)])),
2087                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
2088                    Some(vec![Some(1)]),
2089                    Some(vec![Some(6), None]),
2090                    Some(vec![Some(5)]),
2091                    Some(vec![Some(3)]),
2092                ])),
2093            ],
2094        )?;
2095
2096        assert_eq!(expected, result[0]);
2097
2098        Ok(())
2099    }
2100
2101    #[tokio::test]
2102    async fn test_lex_sort_by_float() -> Result<()> {
2103        let task_ctx = Arc::new(TaskContext::default());
2104        let schema = Arc::new(Schema::new(vec![
2105            Field::new("a", DataType::Float32, true),
2106            Field::new("b", DataType::Float64, true),
2107        ]));
2108
2109        // define data.
2110        let batch = RecordBatch::try_new(
2111            Arc::clone(&schema),
2112            vec![
2113                Arc::new(Float32Array::from(vec![
2114                    Some(f32::NAN),
2115                    None,
2116                    None,
2117                    Some(f32::NAN),
2118                    Some(1.0_f32),
2119                    Some(1.0_f32),
2120                    Some(2.0_f32),
2121                    Some(3.0_f32),
2122                ])),
2123                Arc::new(Float64Array::from(vec![
2124                    Some(200.0_f64),
2125                    Some(20.0_f64),
2126                    Some(10.0_f64),
2127                    Some(100.0_f64),
2128                    Some(f64::NAN),
2129                    None,
2130                    None,
2131                    Some(f64::NAN),
2132                ])),
2133            ],
2134        )?;
2135
2136        let sort_exec = Arc::new(SortExec::new(
2137            [
2138                PhysicalSortExpr {
2139                    expr: col("a", &schema)?,
2140                    options: SortOptions {
2141                        descending: true,
2142                        nulls_first: true,
2143                    },
2144                },
2145                PhysicalSortExpr {
2146                    expr: col("b", &schema)?,
2147                    options: SortOptions {
2148                        descending: false,
2149                        nulls_first: false,
2150                    },
2151                },
2152            ]
2153            .into(),
2154            TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
2155        ));
2156
2157        assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type());
2158        assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type());
2159
2160        let result: Vec<RecordBatch> =
2161            collect(Arc::clone(&sort_exec) as Arc<dyn ExecutionPlan>, task_ctx).await?;
2162        let metrics = sort_exec.metrics().unwrap();
2163        assert!(metrics.elapsed_compute().unwrap() > 0);
2164        assert_eq!(metrics.output_rows().unwrap(), 8);
2165        assert_eq!(result.len(), 1);
2166
2167        let columns = result[0].columns();
2168
2169        assert_eq!(DataType::Float32, *columns[0].data_type());
2170        assert_eq!(DataType::Float64, *columns[1].data_type());
2171
2172        let a = as_primitive_array::<Float32Type>(&columns[0])?;
2173        let b = as_primitive_array::<Float64Type>(&columns[1])?;
2174
2175        // convert result to strings to allow comparing to expected result containing NaN
2176        let result: Vec<(Option<String>, Option<String>)> = (0..result[0].num_rows())
2177            .map(|i| {
2178                let aval = if a.is_valid(i) {
2179                    Some(a.value(i).to_string())
2180                } else {
2181                    None
2182                };
2183                let bval = if b.is_valid(i) {
2184                    Some(b.value(i).to_string())
2185                } else {
2186                    None
2187                };
2188                (aval, bval)
2189            })
2190            .collect();
2191
2192        let expected: Vec<(Option<String>, Option<String>)> = vec![
2193            (None, Some("10".to_owned())),
2194            (None, Some("20".to_owned())),
2195            (Some("NaN".to_owned()), Some("100".to_owned())),
2196            (Some("NaN".to_owned()), Some("200".to_owned())),
2197            (Some("3".to_owned()), Some("NaN".to_owned())),
2198            (Some("2".to_owned()), None),
2199            (Some("1".to_owned()), Some("NaN".to_owned())),
2200            (Some("1".to_owned()), None),
2201        ];
2202
2203        assert_eq!(expected, result);
2204
2205        Ok(())
2206    }
2207
2208    #[tokio::test]
2209    async fn test_drop_cancel() -> Result<()> {
2210        let task_ctx = Arc::new(TaskContext::default());
2211        let schema =
2212            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2213
2214        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2215        let refs = blocking_exec.refs();
2216        let sort_exec = Arc::new(SortExec::new(
2217            [PhysicalSortExpr {
2218                expr: col("a", &schema)?,
2219                options: SortOptions::default(),
2220            }]
2221            .into(),
2222            blocking_exec,
2223        ));
2224
2225        let fut = collect(sort_exec, Arc::clone(&task_ctx));
2226        let mut fut = fut.boxed();
2227
2228        assert_is_pending(&mut fut);
2229        drop(fut);
2230        assert_strong_count_converges_to_zero(refs).await;
2231
2232        assert_eq!(
2233            task_ctx.runtime_env().memory_pool.reserved(),
2234            0,
2235            "The sort should have returned all memory used back to the memory manager"
2236        );
2237
2238        Ok(())
2239    }
2240
2241    #[test]
2242    fn test_empty_sort_batch() {
2243        let schema = Arc::new(Schema::empty());
2244        let options = RecordBatchOptions::new().with_row_count(Some(1));
2245        let batch =
2246            RecordBatch::try_new_with_options(Arc::clone(&schema), vec![], &options)
2247                .unwrap();
2248
2249        let expressions = [PhysicalSortExpr {
2250            expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
2251            options: SortOptions::default(),
2252        }]
2253        .into();
2254
2255        let result = sort_batch(&batch, &expressions, None).unwrap();
2256        assert_eq!(result.num_rows(), 1);
2257    }
2258
2259    #[tokio::test]
2260    async fn topk_unbounded_source() -> Result<()> {
2261        let task_ctx = Arc::new(TaskContext::default());
2262        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
2263        let source = SortedUnboundedExec {
2264            schema: schema.clone(),
2265            batch_size: 2,
2266            cache: Arc::new(SortedUnboundedExec::compute_properties(Arc::new(
2267                schema.clone(),
2268            ))),
2269        };
2270        let mut plan = SortExec::new(
2271            [PhysicalSortExpr::new_default(Arc::new(Column::new(
2272                "c1", 0,
2273            )))]
2274            .into(),
2275            Arc::new(source),
2276        );
2277        plan = plan.with_fetch(Some(9));
2278
2279        let batches = collect(Arc::new(plan), task_ctx).await?;
2280        assert_snapshot!(batches_to_string(&batches), @r"
2281        +----+
2282        | c1 |
2283        +----+
2284        | 0  |
2285        | 1  |
2286        | 2  |
2287        | 3  |
2288        | 4  |
2289        | 5  |
2290        | 6  |
2291        | 7  |
2292        | 8  |
2293        +----+
2294        ");
2295        Ok(())
2296    }
2297
2298    #[tokio::test]
2299    async fn should_return_stream_with_batches_in_the_requested_size() -> Result<()> {
2300        let batch_size = 100;
2301
2302        let create_task_ctx = |_: &[RecordBatch]| {
2303            TaskContext::default().with_session_config(
2304                SessionConfig::new()
2305                    .with_batch_size(batch_size)
2306                    .with_sort_in_place_threshold_bytes(usize::MAX),
2307            )
2308        };
2309
2310        // Smaller than batch size and require more than a single batch to get the requested batch size
2311        test_sort_output_batch_size(10, batch_size / 4, create_task_ctx).await?;
2312
2313        // Not evenly divisible by batch size
2314        test_sort_output_batch_size(10, batch_size + 7, create_task_ctx).await?;
2315
2316        // Evenly divisible by batch size and is larger than 2 output batches
2317        test_sort_output_batch_size(10, batch_size * 3, create_task_ctx).await?;
2318
2319        Ok(())
2320    }
2321
2322    #[tokio::test]
2323    async fn should_return_stream_with_batches_in_the_requested_size_when_sorting_in_place()
2324    -> Result<()> {
2325        let batch_size = 100;
2326
2327        let create_task_ctx = |_: &[RecordBatch]| {
2328            TaskContext::default().with_session_config(
2329                SessionConfig::new()
2330                    .with_batch_size(batch_size)
2331                    .with_sort_in_place_threshold_bytes(usize::MAX - 1),
2332            )
2333        };
2334
2335        // Smaller than batch size and require more than a single batch to get the requested batch size
2336        {
2337            let metrics =
2338                test_sort_output_batch_size(10, batch_size / 4, create_task_ctx).await?;
2339
2340            assert_eq!(
2341                metrics.spill_count(),
2342                Some(0),
2343                "Expected no spills when sorting in place"
2344            );
2345        }
2346
2347        // Not evenly divisible by batch size
2348        {
2349            let metrics =
2350                test_sort_output_batch_size(10, batch_size + 7, create_task_ctx).await?;
2351
2352            assert_eq!(
2353                metrics.spill_count(),
2354                Some(0),
2355                "Expected no spills when sorting in place"
2356            );
2357        }
2358
2359        // Evenly divisible by batch size and is larger than 2 output batches
2360        {
2361            let metrics =
2362                test_sort_output_batch_size(10, batch_size * 3, create_task_ctx).await?;
2363
2364            assert_eq!(
2365                metrics.spill_count(),
2366                Some(0),
2367                "Expected no spills when sorting in place"
2368            );
2369        }
2370
2371        Ok(())
2372    }
2373
2374    #[tokio::test]
2375    async fn should_return_stream_with_batches_in_the_requested_size_when_having_a_single_batch()
2376    -> Result<()> {
2377        let batch_size = 100;
2378
2379        let create_task_ctx = |_: &[RecordBatch]| {
2380            TaskContext::default()
2381                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
2382        };
2383
2384        // Smaller than batch size and require more than a single batch to get the requested batch size
2385        {
2386            let metrics = test_sort_output_batch_size(
2387                // Single batch
2388                1,
2389                batch_size / 4,
2390                create_task_ctx,
2391            )
2392            .await?;
2393
2394            assert_eq!(
2395                metrics.spill_count(),
2396                Some(0),
2397                "Expected no spills when sorting in place"
2398            );
2399        }
2400
2401        // Not evenly divisible by batch size
2402        {
2403            let metrics = test_sort_output_batch_size(
2404                // Single batch
2405                1,
2406                batch_size + 7,
2407                create_task_ctx,
2408            )
2409            .await?;
2410
2411            assert_eq!(
2412                metrics.spill_count(),
2413                Some(0),
2414                "Expected no spills when sorting in place"
2415            );
2416        }
2417
2418        // Evenly divisible by batch size and is larger than 2 output batches
2419        {
2420            let metrics = test_sort_output_batch_size(
2421                // Single batch
2422                1,
2423                batch_size * 3,
2424                create_task_ctx,
2425            )
2426            .await?;
2427
2428            assert_eq!(
2429                metrics.spill_count(),
2430                Some(0),
2431                "Expected no spills when sorting in place"
2432            );
2433        }
2434
2435        Ok(())
2436    }
2437
2438    #[tokio::test]
2439    async fn should_return_stream_with_batches_in_the_requested_size_when_having_to_spill()
2440    -> Result<()> {
2441        let batch_size = 100;
2442
2443        let create_task_ctx = |generated_batches: &[RecordBatch]| {
2444            let batches_memory = generated_batches
2445                .iter()
2446                .map(|b| b.get_array_memory_size())
2447                .sum::<usize>();
2448
2449            TaskContext::default()
2450                .with_session_config(
2451                    SessionConfig::new()
2452                        .with_batch_size(batch_size)
2453                        // To make sure there is no in place sorting
2454                        .with_sort_in_place_threshold_bytes(1)
2455                        .with_sort_spill_reservation_bytes(1),
2456                )
2457                .with_runtime(
2458                    RuntimeEnvBuilder::default()
2459                        .with_memory_limit(batches_memory, 1.0)
2460                        .build_arc()
2461                        .unwrap(),
2462                )
2463        };
2464
2465        // Smaller than batch size and require more than a single batch to get the requested batch size
2466        {
2467            let metrics =
2468                test_sort_output_batch_size(10, batch_size / 4, create_task_ctx).await?;
2469
2470            assert_ne!(metrics.spill_count().unwrap(), 0, "expected to spill");
2471        }
2472
2473        // Not evenly divisible by batch size
2474        {
2475            let metrics =
2476                test_sort_output_batch_size(10, batch_size + 7, create_task_ctx).await?;
2477
2478            assert_ne!(metrics.spill_count().unwrap(), 0, "expected to spill");
2479        }
2480
2481        // Evenly divisible by batch size and is larger than 2 batches
2482        {
2483            let metrics =
2484                test_sort_output_batch_size(10, batch_size * 3, create_task_ctx).await?;
2485
2486            assert_ne!(metrics.spill_count().unwrap(), 0, "expected to spill");
2487        }
2488
2489        Ok(())
2490    }
2491
2492    async fn test_sort_output_batch_size(
2493        number_of_batches: usize,
2494        batch_size_to_generate: usize,
2495        create_task_ctx: impl Fn(&[RecordBatch]) -> TaskContext,
2496    ) -> Result<MetricsSet> {
2497        let batches = (0..number_of_batches)
2498            .map(|_| make_partition(batch_size_to_generate as i32))
2499            .collect::<Vec<_>>();
2500        let task_ctx = create_task_ctx(batches.as_slice());
2501
2502        let expected_batch_size = task_ctx.session_config().batch_size();
2503
2504        let (mut output_batches, metrics) =
2505            run_sort_on_input(task_ctx, "i", batches).await?;
2506
2507        let last_batch = output_batches.pop().unwrap();
2508
2509        for batch in output_batches {
2510            assert_eq!(batch.num_rows(), expected_batch_size);
2511        }
2512
2513        let mut last_expected_batch_size =
2514            (batch_size_to_generate * number_of_batches) % expected_batch_size;
2515        if last_expected_batch_size == 0 {
2516            last_expected_batch_size = expected_batch_size;
2517        }
2518        assert_eq!(last_batch.num_rows(), last_expected_batch_size);
2519
2520        Ok(metrics)
2521    }
2522
2523    async fn run_sort_on_input(
2524        task_ctx: TaskContext,
2525        order_by_col: &str,
2526        batches: Vec<RecordBatch>,
2527    ) -> Result<(Vec<RecordBatch>, MetricsSet)> {
2528        let task_ctx = Arc::new(task_ctx);
2529
2530        // let task_ctx = env.
2531        let schema = batches[0].schema();
2532        let ordering: LexOrdering = [PhysicalSortExpr {
2533            expr: col(order_by_col, &schema)?,
2534            options: SortOptions {
2535                descending: false,
2536                nulls_first: true,
2537            },
2538        }]
2539        .into();
2540        let sort_exec: Arc<dyn ExecutionPlan> = Arc::new(SortExec::new(
2541            ordering.clone(),
2542            TestMemoryExec::try_new_exec(std::slice::from_ref(&batches), schema, None)?,
2543        ));
2544
2545        let sorted_batches =
2546            collect(Arc::clone(&sort_exec), Arc::clone(&task_ctx)).await?;
2547
2548        let metrics = sort_exec.metrics().expect("sort have metrics");
2549
2550        // assert output
2551        {
2552            let input_batches_concat = concat_batches(batches[0].schema_ref(), &batches)?;
2553            let sorted_input_batch = sort_batch(&input_batches_concat, &ordering, None)?;
2554
2555            let sorted_batches_concat =
2556                concat_batches(sorted_batches[0].schema_ref(), &sorted_batches)?;
2557
2558            assert_eq!(sorted_input_batch, sorted_batches_concat);
2559        }
2560
2561        Ok((sorted_batches, metrics))
2562    }
2563
2564    #[tokio::test]
2565    async fn test_sort_batch_chunked_basic() -> Result<()> {
2566        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2567
2568        // Create a batch with 1000 rows
2569        let mut values: Vec<i32> = (0..1000).collect();
2570        // Shuffle to make it unsorted
2571        values.reverse();
2572
2573        let batch = RecordBatch::try_new(
2574            Arc::clone(&schema),
2575            vec![Arc::new(Int32Array::from(values))],
2576        )?;
2577
2578        let expressions: LexOrdering =
2579            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into();
2580
2581        // Sort with batch_size = 250
2582        let result_batches = sort_batch_chunked(&batch, &expressions, 250)?;
2583
2584        // Verify 4 batches are returned
2585        assert_eq!(result_batches.len(), 4);
2586
2587        // Verify each batch has <= 250 rows
2588        let mut total_rows = 0;
2589        for (i, batch) in result_batches.iter().enumerate() {
2590            assert!(
2591                batch.num_rows() <= 250,
2592                "Batch {} has {} rows, expected <= 250",
2593                i,
2594                batch.num_rows()
2595            );
2596            total_rows += batch.num_rows();
2597        }
2598
2599        // Verify total row count matches input
2600        assert_eq!(total_rows, 1000);
2601
2602        // Verify data is correctly sorted across all chunks
2603        let concatenated = concat_batches(&schema, &result_batches)?;
2604        let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
2605        for i in 0..array.len() - 1 {
2606            assert!(
2607                array.value(i) <= array.value(i + 1),
2608                "Array not sorted at position {}: {} > {}",
2609                i,
2610                array.value(i),
2611                array.value(i + 1)
2612            );
2613        }
2614        assert_eq!(array.value(0), 0);
2615        assert_eq!(array.value(array.len() - 1), 999);
2616
2617        Ok(())
2618    }
2619
2620    #[tokio::test]
2621    async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> {
2622        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2623
2624        // Create a batch with 50 rows
2625        let values: Vec<i32> = (0..50).rev().collect();
2626        let batch = RecordBatch::try_new(
2627            Arc::clone(&schema),
2628            vec![Arc::new(Int32Array::from(values))],
2629        )?;
2630
2631        let expressions: LexOrdering =
2632            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into();
2633
2634        // Sort with batch_size = 100
2635        let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
2636
2637        // Should return exactly 1 batch
2638        assert_eq!(result_batches.len(), 1);
2639        assert_eq!(result_batches[0].num_rows(), 50);
2640
2641        // Verify it's correctly sorted
2642        let array = as_primitive_array::<Int32Type>(result_batches[0].column(0))?;
2643        for i in 0..array.len() - 1 {
2644            assert!(array.value(i) <= array.value(i + 1));
2645        }
2646        assert_eq!(array.value(0), 0);
2647        assert_eq!(array.value(49), 49);
2648
2649        Ok(())
2650    }
2651
2652    #[tokio::test]
2653    async fn test_sort_batch_chunked_exact_multiple() -> Result<()> {
2654        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2655
2656        // Create a batch with 1000 rows
2657        let values: Vec<i32> = (0..1000).rev().collect();
2658        let batch = RecordBatch::try_new(
2659            Arc::clone(&schema),
2660            vec![Arc::new(Int32Array::from(values))],
2661        )?;
2662
2663        let expressions: LexOrdering =
2664            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into();
2665
2666        // Sort with batch_size = 100
2667        let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
2668
2669        // Should return exactly 10 batches of 100 rows each
2670        assert_eq!(result_batches.len(), 10);
2671        for batch in &result_batches {
2672            assert_eq!(batch.num_rows(), 100);
2673        }
2674
2675        // Verify sorted correctly across all batches
2676        let concatenated = concat_batches(&schema, &result_batches)?;
2677        let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
2678        for i in 0..array.len() - 1 {
2679            assert!(array.value(i) <= array.value(i + 1));
2680        }
2681
2682        Ok(())
2683    }
2684
2685    #[tokio::test]
2686    async fn test_sort_batch_chunked_empty_batch() -> Result<()> {
2687        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2688
2689        let batch = RecordBatch::new_empty(Arc::clone(&schema));
2690
2691        let expressions: LexOrdering =
2692            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into();
2693
2694        let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
2695
2696        // Empty input produces no output batches (0 chunks)
2697        assert_eq!(result_batches.len(), 0);
2698
2699        Ok(())
2700    }
2701
2702    #[tokio::test]
2703    async fn test_get_reserved_bytes_for_record_batch_with_sliced_batches() -> Result<()>
2704    {
2705        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2706
2707        // Create a larger batch then slice it
2708        let large_array = Int32Array::from((0..1000).collect::<Vec<i32>>());
2709        let sliced_array = large_array.slice(100, 50); // Take 50 elements starting at 100
2710
2711        let sliced_batch =
2712            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(sliced_array)])?;
2713        let batch =
2714            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(large_array)])?;
2715
2716        let sliced_reserved = get_reserved_bytes_for_record_batch(&sliced_batch)?;
2717        let reserved = get_reserved_bytes_for_record_batch(&batch)?;
2718
2719        // The reserved memory for the sliced batch should be less than that of the full batch
2720        assert!(reserved > sliced_reserved);
2721
2722        Ok(())
2723    }
2724
2725    #[test]
2726    fn test_with_dynamic_filter() -> Result<()> {
2727        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2728        let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
2729
2730        let sort = SortExec::new(
2731            LexOrdering::new(vec![PhysicalSortExpr {
2732                expr: Arc::new(Column::new("a", 0)),
2733                options: SortOptions::default(),
2734            }])
2735            .unwrap(),
2736            child,
2737        )
2738        .with_fetch(Some(10));
2739
2740        // SortExec with fetch creates a dynamic filter automatically.
2741        let original_id = sort
2742            .dynamic_filter_expr()
2743            .expect("should have dynamic filter with fetch")
2744            .expression_id()
2745            .expect("DynamicFilterPhysicalExpr always has an expression_id");
2746
2747        // with_dynamic_filter replaces it with a new TopKDynamicFilters.
2748        let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
2749            vec![Arc::new(Column::new("a", 0)) as _],
2750            lit(true),
2751        ));
2752        let new_id = new_df
2753            .expression_id()
2754            .expect("DynamicFilterPhysicalExpr always has an expression_id");
2755        let sort = sort.with_dynamic_filter_expr(Arc::clone(&new_df))?;
2756        let restored_id = sort
2757            .dynamic_filter_expr()
2758            .expect("should still have dynamic filter")
2759            .expression_id()
2760            .expect("DynamicFilterPhysicalExpr always has an expression_id");
2761        assert_eq!(restored_id, new_id);
2762        assert_ne!(restored_id, original_id);
2763        Ok(())
2764    }
2765
2766    #[test]
2767    fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> {
2768        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2769        let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
2770
2771        let sort = SortExec::new(
2772            LexOrdering::new(vec![PhysicalSortExpr {
2773                expr: Arc::new(Column::new("a", 0)),
2774                options: SortOptions::default(),
2775            }])
2776            .unwrap(),
2777            child,
2778        )
2779        .with_fetch(Some(10));
2780
2781        // Column index 99 is out of bounds for the input schema.
2782        let df = Arc::new(DynamicFilterPhysicalExpr::new(
2783            vec![Arc::new(Column::new("bad", 99)) as _],
2784            lit(true),
2785        ));
2786        assert!(sort.with_dynamic_filter_expr(df).is_err());
2787        Ok(())
2788    }
2789
2790    /// Verifies that `ExternalSorter::sort()` transfers the pre-reserved
2791    /// merge bytes to the merge stream via `take()`, rather than leaving
2792    /// them in the sorter (via `new_empty()`).
2793    ///
2794    /// 1. Create a sorter with a tight memory pool and insert enough data
2795    ///    to force spilling
2796    /// 2. Verify `merge_reservation` holds the pre-reserved bytes before sort
2797    /// 3. Call `sort()` to get the merge stream
2798    /// 4. Verify `merge_reservation` is now 0 (bytes transferred to merge stream)
2799    /// 5. Simulate contention: a competing consumer grabs all available pool memory
2800    /// 6. Verify the merge stream still works (it uses its pre-reserved bytes
2801    ///    as initial budget, not requesting from pool starting at 0)
2802    ///
2803    /// With `new_empty()` (before fix), step 4 fails: `merge_reservation`
2804    /// still holds the bytes, the merge stream starts with 0 budget, and
2805    /// those bytes become unaccounted-for reserved memory that nobody uses.
2806    #[tokio::test]
2807    async fn test_sort_merge_reservation_transferred_not_freed() -> Result<()> {
2808        let sort_spill_reservation_bytes: usize = 10 * 1024; // 10 KB
2809
2810        // Pool: merge reservation (10KB) + enough room for sort to work.
2811        // The room must accommodate batch data accumulation before spilling.
2812        let sort_working_memory: usize = 40 * 1024; // 40 KB for sort operations
2813        let pool_size = sort_spill_reservation_bytes + sort_working_memory;
2814        let pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(pool_size));
2815
2816        let runtime = RuntimeEnvBuilder::new()
2817            .with_memory_pool(Arc::clone(&pool))
2818            .build_arc()?;
2819
2820        let metrics_set = ExecutionPlanMetricsSet::new();
2821        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
2822
2823        let mut sorter = ExternalSorter::new(
2824            0,
2825            Arc::clone(&schema),
2826            [PhysicalSortExpr::new_default(Arc::new(Column::new("x", 0)))].into(),
2827            128, // batch_size
2828            sort_spill_reservation_bytes,
2829            usize::MAX, // sort_in_place_threshold_bytes (high to avoid concat path)
2830            SpillCompression::Uncompressed,
2831            &metrics_set,
2832            Arc::clone(&runtime),
2833        )?;
2834
2835        // Insert enough data to force spilling.
2836        let num_batches = 200;
2837        for i in 0..num_batches {
2838            let values: Vec<i32> = ((i * 100)..((i + 1) * 100)).rev().collect();
2839            let batch = RecordBatch::try_new(
2840                Arc::clone(&schema),
2841                vec![Arc::new(Int32Array::from(values))],
2842            )?;
2843            sorter.insert_batch(batch).await?;
2844        }
2845
2846        assert!(
2847            sorter.spilled_before(),
2848            "Test requires spilling to exercise the merge path"
2849        );
2850
2851        // Before sort(), merge_reservation holds sort_spill_reservation_bytes.
2852        assert!(
2853            sorter.merge_reservation_size() >= sort_spill_reservation_bytes,
2854            "merge_reservation should hold the pre-reserved bytes before sort()"
2855        );
2856
2857        // Call sort() to get the merge stream. With the fix (take()),
2858        // the pre-reserved merge bytes are transferred to the merge
2859        // stream. Without the fix (free() + new_empty()), the bytes
2860        // are released back to the pool and the merge stream starts
2861        // with 0 bytes.
2862        let merge_stream = sorter.sort().await?;
2863
2864        // THE KEY ASSERTION: after sort(), merge_reservation must be 0.
2865        // This proves take() transferred the bytes to the merge stream,
2866        // rather than them being freed back to the pool where other
2867        // partitions could steal them.
2868        assert_eq!(
2869            sorter.merge_reservation_size(),
2870            0,
2871            "After sort(), merge_reservation should be 0 (bytes transferred \
2872             to merge stream via take()). If non-zero, the bytes are still \
2873             held by the sorter and will be freed on drop, allowing other \
2874             partitions to steal them."
2875        );
2876
2877        // Drop the sorter to free its reservations back to the pool.
2878        drop(sorter);
2879
2880        // Simulate contention: another partition grabs ALL available
2881        // pool memory. If the merge stream didn't receive the
2882        // pre-reserved bytes via take(), it will fail when it tries
2883        // to allocate memory for reading spill files.
2884        let contender = MemoryConsumer::new("CompetingPartition").register(&pool);
2885        let available = pool_size.saturating_sub(pool.reserved());
2886        if available > 0 {
2887            contender.try_grow(available).unwrap();
2888        }
2889
2890        // The merge stream must still produce correct results despite
2891        // the pool being fully consumed by the contender. This only
2892        // works if sort() transferred the pre-reserved bytes to the
2893        // merge stream (via take()) rather than freeing them.
2894        let batches: Vec<RecordBatch> = merge_stream.try_collect().await?;
2895        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
2896        assert_eq!(
2897            total_rows,
2898            (num_batches * 100) as usize,
2899            "Merge stream should produce all rows even under memory contention"
2900        );
2901
2902        // Verify data is sorted
2903        let merged = concat_batches(&schema, &batches)?;
2904        let col = merged.column(0).as_primitive::<Int32Type>();
2905        for i in 1..col.len() {
2906            assert!(
2907                col.value(i - 1) <= col.value(i),
2908                "Output should be sorted, but found {} > {} at index {}",
2909                col.value(i - 1),
2910                col.value(i),
2911                i
2912            );
2913        }
2914
2915        drop(contender);
2916        Ok(())
2917    }
2918
2919    fn make_sort_exec_with_fetch(fetch: Option<usize>) -> SortExec {
2920        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2921        let input = Arc::new(EmptyExec::new(schema));
2922        SortExec::new(
2923            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(),
2924            input,
2925        )
2926        .with_fetch(fetch)
2927    }
2928
2929    #[test]
2930    fn test_sort_with_fetch_blocks_filter_pushdown() -> Result<()> {
2931        let sort = make_sort_exec_with_fetch(Some(10));
2932        let desc = sort.gather_filters_for_pushdown(
2933            FilterPushdownPhase::Pre,
2934            vec![Arc::new(Column::new("a", 0))],
2935            &ConfigOptions::new(),
2936        )?;
2937        // Sort with fetch (TopK) must not allow filters to be pushed below it.
2938        assert!(matches!(
2939            desc.parent_filters()[0][0].discriminant,
2940            PushedDown::No
2941        ));
2942        Ok(())
2943    }
2944
2945    #[test]
2946    fn test_sort_without_fetch_allows_filter_pushdown() -> Result<()> {
2947        let sort = make_sort_exec_with_fetch(None);
2948        let desc = sort.gather_filters_for_pushdown(
2949            FilterPushdownPhase::Pre,
2950            vec![Arc::new(Column::new("a", 0))],
2951            &ConfigOptions::new(),
2952        )?;
2953        // Plain sort (no fetch) is filter-commutative.
2954        assert!(matches!(
2955            desc.parent_filters()[0][0].discriminant,
2956            PushedDown::Yes
2957        ));
2958        Ok(())
2959    }
2960
2961    #[test]
2962    fn test_sort_with_fetch_allows_topk_self_filter_in_post_phase() -> Result<()> {
2963        let sort = make_sort_exec_with_fetch(Some(10));
2964        assert!(sort.filter.is_some(), "TopK filter should be created");
2965
2966        let mut config = ConfigOptions::new();
2967        config.optimizer.enable_topk_dynamic_filter_pushdown = true;
2968        let desc = sort.gather_filters_for_pushdown(
2969            FilterPushdownPhase::Post,
2970            vec![Arc::new(Column::new("a", 0))],
2971            &config,
2972        )?;
2973        // Parent filters are still blocked in the Post phase.
2974        assert!(matches!(
2975            desc.parent_filters()[0][0].discriminant,
2976            PushedDown::No
2977        ));
2978        // But the TopK self-filter should be pushed down.
2979        assert_eq!(desc.self_filters()[0].len(), 1);
2980        Ok(())
2981    }
2982}