datafusion_physical_plan/sorts/
partial_sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partial Sort deals with input data that partially
19//! satisfies the required sort order. Such an input data can be
20//! partitioned into segments where each segment already has the
21//! required information for lexicographic sorting so sorting
22//! can be done without loading the entire dataset.
23//!
24//! Consider a sort plan having an input with ordering `a ASC, b ASC`
25//!
26//! ```text
27//! +---+---+---+
28//! | a | b | d |
29//! +---+---+---+
30//! | 0 | 0 | 3 |
31//! | 0 | 0 | 2 |
32//! | 0 | 1 | 1 |
33//! | 0 | 2 | 0 |
34//! +---+---+---+
35//! ```
36//!
37//! and required ordering for the plan is `a ASC, b ASC, d ASC`.
38//! The first 3 rows(segment) can be sorted as the segment already
39//! has the required information for the sort, but the last row
40//! requires further information as the input can continue with a
41//! batch with a starting row where a and b does not change as below
42//!
43//! ```text
44//! +---+---+---+
45//! | a | b | d |
46//! +---+---+---+
47//! | 0 | 2 | 4 |
48//! +---+---+---+
49//! ```
50//!
51//! The plan concats incoming data with such last rows of previous input
52//! and continues partial sorting of the segments.
53
54use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65};
66
67use arrow::compute::concat_batches;
68use arrow::datatypes::SchemaRef;
69use arrow::record_batch::RecordBatch;
70use datafusion_common::utils::evaluate_partition_ranges;
71use datafusion_common::Result;
72use datafusion_execution::{RecordBatchStream, TaskContext};
73use datafusion_physical_expr::LexOrdering;
74
75use futures::{ready, Stream, StreamExt};
76use log::trace;
77
78/// Partial Sort execution plan.
79#[derive(Debug, Clone)]
80pub struct PartialSortExec {
81    /// Input schema
82    pub(crate) input: Arc<dyn ExecutionPlan>,
83    /// Sort expressions
84    expr: LexOrdering,
85    /// Length of continuous matching columns of input that satisfy
86    /// the required ordering for the sort
87    common_prefix_length: usize,
88    /// Containing all metrics set created during sort
89    metrics_set: ExecutionPlanMetricsSet,
90    /// Preserve partitions of input plan. If false, the input partitions
91    /// will be sorted and merged into a single output partition.
92    preserve_partitioning: bool,
93    /// Fetch highest/lowest n results
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97}
98
99impl PartialSortExec {
100    /// Create a new partial sort execution plan
101    pub fn new(
102        expr: LexOrdering,
103        input: Arc<dyn ExecutionPlan>,
104        common_prefix_length: usize,
105    ) -> Self {
106        debug_assert!(common_prefix_length > 0);
107        let preserve_partitioning = false;
108        let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
109            .unwrap();
110        Self {
111            input,
112            expr,
113            common_prefix_length,
114            metrics_set: ExecutionPlanMetricsSet::new(),
115            preserve_partitioning,
116            fetch: None,
117            cache,
118        }
119    }
120
121    /// Whether this `PartialSortExec` preserves partitioning of the children
122    pub fn preserve_partitioning(&self) -> bool {
123        self.preserve_partitioning
124    }
125
126    /// Specify the partitioning behavior of this partial sort exec
127    ///
128    /// If `preserve_partitioning` is true, sorts each partition
129    /// individually, producing one sorted stream for each input partition.
130    ///
131    /// If `preserve_partitioning` is false, sorts and merges all
132    /// input partitions producing a single, sorted partition.
133    pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
134        self.preserve_partitioning = preserve_partitioning;
135        self.cache = self
136            .cache
137            .with_partitioning(Self::output_partitioning_helper(
138                &self.input,
139                self.preserve_partitioning,
140            ));
141        self
142    }
143
144    /// Modify how many rows to include in the result
145    ///
146    /// If None, then all rows will be returned, in sorted order.
147    /// If Some, then only the top `fetch` rows will be returned.
148    /// This can reduce the memory pressure required by the sort
149    /// operation since rows that are not going to be included
150    /// can be dropped.
151    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
152        self.fetch = fetch;
153        self
154    }
155
156    /// Input schema
157    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
158        &self.input
159    }
160
161    /// Sort expressions
162    pub fn expr(&self) -> &LexOrdering {
163        &self.expr
164    }
165
166    /// If `Some(fetch)`, limits output to only the first "fetch" items
167    pub fn fetch(&self) -> Option<usize> {
168        self.fetch
169    }
170
171    /// Common prefix length
172    pub fn common_prefix_length(&self) -> usize {
173        self.common_prefix_length
174    }
175
176    fn output_partitioning_helper(
177        input: &Arc<dyn ExecutionPlan>,
178        preserve_partitioning: bool,
179    ) -> Partitioning {
180        // Get output partitioning:
181        if preserve_partitioning {
182            input.output_partitioning().clone()
183        } else {
184            Partitioning::UnknownPartitioning(1)
185        }
186    }
187
188    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
189    fn compute_properties(
190        input: &Arc<dyn ExecutionPlan>,
191        sort_exprs: LexOrdering,
192        preserve_partitioning: bool,
193    ) -> Result<PlanProperties> {
194        // Calculate equivalence properties; i.e. reset the ordering equivalence
195        // class with the new ordering:
196        let mut eq_properties = input.equivalence_properties().clone();
197        eq_properties.reorder(sort_exprs)?;
198
199        // Get output partitioning:
200        let output_partitioning =
201            Self::output_partitioning_helper(input, preserve_partitioning);
202
203        Ok(PlanProperties::new(
204            eq_properties,
205            output_partitioning,
206            input.pipeline_behavior(),
207            input.boundedness(),
208        ))
209    }
210}
211
212impl DisplayAs for PartialSortExec {
213    fn fmt_as(
214        &self,
215        t: DisplayFormatType,
216        f: &mut std::fmt::Formatter,
217    ) -> std::fmt::Result {
218        match t {
219            DisplayFormatType::Default | DisplayFormatType::Verbose => {
220                let common_prefix_length = self.common_prefix_length;
221                match self.fetch {
222                    Some(fetch) => {
223                        write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr)
224                    }
225                    None => write!(f, "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr),
226                }
227            }
228            DisplayFormatType::TreeRender => match self.fetch {
229                Some(fetch) => {
230                    writeln!(f, "{}", self.expr)?;
231                    writeln!(f, "limit={fetch}")
232                }
233                None => {
234                    writeln!(f, "{}", self.expr)
235                }
236            },
237        }
238    }
239}
240
241impl ExecutionPlan for PartialSortExec {
242    fn name(&self) -> &'static str {
243        "PartialSortExec"
244    }
245
246    fn as_any(&self) -> &dyn Any {
247        self
248    }
249
250    fn properties(&self) -> &PlanProperties {
251        &self.cache
252    }
253
254    fn fetch(&self) -> Option<usize> {
255        self.fetch
256    }
257
258    fn required_input_distribution(&self) -> Vec<Distribution> {
259        if self.preserve_partitioning {
260            vec![Distribution::UnspecifiedDistribution]
261        } else {
262            vec![Distribution::SinglePartition]
263        }
264    }
265
266    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
267        vec![false]
268    }
269
270    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
271        vec![&self.input]
272    }
273
274    fn with_new_children(
275        self: Arc<Self>,
276        children: Vec<Arc<dyn ExecutionPlan>>,
277    ) -> Result<Arc<dyn ExecutionPlan>> {
278        let new_partial_sort = PartialSortExec::new(
279            self.expr.clone(),
280            Arc::clone(&children[0]),
281            self.common_prefix_length,
282        )
283        .with_fetch(self.fetch)
284        .with_preserve_partitioning(self.preserve_partitioning);
285
286        Ok(Arc::new(new_partial_sort))
287    }
288
289    fn execute(
290        &self,
291        partition: usize,
292        context: Arc<TaskContext>,
293    ) -> Result<SendableRecordBatchStream> {
294        trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
295
296        let input = self.input.execute(partition, Arc::clone(&context))?;
297
298        trace!("End PartialSortExec's input.execute for partition: {partition}");
299
300        // Make sure common prefix length is larger than 0
301        // Otherwise, we should use SortExec.
302        debug_assert!(self.common_prefix_length > 0);
303
304        Ok(Box::pin(PartialSortStream {
305            input,
306            expr: self.expr.clone(),
307            common_prefix_length: self.common_prefix_length,
308            in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
309            fetch: self.fetch,
310            is_closed: false,
311            baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
312        }))
313    }
314
315    fn metrics(&self) -> Option<MetricsSet> {
316        Some(self.metrics_set.clone_inner())
317    }
318
319    fn statistics(&self) -> Result<Statistics> {
320        self.input.partition_statistics(None)
321    }
322
323    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
324        self.input.partition_statistics(partition)
325    }
326}
327
328struct PartialSortStream {
329    /// The input plan
330    input: SendableRecordBatchStream,
331    /// Sort expressions
332    expr: LexOrdering,
333    /// Length of prefix common to input ordering and required ordering of plan
334    /// should be more than 0 otherwise PartialSort is not applicable
335    common_prefix_length: usize,
336    /// Used as a buffer for part of the input not ready for sort
337    in_mem_batch: RecordBatch,
338    /// Fetch top N results
339    fetch: Option<usize>,
340    /// Whether the stream has finished returning all of its data or not
341    is_closed: bool,
342    /// Execution metrics
343    baseline_metrics: BaselineMetrics,
344}
345
346impl Stream for PartialSortStream {
347    type Item = Result<RecordBatch>;
348
349    fn poll_next(
350        mut self: Pin<&mut Self>,
351        cx: &mut Context<'_>,
352    ) -> Poll<Option<Self::Item>> {
353        let poll = self.poll_next_inner(cx);
354        self.baseline_metrics.record_poll(poll)
355    }
356
357    fn size_hint(&self) -> (usize, Option<usize>) {
358        // we can't predict the size of incoming batches so re-use the size hint from the input
359        self.input.size_hint()
360    }
361}
362
363impl RecordBatchStream for PartialSortStream {
364    fn schema(&self) -> SchemaRef {
365        self.input.schema()
366    }
367}
368
369impl PartialSortStream {
370    fn poll_next_inner(
371        self: &mut Pin<&mut Self>,
372        cx: &mut Context<'_>,
373    ) -> Poll<Option<Result<RecordBatch>>> {
374        if self.is_closed {
375            return Poll::Ready(None);
376        }
377        loop {
378            // Check if we've already reached the fetch limit
379            if self.fetch == Some(0) {
380                self.is_closed = true;
381                return Poll::Ready(None);
382            }
383
384            match ready!(self.input.poll_next_unpin(cx)) {
385                Some(Ok(batch)) => {
386                    // Merge new batch into in_mem_batch
387                    self.in_mem_batch = concat_batches(
388                        &self.schema(),
389                        &[self.in_mem_batch.clone(), batch],
390                    )?;
391
392                    // Check if we have a slice point, otherwise keep accumulating in `self.in_mem_batch`.
393                    if let Some(slice_point) = self
394                        .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
395                    {
396                        let sorted = self.in_mem_batch.slice(0, slice_point);
397                        self.in_mem_batch = self.in_mem_batch.slice(
398                            slice_point,
399                            self.in_mem_batch.num_rows() - slice_point,
400                        );
401                        let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
402                        if let Some(fetch) = self.fetch.as_mut() {
403                            *fetch -= sorted_batch.num_rows();
404                        }
405
406                        if sorted_batch.num_rows() > 0 {
407                            return Poll::Ready(Some(Ok(sorted_batch)));
408                        }
409                    }
410                }
411                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
412                None => {
413                    self.is_closed = true;
414                    // Once input is consumed, sort the rest of the inserted batches
415                    let remaining_batch = self.sort_in_mem_batch()?;
416                    return if remaining_batch.num_rows() > 0 {
417                        Poll::Ready(Some(Ok(remaining_batch)))
418                    } else {
419                        Poll::Ready(None)
420                    };
421                }
422            };
423        }
424    }
425
426    /// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
427    ///
428    /// If fetch is specified for PartialSortStream `sort_in_mem_batch` will limit
429    /// the last RecordBatch returned and will mark the stream as closed
430    fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
431        let input_batch = self.in_mem_batch.clone();
432        self.in_mem_batch = RecordBatch::new_empty(self.schema());
433        let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
434        if let Some(remaining_fetch) = self.fetch {
435            // remaining_fetch - result.num_rows() is always be >= 0
436            // because result length of sort_batch with limit cannot be
437            // more than the requested limit
438            self.fetch = Some(remaining_fetch - result.num_rows());
439            if remaining_fetch == result.num_rows() {
440                self.is_closed = true;
441            }
442        }
443        Ok(result)
444    }
445
446    /// Return the end index of the second last partition if the batch
447    /// can be partitioned based on its already sorted columns
448    ///
449    /// Return None if the batch cannot be partitioned, which means the
450    /// batch does not have the information for a safe sort
451    fn get_slice_point(
452        &self,
453        common_prefix_len: usize,
454        batch: &RecordBatch,
455    ) -> Result<Option<usize>> {
456        let common_prefix_sort_keys = (0..common_prefix_len)
457            .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
458            .collect::<Result<Vec<_>>>()?;
459        let partition_points =
460            evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
461        // If partition points are [0..100], [100..200], [200..300]
462        // we should return 200, which is the safest and furthest partition boundary
463        // Please note that we shouldn't return 300 (which is number of rows in the batch),
464        // because this boundary may change with new data.
465        if partition_points.len() >= 2 {
466            Ok(Some(partition_points[partition_points.len() - 2].end))
467        } else {
468            Ok(None)
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use std::collections::HashMap;
476
477    use arrow::array::*;
478    use arrow::compute::SortOptions;
479    use arrow::datatypes::*;
480    use datafusion_common::test_util::batches_to_string;
481    use futures::FutureExt;
482    use insta::allow_duplicates;
483    use insta::assert_snapshot;
484    use itertools::Itertools;
485
486    use crate::collect;
487    use crate::expressions::col;
488    use crate::expressions::PhysicalSortExpr;
489    use crate::sorts::sort::SortExec;
490    use crate::test;
491    use crate::test::assert_is_pending;
492    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
493    use crate::test::TestMemoryExec;
494
495    use super::*;
496
497    #[tokio::test]
498    async fn test_partial_sort() -> Result<()> {
499        let task_ctx = Arc::new(TaskContext::default());
500        let source = test::build_table_scan_i32(
501            ("a", &vec![0, 0, 0, 1, 1, 1]),
502            ("b", &vec![1, 1, 2, 2, 3, 3]),
503            ("c", &vec![1, 0, 5, 4, 3, 2]),
504        );
505        let schema = Schema::new(vec![
506            Field::new("a", DataType::Int32, false),
507            Field::new("b", DataType::Int32, false),
508            Field::new("c", DataType::Int32, false),
509        ]);
510        let option_asc = SortOptions {
511            descending: false,
512            nulls_first: false,
513        };
514
515        let partial_sort_exec = Arc::new(PartialSortExec::new(
516            [
517                PhysicalSortExpr {
518                    expr: col("a", &schema)?,
519                    options: option_asc,
520                },
521                PhysicalSortExpr {
522                    expr: col("b", &schema)?,
523                    options: option_asc,
524                },
525                PhysicalSortExpr {
526                    expr: col("c", &schema)?,
527                    options: option_asc,
528                },
529            ]
530            .into(),
531            Arc::clone(&source),
532            2,
533        ));
534
535        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
536
537        assert_eq!(2, result.len());
538        allow_duplicates! {
539            assert_snapshot!(batches_to_string(&result), @r#"
540                +---+---+---+
541                | a | b | c |
542                +---+---+---+
543                | 0 | 1 | 0 |
544                | 0 | 1 | 1 |
545                | 0 | 2 | 5 |
546                | 1 | 2 | 4 |
547                | 1 | 3 | 2 |
548                | 1 | 3 | 3 |
549                +---+---+---+
550                "#);
551        }
552        assert_eq!(
553            task_ctx.runtime_env().memory_pool.reserved(),
554            0,
555            "The sort should have returned all memory used back to the memory manager"
556        );
557
558        Ok(())
559    }
560
561    #[tokio::test]
562    async fn test_partial_sort_with_fetch() -> Result<()> {
563        let task_ctx = Arc::new(TaskContext::default());
564        let source = test::build_table_scan_i32(
565            ("a", &vec![0, 0, 1, 1, 1]),
566            ("b", &vec![1, 2, 2, 3, 3]),
567            ("c", &vec![4, 3, 2, 1, 0]),
568        );
569        let schema = Schema::new(vec![
570            Field::new("a", DataType::Int32, false),
571            Field::new("b", DataType::Int32, false),
572            Field::new("c", DataType::Int32, false),
573        ]);
574        let option_asc = SortOptions {
575            descending: false,
576            nulls_first: false,
577        };
578
579        for common_prefix_length in [1, 2] {
580            let partial_sort_exec = Arc::new(
581                PartialSortExec::new(
582                    [
583                        PhysicalSortExpr {
584                            expr: col("a", &schema)?,
585                            options: option_asc,
586                        },
587                        PhysicalSortExpr {
588                            expr: col("b", &schema)?,
589                            options: option_asc,
590                        },
591                        PhysicalSortExpr {
592                            expr: col("c", &schema)?,
593                            options: option_asc,
594                        },
595                    ]
596                    .into(),
597                    Arc::clone(&source),
598                    common_prefix_length,
599                )
600                .with_fetch(Some(4)),
601            );
602
603            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
604
605            assert_eq!(2, result.len());
606            allow_duplicates! {
607                assert_snapshot!(batches_to_string(&result), @r#"
608                    +---+---+---+
609                    | a | b | c |
610                    +---+---+---+
611                    | 0 | 1 | 4 |
612                    | 0 | 2 | 3 |
613                    | 1 | 2 | 2 |
614                    | 1 | 3 | 0 |
615                    +---+---+---+
616                    "#);
617            }
618            assert_eq!(
619                task_ctx.runtime_env().memory_pool.reserved(),
620                0,
621                "The sort should have returned all memory used back to the memory manager"
622            );
623        }
624
625        Ok(())
626    }
627
628    #[tokio::test]
629    async fn test_partial_sort2() -> Result<()> {
630        let task_ctx = Arc::new(TaskContext::default());
631        let source_tables = [
632            test::build_table_scan_i32(
633                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
634                ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
635                ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
636            ),
637            test::build_table_scan_i32(
638                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
639                ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
640                ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
641            ),
642        ];
643        let schema = Schema::new(vec![
644            Field::new("a", DataType::Int32, false),
645            Field::new("b", DataType::Int32, false),
646            Field::new("c", DataType::Int32, false),
647        ]);
648        let option_asc = SortOptions {
649            descending: false,
650            nulls_first: false,
651        };
652        for (common_prefix_length, source) in
653            [(1, &source_tables[0]), (2, &source_tables[1])]
654        {
655            let partial_sort_exec = Arc::new(PartialSortExec::new(
656                [
657                    PhysicalSortExpr {
658                        expr: col("a", &schema)?,
659                        options: option_asc,
660                    },
661                    PhysicalSortExpr {
662                        expr: col("b", &schema)?,
663                        options: option_asc,
664                    },
665                    PhysicalSortExpr {
666                        expr: col("c", &schema)?,
667                        options: option_asc,
668                    },
669                ]
670                .into(),
671                Arc::clone(source),
672                common_prefix_length,
673            ));
674
675            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
676            assert_eq!(2, result.len());
677            assert_eq!(
678                task_ctx.runtime_env().memory_pool.reserved(),
679                0,
680                "The sort should have returned all memory used back to the memory manager"
681            );
682            allow_duplicates! {
683                assert_snapshot!(batches_to_string(&result), @r#"
684                    +---+---+---+
685                    | a | b | c |
686                    +---+---+---+
687                    | 0 | 1 | 6 |
688                    | 0 | 1 | 7 |
689                    | 0 | 3 | 4 |
690                    | 0 | 3 | 5 |
691                    | 1 | 2 | 0 |
692                    | 1 | 2 | 1 |
693                    | 1 | 4 | 2 |
694                    | 1 | 4 | 3 |
695                    +---+---+---+
696                    "#);
697            }
698        }
699        Ok(())
700    }
701
702    fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
703        let batch1 = test::build_table_i32(
704            ("a", &vec![1; 100]),
705            ("b", &(0..100).rev().collect()),
706            ("c", &(0..100).rev().collect()),
707        );
708        let batch2 = test::build_table_i32(
709            ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
710            ("b", &(100..200).rev().collect()),
711            ("c", &(0..100).collect()),
712        );
713        let batch3 = test::build_table_i32(
714            ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
715            ("b", &(150..250).rev().collect()),
716            ("c", &(0..100).rev().collect()),
717        );
718        let batch4 = test::build_table_i32(
719            ("a", &vec![4; 100]),
720            ("b", &(50..150).rev().collect()),
721            ("c", &(0..100).rev().collect()),
722        );
723        let schema = batch1.schema();
724
725        TestMemoryExec::try_new_exec(
726            &[vec![batch1, batch2, batch3, batch4]],
727            Arc::clone(&schema),
728            None,
729        )
730        .unwrap() as Arc<dyn ExecutionPlan>
731    }
732
733    #[tokio::test]
734    async fn test_partitioned_input_partial_sort() -> Result<()> {
735        let task_ctx = Arc::new(TaskContext::default());
736        let mem_exec = prepare_partitioned_input();
737        let option_asc = SortOptions {
738            descending: false,
739            nulls_first: false,
740        };
741        let option_desc = SortOptions {
742            descending: false,
743            nulls_first: false,
744        };
745        let schema = mem_exec.schema();
746        let partial_sort_exec = PartialSortExec::new(
747            [
748                PhysicalSortExpr {
749                    expr: col("a", &schema)?,
750                    options: option_asc,
751                },
752                PhysicalSortExpr {
753                    expr: col("b", &schema)?,
754                    options: option_desc,
755                },
756                PhysicalSortExpr {
757                    expr: col("c", &schema)?,
758                    options: option_asc,
759                },
760            ]
761            .into(),
762            Arc::clone(&mem_exec),
763            1,
764        );
765        let sort_exec = Arc::new(SortExec::new(
766            partial_sort_exec.expr.clone(),
767            Arc::clone(&partial_sort_exec.input),
768        ));
769        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
770        assert_eq!(
771            result.iter().map(|r| r.num_rows()).collect_vec(),
772            [125, 125, 150]
773        );
774
775        assert_eq!(
776            task_ctx.runtime_env().memory_pool.reserved(),
777            0,
778            "The sort should have returned all memory used back to the memory manager"
779        );
780        let partial_sort_result = concat_batches(&schema, &result).unwrap();
781        let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
782        assert_eq!(sort_result[0], partial_sort_result);
783
784        Ok(())
785    }
786
787    #[tokio::test]
788    async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
789        let task_ctx = Arc::new(TaskContext::default());
790        let mem_exec = prepare_partitioned_input();
791        let schema = mem_exec.schema();
792        let option_asc = SortOptions {
793            descending: false,
794            nulls_first: false,
795        };
796        let option_desc = SortOptions {
797            descending: false,
798            nulls_first: false,
799        };
800        for (fetch_size, expected_batch_num_rows) in [
801            (Some(50), vec![50]),
802            (Some(120), vec![120]),
803            (Some(150), vec![125, 25]),
804            (Some(250), vec![125, 125]),
805        ] {
806            let partial_sort_exec = PartialSortExec::new(
807                [
808                    PhysicalSortExpr {
809                        expr: col("a", &schema)?,
810                        options: option_asc,
811                    },
812                    PhysicalSortExpr {
813                        expr: col("b", &schema)?,
814                        options: option_desc,
815                    },
816                    PhysicalSortExpr {
817                        expr: col("c", &schema)?,
818                        options: option_asc,
819                    },
820                ]
821                .into(),
822                Arc::clone(&mem_exec),
823                1,
824            )
825            .with_fetch(fetch_size);
826
827            let sort_exec = Arc::new(
828                SortExec::new(
829                    partial_sort_exec.expr.clone(),
830                    Arc::clone(&partial_sort_exec.input),
831                )
832                .with_fetch(fetch_size),
833            );
834            let result =
835                collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
836            assert_eq!(
837                result.iter().map(|r| r.num_rows()).collect_vec(),
838                expected_batch_num_rows
839            );
840
841            assert_eq!(
842                task_ctx.runtime_env().memory_pool.reserved(),
843                0,
844                "The sort should have returned all memory used back to the memory manager"
845            );
846            let partial_sort_result = concat_batches(&schema, &result)?;
847            let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
848            assert_eq!(sort_result[0], partial_sort_result);
849        }
850
851        Ok(())
852    }
853
854    #[tokio::test]
855    async fn test_partial_sort_no_empty_batches() -> Result<()> {
856        let task_ctx = Arc::new(TaskContext::default());
857        let mem_exec = prepare_partitioned_input();
858        let schema = mem_exec.schema();
859        let option_asc = SortOptions {
860            descending: false,
861            nulls_first: false,
862        };
863        let fetch_size = Some(250);
864        let partial_sort_exec = PartialSortExec::new(
865            [
866                PhysicalSortExpr {
867                    expr: col("a", &schema)?,
868                    options: option_asc,
869                },
870                PhysicalSortExpr {
871                    expr: col("c", &schema)?,
872                    options: option_asc,
873                },
874            ]
875            .into(),
876            Arc::clone(&mem_exec),
877            1,
878        )
879        .with_fetch(fetch_size);
880
881        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
882        for rb in result {
883            assert!(rb.num_rows() > 0);
884        }
885
886        Ok(())
887    }
888
889    #[tokio::test]
890    async fn test_sort_metadata() -> Result<()> {
891        let task_ctx = Arc::new(TaskContext::default());
892        let field_metadata: HashMap<String, String> =
893            vec![("foo".to_string(), "bar".to_string())]
894                .into_iter()
895                .collect();
896        let schema_metadata: HashMap<String, String> =
897            vec![("baz".to_string(), "barf".to_string())]
898                .into_iter()
899                .collect();
900
901        let mut field = Field::new("field_name", DataType::UInt64, true);
902        field.set_metadata(field_metadata.clone());
903        let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
904        let schema = Arc::new(schema);
905
906        let data: ArrayRef =
907            Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
908
909        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
910        let input =
911            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
912
913        let partial_sort_exec = Arc::new(PartialSortExec::new(
914            [PhysicalSortExpr {
915                expr: col("field_name", &schema)?,
916                options: SortOptions::default(),
917            }]
918            .into(),
919            input,
920            1,
921        ));
922
923        let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
924        let expected_batch = vec![
925            RecordBatch::try_new(
926                Arc::clone(&schema),
927                vec![Arc::new(
928                    vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
929                )],
930            )?,
931            RecordBatch::try_new(
932                Arc::clone(&schema),
933                vec![Arc::new(
934                    vec![2].into_iter().map(Some).collect::<UInt64Array>(),
935                )],
936            )?,
937        ];
938
939        // Data is correct
940        assert_eq!(&expected_batch, &result);
941
942        // explicitly ensure the metadata is present
943        assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
944        assert_eq!(result[0].schema().metadata(), &schema_metadata);
945
946        Ok(())
947    }
948
949    #[tokio::test]
950    async fn test_lex_sort_by_float() -> Result<()> {
951        let task_ctx = Arc::new(TaskContext::default());
952        let schema = Arc::new(Schema::new(vec![
953            Field::new("a", DataType::Float32, true),
954            Field::new("b", DataType::Float64, true),
955            Field::new("c", DataType::Float64, true),
956        ]));
957        let option_asc = SortOptions {
958            descending: false,
959            nulls_first: true,
960        };
961        let option_desc = SortOptions {
962            descending: true,
963            nulls_first: true,
964        };
965
966        // define data.
967        let batch = RecordBatch::try_new(
968            Arc::clone(&schema),
969            vec![
970                Arc::new(Float32Array::from(vec![
971                    Some(1.0_f32),
972                    Some(1.0_f32),
973                    Some(1.0_f32),
974                    Some(2.0_f32),
975                    Some(2.0_f32),
976                    Some(3.0_f32),
977                    Some(3.0_f32),
978                    Some(3.0_f32),
979                ])),
980                Arc::new(Float64Array::from(vec![
981                    Some(20.0_f64),
982                    Some(20.0_f64),
983                    Some(40.0_f64),
984                    Some(40.0_f64),
985                    Some(f64::NAN),
986                    None,
987                    None,
988                    Some(f64::NAN),
989                ])),
990                Arc::new(Float64Array::from(vec![
991                    Some(10.0_f64),
992                    Some(20.0_f64),
993                    Some(10.0_f64),
994                    Some(100.0_f64),
995                    Some(f64::NAN),
996                    Some(100.0_f64),
997                    None,
998                    Some(f64::NAN),
999                ])),
1000            ],
1001        )?;
1002
1003        let partial_sort_exec = Arc::new(PartialSortExec::new(
1004            [
1005                PhysicalSortExpr {
1006                    expr: col("a", &schema)?,
1007                    options: option_asc,
1008                },
1009                PhysicalSortExpr {
1010                    expr: col("b", &schema)?,
1011                    options: option_asc,
1012                },
1013                PhysicalSortExpr {
1014                    expr: col("c", &schema)?,
1015                    options: option_desc,
1016                },
1017            ]
1018            .into(),
1019            TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1020            2,
1021        ));
1022
1023        assert_eq!(
1024            DataType::Float32,
1025            *partial_sort_exec.schema().field(0).data_type()
1026        );
1027        assert_eq!(
1028            DataType::Float64,
1029            *partial_sort_exec.schema().field(1).data_type()
1030        );
1031        assert_eq!(
1032            DataType::Float64,
1033            *partial_sort_exec.schema().field(2).data_type()
1034        );
1035
1036        let result: Vec<RecordBatch> = collect(
1037            Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1038            task_ctx,
1039        )
1040        .await?;
1041        assert_snapshot!(batches_to_string(&result), @r#"
1042            +-----+------+-------+
1043            | a   | b    | c     |
1044            +-----+------+-------+
1045            | 1.0 | 20.0 | 20.0  |
1046            | 1.0 | 20.0 | 10.0  |
1047            | 1.0 | 40.0 | 10.0  |
1048            | 2.0 | 40.0 | 100.0 |
1049            | 2.0 | NaN  | NaN   |
1050            | 3.0 |      |       |
1051            | 3.0 |      | 100.0 |
1052            | 3.0 | NaN  | NaN   |
1053            +-----+------+-------+
1054            "#);
1055        assert_eq!(result.len(), 2);
1056        let metrics = partial_sort_exec.metrics().unwrap();
1057        assert!(metrics.elapsed_compute().unwrap() > 0);
1058        assert_eq!(metrics.output_rows().unwrap(), 8);
1059
1060        let columns = result[0].columns();
1061
1062        assert_eq!(DataType::Float32, *columns[0].data_type());
1063        assert_eq!(DataType::Float64, *columns[1].data_type());
1064        assert_eq!(DataType::Float64, *columns[2].data_type());
1065
1066        Ok(())
1067    }
1068
1069    #[tokio::test]
1070    async fn test_drop_cancel() -> Result<()> {
1071        let task_ctx = Arc::new(TaskContext::default());
1072        let schema = Arc::new(Schema::new(vec![
1073            Field::new("a", DataType::Float32, true),
1074            Field::new("b", DataType::Float32, true),
1075        ]));
1076
1077        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1078        let refs = blocking_exec.refs();
1079        let sort_exec = Arc::new(PartialSortExec::new(
1080            [PhysicalSortExpr {
1081                expr: col("a", &schema)?,
1082                options: SortOptions::default(),
1083            }]
1084            .into(),
1085            blocking_exec,
1086            1,
1087        ));
1088
1089        let fut = collect(sort_exec, Arc::clone(&task_ctx));
1090        let mut fut = fut.boxed();
1091
1092        assert_is_pending(&mut fut);
1093        drop(fut);
1094        assert_strong_count_converges_to_zero(refs).await;
1095
1096        assert_eq!(
1097            task_ctx.runtime_env().memory_pool.reserved(),
1098            0,
1099            "The sort should have returned all memory used back to the memory manager"
1100        );
1101
1102        Ok(())
1103    }
1104
1105    #[tokio::test]
1106    async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1107        // Test case for the bug where batches with homogeneous sort keys
1108        // (e.g., [1,1,1], [2,2,2]) would not be properly detected as having
1109        // slice points between batches.
1110        let task_ctx = Arc::new(TaskContext::default());
1111
1112        // Create batches where each batch has homogeneous values for sort keys
1113        let batch1 = test::build_table_i32(
1114            ("a", &vec![1; 3]),
1115            ("b", &vec![1; 3]),
1116            ("c", &vec![3, 2, 1]),
1117        );
1118        let batch2 = test::build_table_i32(
1119            ("a", &vec![2; 3]),
1120            ("b", &vec![2; 3]),
1121            ("c", &vec![4, 6, 4]),
1122        );
1123        let batch3 = test::build_table_i32(
1124            ("a", &vec![3; 3]),
1125            ("b", &vec![3; 3]),
1126            ("c", &vec![9, 7, 8]),
1127        );
1128
1129        let schema = batch1.schema();
1130        let mem_exec = TestMemoryExec::try_new_exec(
1131            &[vec![batch1, batch2, batch3]],
1132            Arc::clone(&schema),
1133            None,
1134        )?;
1135
1136        let option_asc = SortOptions {
1137            descending: false,
1138            nulls_first: false,
1139        };
1140
1141        // Partial sort with common prefix of 2 (sorting by a, b, c)
1142        let partial_sort_exec = Arc::new(PartialSortExec::new(
1143            [
1144                PhysicalSortExpr {
1145                    expr: col("a", &schema)?,
1146                    options: option_asc,
1147                },
1148                PhysicalSortExpr {
1149                    expr: col("b", &schema)?,
1150                    options: option_asc,
1151                },
1152                PhysicalSortExpr {
1153                    expr: col("c", &schema)?,
1154                    options: option_asc,
1155                },
1156            ]
1157            .into(),
1158            mem_exec,
1159            2,
1160        ));
1161
1162        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1163
1164        assert_eq!(result.len(), 3,);
1165
1166        allow_duplicates! {
1167            assert_snapshot!(batches_to_string(&result), @r#"
1168                +---+---+---+
1169                | a | b | c |
1170                +---+---+---+
1171                | 1 | 1 | 1 |
1172                | 1 | 1 | 2 |
1173                | 1 | 1 | 3 |
1174                | 2 | 2 | 4 |
1175                | 2 | 2 | 4 |
1176                | 2 | 2 | 6 |
1177                | 3 | 3 | 7 |
1178                | 3 | 3 | 8 |
1179                | 3 | 3 | 9 |
1180                +---+---+---+
1181                "#);
1182        }
1183
1184        assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1185        Ok(())
1186    }
1187}