Skip to main content

datafusion_physical_plan/sorts/
partial_sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partial Sort deals with input data that partially
19//! satisfies the required sort order. Such an input data can be
20//! partitioned into segments where each segment already has the
21//! required information for lexicographic sorting so sorting
22//! can be done without loading the entire dataset.
23//!
24//! Consider a sort plan having an input with ordering `a ASC, b ASC`
25//!
26//! ```text
27//! +---+---+---+
28//! | a | b | d |
29//! +---+---+---+
30//! | 0 | 0 | 3 |
31//! | 0 | 0 | 2 |
32//! | 0 | 1 | 1 |
33//! | 0 | 2 | 0 |
34//! +---+---+---+
35//! ```
36//!
37//! and required ordering for the plan is `a ASC, b ASC, d ASC`.
38//! The first 3 rows(segment) can be sorted as the segment already
39//! has the required information for the sort, but the last row
40//! requires further information as the input can continue with a
41//! batch with a starting row where a and b does not change as below
42//!
43//! ```text
44//! +---+---+---+
45//! | a | b | d |
46//! +---+---+---+
47//! | 0 | 2 | 4 |
48//! +---+---+---+
49//! ```
50//!
51//! The plan concats incoming data with such last rows of previous input
52//! and continues partial sorting of the segments.
53
54use std::fmt::Debug;
55use std::pin::Pin;
56use std::sync::Arc;
57use std::task::{Context, Poll};
58
59use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
60use crate::sorts::sort::sort_batch;
61use crate::stream::EmptyRecordBatchStream;
62use crate::{
63    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65    check_if_same_properties,
66};
67
68use arrow::compute::concat_batches;
69use arrow::datatypes::SchemaRef;
70use arrow::record_batch::RecordBatch;
71use datafusion_common::Result;
72use datafusion_common::utils::evaluate_partition_ranges;
73use datafusion_execution::{RecordBatchStream, TaskContext};
74use datafusion_physical_expr::LexOrdering;
75
76use futures::{Stream, StreamExt, ready};
77use log::trace;
78
79/// Partial Sort execution plan.
80#[derive(Debug, Clone)]
81pub struct PartialSortExec {
82    /// Input schema
83    pub(crate) input: Arc<dyn ExecutionPlan>,
84    /// Sort expressions
85    expr: LexOrdering,
86    /// Length of continuous matching columns of input that satisfy
87    /// the required ordering for the sort
88    common_prefix_length: usize,
89    /// Containing all metrics set created during sort
90    metrics_set: ExecutionPlanMetricsSet,
91    /// Preserve partitions of input plan. If false, the input partitions
92    /// will be sorted and merged into a single output partition.
93    preserve_partitioning: bool,
94    /// Fetch highest/lowest n results
95    fetch: Option<usize>,
96    /// Cache holding plan properties like equivalences, output partitioning etc.
97    cache: Arc<PlanProperties>,
98}
99
100impl PartialSortExec {
101    /// Create a new partial sort execution plan
102    pub fn new(
103        expr: LexOrdering,
104        input: Arc<dyn ExecutionPlan>,
105        common_prefix_length: usize,
106    ) -> Self {
107        debug_assert!(common_prefix_length > 0);
108        let preserve_partitioning = false;
109        let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
110            .unwrap();
111        Self {
112            input,
113            expr,
114            common_prefix_length,
115            metrics_set: ExecutionPlanMetricsSet::new(),
116            preserve_partitioning,
117            fetch: None,
118            cache: Arc::new(cache),
119        }
120    }
121
122    /// Whether this `PartialSortExec` preserves partitioning of the children
123    pub fn preserve_partitioning(&self) -> bool {
124        self.preserve_partitioning
125    }
126
127    /// Specify the partitioning behavior of this partial sort exec
128    ///
129    /// If `preserve_partitioning` is true, sorts each partition
130    /// individually, producing one sorted stream for each input partition.
131    ///
132    /// If `preserve_partitioning` is false, sorts and merges all
133    /// input partitions producing a single, sorted partition.
134    pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
135        self.preserve_partitioning = preserve_partitioning;
136        Arc::make_mut(&mut self.cache).partitioning =
137            Self::output_partitioning_helper(&self.input, self.preserve_partitioning);
138        self
139    }
140
141    /// Modify how many rows to include in the result
142    ///
143    /// If None, then all rows will be returned, in sorted order.
144    /// If Some, then only the top `fetch` rows will be returned.
145    /// This can reduce the memory pressure required by the sort
146    /// operation since rows that are not going to be included
147    /// can be dropped.
148    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
149        self.fetch = fetch;
150        self
151    }
152
153    /// Input schema
154    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
155        &self.input
156    }
157
158    /// Sort expressions
159    pub fn expr(&self) -> &LexOrdering {
160        &self.expr
161    }
162
163    /// If `Some(fetch)`, limits output to only the first "fetch" items
164    pub fn fetch(&self) -> Option<usize> {
165        self.fetch
166    }
167
168    /// Common prefix length
169    pub fn common_prefix_length(&self) -> usize {
170        self.common_prefix_length
171    }
172
173    fn output_partitioning_helper(
174        input: &Arc<dyn ExecutionPlan>,
175        preserve_partitioning: bool,
176    ) -> Partitioning {
177        // Get output partitioning:
178        if preserve_partitioning {
179            input.output_partitioning().clone()
180        } else {
181            Partitioning::UnknownPartitioning(1)
182        }
183    }
184
185    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
186    fn compute_properties(
187        input: &Arc<dyn ExecutionPlan>,
188        sort_exprs: LexOrdering,
189        preserve_partitioning: bool,
190    ) -> Result<PlanProperties> {
191        // Calculate equivalence properties; i.e. reset the ordering equivalence
192        // class with the new ordering:
193        let mut eq_properties = input.equivalence_properties().clone();
194        eq_properties.reorder(sort_exprs)?;
195
196        // Get output partitioning:
197        let output_partitioning =
198            Self::output_partitioning_helper(input, preserve_partitioning);
199
200        Ok(PlanProperties::new(
201            eq_properties,
202            output_partitioning,
203            input.pipeline_behavior(),
204            input.boundedness(),
205        ))
206    }
207
208    fn with_new_children_and_same_properties(
209        &self,
210        mut children: Vec<Arc<dyn ExecutionPlan>>,
211    ) -> Self {
212        Self {
213            input: children.swap_remove(0),
214            metrics_set: ExecutionPlanMetricsSet::new(),
215            ..Self::clone(self)
216        }
217    }
218}
219
220impl DisplayAs for PartialSortExec {
221    fn fmt_as(
222        &self,
223        t: DisplayFormatType,
224        f: &mut std::fmt::Formatter,
225    ) -> std::fmt::Result {
226        match t {
227            DisplayFormatType::Default | DisplayFormatType::Verbose => {
228                let common_prefix_length = self.common_prefix_length;
229                match self.fetch {
230                    Some(fetch) => {
231                        write!(
232                            f,
233                            "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]",
234                            self.expr
235                        )
236                    }
237                    None => write!(
238                        f,
239                        "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]",
240                        self.expr
241                    ),
242                }
243            }
244            DisplayFormatType::TreeRender => match self.fetch {
245                Some(fetch) => {
246                    writeln!(f, "{}", self.expr)?;
247                    writeln!(f, "limit={fetch}")
248                }
249                None => {
250                    writeln!(f, "{}", self.expr)
251                }
252            },
253        }
254    }
255}
256
257impl ExecutionPlan for PartialSortExec {
258    fn name(&self) -> &'static str {
259        "PartialSortExec"
260    }
261
262    fn properties(&self) -> &Arc<PlanProperties> {
263        &self.cache
264    }
265
266    fn fetch(&self) -> Option<usize> {
267        self.fetch
268    }
269
270    fn required_input_distribution(&self) -> Vec<Distribution> {
271        if self.preserve_partitioning {
272            vec![Distribution::UnspecifiedDistribution]
273        } else {
274            vec![Distribution::SinglePartition]
275        }
276    }
277
278    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
279        vec![false]
280    }
281
282    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
283        vec![&self.input]
284    }
285
286    fn with_new_children(
287        self: Arc<Self>,
288        children: Vec<Arc<dyn ExecutionPlan>>,
289    ) -> Result<Arc<dyn ExecutionPlan>> {
290        check_if_same_properties!(self, children);
291        let new_partial_sort = PartialSortExec::new(
292            self.expr.clone(),
293            Arc::clone(&children[0]),
294            self.common_prefix_length,
295        )
296        .with_fetch(self.fetch)
297        .with_preserve_partitioning(self.preserve_partitioning);
298
299        Ok(Arc::new(new_partial_sort))
300    }
301
302    fn execute(
303        &self,
304        partition: usize,
305        context: Arc<TaskContext>,
306    ) -> Result<SendableRecordBatchStream> {
307        trace!(
308            "Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}",
309            partition,
310            context.session_id(),
311            context.task_id()
312        );
313
314        let input = self.input.execute(partition, Arc::clone(&context))?;
315
316        trace!("End PartialSortExec's input.execute for partition: {partition}");
317
318        // Make sure common prefix length is larger than 0
319        // Otherwise, we should use SortExec.
320        debug_assert!(self.common_prefix_length > 0);
321
322        Ok(Box::pin(PartialSortStream {
323            input,
324            expr: self.expr.clone(),
325            common_prefix_length: self.common_prefix_length,
326            in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
327            fetch: self.fetch,
328            is_closed: false,
329            baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
330        }))
331    }
332
333    fn metrics(&self) -> Option<MetricsSet> {
334        Some(self.metrics_set.clone_inner())
335    }
336
337    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
338        self.input.partition_statistics(partition)
339    }
340}
341
342struct PartialSortStream {
343    /// The input plan
344    input: SendableRecordBatchStream,
345    /// Sort expressions
346    expr: LexOrdering,
347    /// Length of prefix common to input ordering and required ordering of plan
348    /// should be more than 0 otherwise PartialSort is not applicable
349    common_prefix_length: usize,
350    /// Used as a buffer for part of the input not ready for sort
351    in_mem_batch: RecordBatch,
352    /// Fetch top N results
353    fetch: Option<usize>,
354    /// Whether the stream has finished returning all of its data or not
355    is_closed: bool,
356    /// Execution metrics
357    baseline_metrics: BaselineMetrics,
358}
359
360impl Stream for PartialSortStream {
361    type Item = Result<RecordBatch>;
362
363    fn poll_next(
364        mut self: Pin<&mut Self>,
365        cx: &mut Context<'_>,
366    ) -> Poll<Option<Self::Item>> {
367        let poll = self.poll_next_inner(cx);
368        self.baseline_metrics.record_poll(poll)
369    }
370
371    fn size_hint(&self) -> (usize, Option<usize>) {
372        // we can't predict the size of incoming batches so re-use the size hint from the input
373        self.input.size_hint()
374    }
375}
376
377impl RecordBatchStream for PartialSortStream {
378    fn schema(&self) -> SchemaRef {
379        self.input.schema()
380    }
381}
382
383impl PartialSortStream {
384    fn poll_next_inner(
385        self: &mut Pin<&mut Self>,
386        cx: &mut Context<'_>,
387    ) -> Poll<Option<Result<RecordBatch>>> {
388        if self.is_closed {
389            return Poll::Ready(None);
390        }
391        loop {
392            // Check if we've already reached the fetch limit
393            if self.fetch == Some(0) {
394                self.is_closed = true;
395                // Release the input pipeline's resources.
396                let input_schema = self.input.schema();
397                self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
398                return Poll::Ready(None);
399            }
400
401            match ready!(self.input.poll_next_unpin(cx)) {
402                Some(Ok(batch)) => {
403                    // Merge new batch into in_mem_batch
404                    self.in_mem_batch = concat_batches(
405                        &self.schema(),
406                        &[self.in_mem_batch.clone(), batch],
407                    )?;
408
409                    // Check if we have a slice point, otherwise keep accumulating in `self.in_mem_batch`.
410                    if let Some(slice_point) = self
411                        .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
412                    {
413                        let sorted = self.in_mem_batch.slice(0, slice_point);
414                        self.in_mem_batch = self.in_mem_batch.slice(
415                            slice_point,
416                            self.in_mem_batch.num_rows() - slice_point,
417                        );
418                        let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
419                        if let Some(fetch) = self.fetch.as_mut() {
420                            *fetch -= sorted_batch.num_rows();
421                        }
422
423                        if sorted_batch.num_rows() > 0 {
424                            return Poll::Ready(Some(Ok(sorted_batch)));
425                        }
426                    }
427                }
428                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
429                None => {
430                    self.is_closed = true;
431                    // Release the input pipeline's resources before sorting.
432                    let input_schema = self.input.schema();
433                    self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
434                    // Once input is consumed, sort the rest of the inserted batches
435                    let remaining_batch = self.sort_in_mem_batch()?;
436                    return if remaining_batch.num_rows() > 0 {
437                        Poll::Ready(Some(Ok(remaining_batch)))
438                    } else {
439                        Poll::Ready(None)
440                    };
441                }
442            };
443        }
444    }
445
446    /// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
447    ///
448    /// If fetch is specified for PartialSortStream `sort_in_mem_batch` will limit
449    /// the last RecordBatch returned and will mark the stream as closed
450    fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
451        let input_batch = self.in_mem_batch.clone();
452        self.in_mem_batch = RecordBatch::new_empty(self.schema());
453        let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
454        if let Some(remaining_fetch) = self.fetch {
455            // remaining_fetch - result.num_rows() is always be >= 0
456            // because result length of sort_batch with limit cannot be
457            // more than the requested limit
458            self.fetch = Some(remaining_fetch - result.num_rows());
459            if remaining_fetch == result.num_rows() {
460                self.is_closed = true;
461            }
462        }
463        Ok(result)
464    }
465
466    /// Return the end index of the second last partition if the batch
467    /// can be partitioned based on its already sorted columns
468    ///
469    /// Return None if the batch cannot be partitioned, which means the
470    /// batch does not have the information for a safe sort
471    fn get_slice_point(
472        &self,
473        common_prefix_len: usize,
474        batch: &RecordBatch,
475    ) -> Result<Option<usize>> {
476        let common_prefix_sort_keys = (0..common_prefix_len)
477            .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
478            .collect::<Result<Vec<_>>>()?;
479        let partition_points =
480            evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
481        // If partition points are [0..100], [100..200], [200..300]
482        // we should return 200, which is the safest and furthest partition boundary
483        // Please note that we shouldn't return 300 (which is number of rows in the batch),
484        // because this boundary may change with new data.
485        if partition_points.len() >= 2 {
486            Ok(Some(partition_points[partition_points.len() - 2].end))
487        } else {
488            Ok(None)
489        }
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use std::collections::HashMap;
496
497    use arrow::array::*;
498    use arrow::compute::SortOptions;
499    use arrow::datatypes::*;
500    use datafusion_common::test_util::batches_to_string;
501    use futures::FutureExt;
502    use insta::allow_duplicates;
503    use insta::assert_snapshot;
504    use itertools::Itertools;
505
506    use crate::collect;
507    use crate::expressions::PhysicalSortExpr;
508    use crate::expressions::col;
509    use crate::sorts::sort::SortExec;
510    use crate::test;
511    use crate::test::TestMemoryExec;
512    use crate::test::assert_is_pending;
513    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
514
515    use super::*;
516
517    #[tokio::test]
518    async fn test_partial_sort() -> Result<()> {
519        let task_ctx = Arc::new(TaskContext::default());
520        let source = test::build_table_scan_i32(
521            ("a", &vec![0, 0, 0, 1, 1, 1]),
522            ("b", &vec![1, 1, 2, 2, 3, 3]),
523            ("c", &vec![1, 0, 5, 4, 3, 2]),
524        );
525        let schema = Schema::new(vec![
526            Field::new("a", DataType::Int32, false),
527            Field::new("b", DataType::Int32, false),
528            Field::new("c", DataType::Int32, false),
529        ]);
530        let option_asc = SortOptions {
531            descending: false,
532            nulls_first: false,
533        };
534
535        let partial_sort_exec = Arc::new(PartialSortExec::new(
536            [
537                PhysicalSortExpr {
538                    expr: col("a", &schema)?,
539                    options: option_asc,
540                },
541                PhysicalSortExpr {
542                    expr: col("b", &schema)?,
543                    options: option_asc,
544                },
545                PhysicalSortExpr {
546                    expr: col("c", &schema)?,
547                    options: option_asc,
548                },
549            ]
550            .into(),
551            Arc::clone(&source),
552            2,
553        ));
554
555        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
556
557        assert_eq!(2, result.len());
558        allow_duplicates! {
559            assert_snapshot!(batches_to_string(&result), @r"
560            +---+---+---+
561            | a | b | c |
562            +---+---+---+
563            | 0 | 1 | 0 |
564            | 0 | 1 | 1 |
565            | 0 | 2 | 5 |
566            | 1 | 2 | 4 |
567            | 1 | 3 | 2 |
568            | 1 | 3 | 3 |
569            +---+---+---+
570            ");
571        }
572        assert_eq!(
573            task_ctx.runtime_env().memory_pool.reserved(),
574            0,
575            "The sort should have returned all memory used back to the memory manager"
576        );
577
578        Ok(())
579    }
580
581    #[tokio::test]
582    async fn test_partial_sort_with_fetch() -> Result<()> {
583        let task_ctx = Arc::new(TaskContext::default());
584        let source = test::build_table_scan_i32(
585            ("a", &vec![0, 0, 1, 1, 1]),
586            ("b", &vec![1, 2, 2, 3, 3]),
587            ("c", &vec![4, 3, 2, 1, 0]),
588        );
589        let schema = Schema::new(vec![
590            Field::new("a", DataType::Int32, false),
591            Field::new("b", DataType::Int32, false),
592            Field::new("c", DataType::Int32, false),
593        ]);
594        let option_asc = SortOptions {
595            descending: false,
596            nulls_first: false,
597        };
598
599        for common_prefix_length in [1, 2] {
600            let partial_sort_exec = Arc::new(
601                PartialSortExec::new(
602                    [
603                        PhysicalSortExpr {
604                            expr: col("a", &schema)?,
605                            options: option_asc,
606                        },
607                        PhysicalSortExpr {
608                            expr: col("b", &schema)?,
609                            options: option_asc,
610                        },
611                        PhysicalSortExpr {
612                            expr: col("c", &schema)?,
613                            options: option_asc,
614                        },
615                    ]
616                    .into(),
617                    Arc::clone(&source),
618                    common_prefix_length,
619                )
620                .with_fetch(Some(4)),
621            );
622
623            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
624
625            assert_eq!(2, result.len());
626            allow_duplicates! {
627                assert_snapshot!(batches_to_string(&result), @r"
628                +---+---+---+
629                | a | b | c |
630                +---+---+---+
631                | 0 | 1 | 4 |
632                | 0 | 2 | 3 |
633                | 1 | 2 | 2 |
634                | 1 | 3 | 0 |
635                +---+---+---+
636                ");
637            }
638            assert_eq!(
639                task_ctx.runtime_env().memory_pool.reserved(),
640                0,
641                "The sort should have returned all memory used back to the memory manager"
642            );
643        }
644
645        Ok(())
646    }
647
648    #[tokio::test]
649    async fn test_partial_sort2() -> Result<()> {
650        let task_ctx = Arc::new(TaskContext::default());
651        let source_tables = [
652            test::build_table_scan_i32(
653                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
654                ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
655                ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
656            ),
657            test::build_table_scan_i32(
658                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
659                ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
660                ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
661            ),
662        ];
663        let schema = Schema::new(vec![
664            Field::new("a", DataType::Int32, false),
665            Field::new("b", DataType::Int32, false),
666            Field::new("c", DataType::Int32, false),
667        ]);
668        let option_asc = SortOptions {
669            descending: false,
670            nulls_first: false,
671        };
672        for (common_prefix_length, source) in
673            [(1, &source_tables[0]), (2, &source_tables[1])]
674        {
675            let partial_sort_exec = Arc::new(PartialSortExec::new(
676                [
677                    PhysicalSortExpr {
678                        expr: col("a", &schema)?,
679                        options: option_asc,
680                    },
681                    PhysicalSortExpr {
682                        expr: col("b", &schema)?,
683                        options: option_asc,
684                    },
685                    PhysicalSortExpr {
686                        expr: col("c", &schema)?,
687                        options: option_asc,
688                    },
689                ]
690                .into(),
691                Arc::clone(source),
692                common_prefix_length,
693            ));
694
695            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
696            assert_eq!(2, result.len());
697            assert_eq!(
698                task_ctx.runtime_env().memory_pool.reserved(),
699                0,
700                "The sort should have returned all memory used back to the memory manager"
701            );
702            allow_duplicates! {
703                assert_snapshot!(batches_to_string(&result), @r"
704                +---+---+---+
705                | a | b | c |
706                +---+---+---+
707                | 0 | 1 | 6 |
708                | 0 | 1 | 7 |
709                | 0 | 3 | 4 |
710                | 0 | 3 | 5 |
711                | 1 | 2 | 0 |
712                | 1 | 2 | 1 |
713                | 1 | 4 | 2 |
714                | 1 | 4 | 3 |
715                +---+---+---+
716                ");
717            }
718        }
719        Ok(())
720    }
721
722    fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
723        let batch1 = test::build_table_i32(
724            ("a", &vec![1; 100]),
725            ("b", &(0..100).rev().collect()),
726            ("c", &(0..100).rev().collect()),
727        );
728        let batch2 = test::build_table_i32(
729            ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
730            ("b", &(100..200).rev().collect()),
731            ("c", &(0..100).collect()),
732        );
733        let batch3 = test::build_table_i32(
734            ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
735            ("b", &(150..250).rev().collect()),
736            ("c", &(0..100).rev().collect()),
737        );
738        let batch4 = test::build_table_i32(
739            ("a", &vec![4; 100]),
740            ("b", &(50..150).rev().collect()),
741            ("c", &(0..100).rev().collect()),
742        );
743        let schema = batch1.schema();
744
745        TestMemoryExec::try_new_exec(
746            &[vec![batch1, batch2, batch3, batch4]],
747            Arc::clone(&schema),
748            None,
749        )
750        .unwrap() as Arc<dyn ExecutionPlan>
751    }
752
753    #[tokio::test]
754    async fn test_partitioned_input_partial_sort() -> Result<()> {
755        let task_ctx = Arc::new(TaskContext::default());
756        let mem_exec = prepare_partitioned_input();
757        let option_asc = SortOptions {
758            descending: false,
759            nulls_first: false,
760        };
761        let option_desc = SortOptions {
762            descending: false,
763            nulls_first: false,
764        };
765        let schema = mem_exec.schema();
766        let partial_sort_exec = PartialSortExec::new(
767            [
768                PhysicalSortExpr {
769                    expr: col("a", &schema)?,
770                    options: option_asc,
771                },
772                PhysicalSortExpr {
773                    expr: col("b", &schema)?,
774                    options: option_desc,
775                },
776                PhysicalSortExpr {
777                    expr: col("c", &schema)?,
778                    options: option_asc,
779                },
780            ]
781            .into(),
782            Arc::clone(&mem_exec),
783            1,
784        );
785        let sort_exec = Arc::new(SortExec::new(
786            partial_sort_exec.expr.clone(),
787            Arc::clone(&partial_sort_exec.input),
788        ));
789        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
790        assert_eq!(
791            result.iter().map(|r| r.num_rows()).collect_vec(),
792            [125, 125, 150]
793        );
794
795        assert_eq!(
796            task_ctx.runtime_env().memory_pool.reserved(),
797            0,
798            "The sort should have returned all memory used back to the memory manager"
799        );
800        let partial_sort_result = concat_batches(&schema, &result).unwrap();
801        let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
802        assert_eq!(sort_result[0], partial_sort_result);
803
804        Ok(())
805    }
806
807    #[tokio::test]
808    async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
809        let task_ctx = Arc::new(TaskContext::default());
810        let mem_exec = prepare_partitioned_input();
811        let schema = mem_exec.schema();
812        let option_asc = SortOptions {
813            descending: false,
814            nulls_first: false,
815        };
816        let option_desc = SortOptions {
817            descending: false,
818            nulls_first: false,
819        };
820        for (fetch_size, expected_batch_num_rows) in [
821            (Some(50), vec![50]),
822            (Some(120), vec![120]),
823            (Some(150), vec![125, 25]),
824            (Some(250), vec![125, 125]),
825        ] {
826            let partial_sort_exec = PartialSortExec::new(
827                [
828                    PhysicalSortExpr {
829                        expr: col("a", &schema)?,
830                        options: option_asc,
831                    },
832                    PhysicalSortExpr {
833                        expr: col("b", &schema)?,
834                        options: option_desc,
835                    },
836                    PhysicalSortExpr {
837                        expr: col("c", &schema)?,
838                        options: option_asc,
839                    },
840                ]
841                .into(),
842                Arc::clone(&mem_exec),
843                1,
844            )
845            .with_fetch(fetch_size);
846
847            let sort_exec = Arc::new(
848                SortExec::new(
849                    partial_sort_exec.expr.clone(),
850                    Arc::clone(&partial_sort_exec.input),
851                )
852                .with_fetch(fetch_size),
853            );
854            let result =
855                collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
856            assert_eq!(
857                result.iter().map(|r| r.num_rows()).collect_vec(),
858                expected_batch_num_rows
859            );
860
861            assert_eq!(
862                task_ctx.runtime_env().memory_pool.reserved(),
863                0,
864                "The sort should have returned all memory used back to the memory manager"
865            );
866            let partial_sort_result = concat_batches(&schema, &result)?;
867            let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
868            assert_eq!(sort_result[0], partial_sort_result);
869        }
870
871        Ok(())
872    }
873
874    #[tokio::test]
875    async fn test_partial_sort_no_empty_batches() -> Result<()> {
876        let task_ctx = Arc::new(TaskContext::default());
877        let mem_exec = prepare_partitioned_input();
878        let schema = mem_exec.schema();
879        let option_asc = SortOptions {
880            descending: false,
881            nulls_first: false,
882        };
883        let fetch_size = Some(250);
884        let partial_sort_exec = PartialSortExec::new(
885            [
886                PhysicalSortExpr {
887                    expr: col("a", &schema)?,
888                    options: option_asc,
889                },
890                PhysicalSortExpr {
891                    expr: col("c", &schema)?,
892                    options: option_asc,
893                },
894            ]
895            .into(),
896            Arc::clone(&mem_exec),
897            1,
898        )
899        .with_fetch(fetch_size);
900
901        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
902        for rb in result {
903            assert!(rb.num_rows() > 0);
904        }
905
906        Ok(())
907    }
908
909    #[tokio::test]
910    async fn test_sort_metadata() -> Result<()> {
911        let task_ctx = Arc::new(TaskContext::default());
912        let field_metadata: HashMap<String, String> =
913            vec![("foo".to_string(), "bar".to_string())]
914                .into_iter()
915                .collect();
916        let schema_metadata: HashMap<String, String> =
917            vec![("baz".to_string(), "barf".to_string())]
918                .into_iter()
919                .collect();
920
921        let mut field = Field::new("field_name", DataType::UInt64, true);
922        field.set_metadata(field_metadata.clone());
923        let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
924        let schema = Arc::new(schema);
925
926        let data: ArrayRef =
927            Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
928
929        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
930        let input =
931            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
932
933        let partial_sort_exec = Arc::new(PartialSortExec::new(
934            [PhysicalSortExpr {
935                expr: col("field_name", &schema)?,
936                options: SortOptions::default(),
937            }]
938            .into(),
939            input,
940            1,
941        ));
942
943        let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
944        let expected_batch = vec![
945            RecordBatch::try_new(
946                Arc::clone(&schema),
947                vec![Arc::new(
948                    vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
949                )],
950            )?,
951            RecordBatch::try_new(
952                Arc::clone(&schema),
953                vec![Arc::new(
954                    vec![2].into_iter().map(Some).collect::<UInt64Array>(),
955                )],
956            )?,
957        ];
958
959        // Data is correct
960        assert_eq!(&expected_batch, &result);
961
962        // explicitly ensure the metadata is present
963        assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
964        assert_eq!(result[0].schema().metadata(), &schema_metadata);
965
966        Ok(())
967    }
968
969    #[tokio::test]
970    async fn test_lex_sort_by_float() -> Result<()> {
971        let task_ctx = Arc::new(TaskContext::default());
972        let schema = Arc::new(Schema::new(vec![
973            Field::new("a", DataType::Float32, true),
974            Field::new("b", DataType::Float64, true),
975            Field::new("c", DataType::Float64, true),
976        ]));
977        let option_asc = SortOptions {
978            descending: false,
979            nulls_first: true,
980        };
981        let option_desc = SortOptions {
982            descending: true,
983            nulls_first: true,
984        };
985
986        // define data.
987        let batch = RecordBatch::try_new(
988            Arc::clone(&schema),
989            vec![
990                Arc::new(Float32Array::from(vec![
991                    Some(1.0_f32),
992                    Some(1.0_f32),
993                    Some(1.0_f32),
994                    Some(2.0_f32),
995                    Some(2.0_f32),
996                    Some(3.0_f32),
997                    Some(3.0_f32),
998                    Some(3.0_f32),
999                ])),
1000                Arc::new(Float64Array::from(vec![
1001                    Some(20.0_f64),
1002                    Some(20.0_f64),
1003                    Some(40.0_f64),
1004                    Some(40.0_f64),
1005                    Some(f64::NAN),
1006                    None,
1007                    None,
1008                    Some(f64::NAN),
1009                ])),
1010                Arc::new(Float64Array::from(vec![
1011                    Some(10.0_f64),
1012                    Some(20.0_f64),
1013                    Some(10.0_f64),
1014                    Some(100.0_f64),
1015                    Some(f64::NAN),
1016                    Some(100.0_f64),
1017                    None,
1018                    Some(f64::NAN),
1019                ])),
1020            ],
1021        )?;
1022
1023        let partial_sort_exec = Arc::new(PartialSortExec::new(
1024            [
1025                PhysicalSortExpr {
1026                    expr: col("a", &schema)?,
1027                    options: option_asc,
1028                },
1029                PhysicalSortExpr {
1030                    expr: col("b", &schema)?,
1031                    options: option_asc,
1032                },
1033                PhysicalSortExpr {
1034                    expr: col("c", &schema)?,
1035                    options: option_desc,
1036                },
1037            ]
1038            .into(),
1039            TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1040            2,
1041        ));
1042
1043        assert_eq!(
1044            DataType::Float32,
1045            *partial_sort_exec.schema().field(0).data_type()
1046        );
1047        assert_eq!(
1048            DataType::Float64,
1049            *partial_sort_exec.schema().field(1).data_type()
1050        );
1051        assert_eq!(
1052            DataType::Float64,
1053            *partial_sort_exec.schema().field(2).data_type()
1054        );
1055
1056        let result: Vec<RecordBatch> = collect(
1057            Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1058            task_ctx,
1059        )
1060        .await?;
1061        assert_snapshot!(batches_to_string(&result), @r"
1062        +-----+------+-------+
1063        | a   | b    | c     |
1064        +-----+------+-------+
1065        | 1.0 | 20.0 | 20.0  |
1066        | 1.0 | 20.0 | 10.0  |
1067        | 1.0 | 40.0 | 10.0  |
1068        | 2.0 | 40.0 | 100.0 |
1069        | 2.0 | NaN  | NaN   |
1070        | 3.0 |      |       |
1071        | 3.0 |      | 100.0 |
1072        | 3.0 | NaN  | NaN   |
1073        +-----+------+-------+
1074        ");
1075        assert_eq!(result.len(), 2);
1076        let metrics = partial_sort_exec.metrics().unwrap();
1077        assert!(metrics.elapsed_compute().unwrap() > 0);
1078        assert_eq!(metrics.output_rows().unwrap(), 8);
1079
1080        let columns = result[0].columns();
1081
1082        assert_eq!(DataType::Float32, *columns[0].data_type());
1083        assert_eq!(DataType::Float64, *columns[1].data_type());
1084        assert_eq!(DataType::Float64, *columns[2].data_type());
1085
1086        Ok(())
1087    }
1088
1089    #[tokio::test]
1090    async fn test_drop_cancel() -> Result<()> {
1091        let task_ctx = Arc::new(TaskContext::default());
1092        let schema = Arc::new(Schema::new(vec![
1093            Field::new("a", DataType::Float32, true),
1094            Field::new("b", DataType::Float32, true),
1095        ]));
1096
1097        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1098        let refs = blocking_exec.refs();
1099        let sort_exec = Arc::new(PartialSortExec::new(
1100            [PhysicalSortExpr {
1101                expr: col("a", &schema)?,
1102                options: SortOptions::default(),
1103            }]
1104            .into(),
1105            blocking_exec,
1106            1,
1107        ));
1108
1109        let fut = collect(sort_exec, Arc::clone(&task_ctx));
1110        let mut fut = fut.boxed();
1111
1112        assert_is_pending(&mut fut);
1113        drop(fut);
1114        assert_strong_count_converges_to_zero(refs).await;
1115
1116        assert_eq!(
1117            task_ctx.runtime_env().memory_pool.reserved(),
1118            0,
1119            "The sort should have returned all memory used back to the memory manager"
1120        );
1121
1122        Ok(())
1123    }
1124
1125    #[tokio::test]
1126    async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1127        // Test case for the bug where batches with homogeneous sort keys
1128        // (e.g., [1,1,1], [2,2,2]) would not be properly detected as having
1129        // slice points between batches.
1130        let task_ctx = Arc::new(TaskContext::default());
1131
1132        // Create batches where each batch has homogeneous values for sort keys
1133        let batch1 = test::build_table_i32(
1134            ("a", &vec![1; 3]),
1135            ("b", &vec![1; 3]),
1136            ("c", &vec![3, 2, 1]),
1137        );
1138        let batch2 = test::build_table_i32(
1139            ("a", &vec![2; 3]),
1140            ("b", &vec![2; 3]),
1141            ("c", &vec![4, 6, 4]),
1142        );
1143        let batch3 = test::build_table_i32(
1144            ("a", &vec![3; 3]),
1145            ("b", &vec![3; 3]),
1146            ("c", &vec![9, 7, 8]),
1147        );
1148
1149        let schema = batch1.schema();
1150        let mem_exec = TestMemoryExec::try_new_exec(
1151            &[vec![batch1, batch2, batch3]],
1152            Arc::clone(&schema),
1153            None,
1154        )?;
1155
1156        let option_asc = SortOptions {
1157            descending: false,
1158            nulls_first: false,
1159        };
1160
1161        // Partial sort with common prefix of 2 (sorting by a, b, c)
1162        let partial_sort_exec = Arc::new(PartialSortExec::new(
1163            [
1164                PhysicalSortExpr {
1165                    expr: col("a", &schema)?,
1166                    options: option_asc,
1167                },
1168                PhysicalSortExpr {
1169                    expr: col("b", &schema)?,
1170                    options: option_asc,
1171                },
1172                PhysicalSortExpr {
1173                    expr: col("c", &schema)?,
1174                    options: option_asc,
1175                },
1176            ]
1177            .into(),
1178            mem_exec,
1179            2,
1180        ));
1181
1182        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1183
1184        assert_eq!(result.len(), 3,);
1185
1186        allow_duplicates! {
1187            assert_snapshot!(batches_to_string(&result), @r"
1188            +---+---+---+
1189            | a | b | c |
1190            +---+---+---+
1191            | 1 | 1 | 1 |
1192            | 1 | 1 | 2 |
1193            | 1 | 1 | 3 |
1194            | 2 | 2 | 4 |
1195            | 2 | 2 | 4 |
1196            | 2 | 2 | 6 |
1197            | 3 | 3 | 7 |
1198            | 3 | 3 | 8 |
1199            | 3 | 3 | 9 |
1200            +---+---+---+
1201            ");
1202        }
1203
1204        assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1205        Ok(())
1206    }
1207}