Skip to main content

datafusion_physical_plan/sorts/
partial_sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partial Sort deals with input data that partially
19//! satisfies the required sort order. Such an input data can be
20//! partitioned into segments where each segment already has the
21//! required information for lexicographic sorting so sorting
22//! can be done without loading the entire dataset.
23//!
24//! Consider a sort plan having an input with ordering `a ASC, b ASC`
25//!
26//! ```text
27//! +---+---+---+
28//! | a | b | d |
29//! +---+---+---+
30//! | 0 | 0 | 3 |
31//! | 0 | 0 | 2 |
32//! | 0 | 1 | 1 |
33//! | 0 | 2 | 0 |
34//! +---+---+---+
35//! ```
36//!
37//! and required ordering for the plan is `a ASC, b ASC, d ASC`.
38//! The first 3 rows(segment) can be sorted as the segment already
39//! has the required information for the sort, but the last row
40//! requires further information as the input can continue with a
41//! batch with a starting row where a and b does not change as below
42//!
43//! ```text
44//! +---+---+---+
45//! | a | b | d |
46//! +---+---+---+
47//! | 0 | 2 | 4 |
48//! +---+---+---+
49//! ```
50//!
51//! The plan concats incoming data with such last rows of previous input
52//! and continues partial sorting of the segments.
53
54use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65    check_if_same_properties,
66};
67
68use arrow::compute::concat_batches;
69use arrow::datatypes::SchemaRef;
70use arrow::record_batch::RecordBatch;
71use datafusion_common::Result;
72use datafusion_common::utils::evaluate_partition_ranges;
73use datafusion_execution::{RecordBatchStream, TaskContext};
74use datafusion_physical_expr::LexOrdering;
75
76use futures::{Stream, StreamExt, ready};
77use log::trace;
78
79/// Partial Sort execution plan.
80#[derive(Debug, Clone)]
81pub struct PartialSortExec {
82    /// Input schema
83    pub(crate) input: Arc<dyn ExecutionPlan>,
84    /// Sort expressions
85    expr: LexOrdering,
86    /// Length of continuous matching columns of input that satisfy
87    /// the required ordering for the sort
88    common_prefix_length: usize,
89    /// Containing all metrics set created during sort
90    metrics_set: ExecutionPlanMetricsSet,
91    /// Preserve partitions of input plan. If false, the input partitions
92    /// will be sorted and merged into a single output partition.
93    preserve_partitioning: bool,
94    /// Fetch highest/lowest n results
95    fetch: Option<usize>,
96    /// Cache holding plan properties like equivalences, output partitioning etc.
97    cache: Arc<PlanProperties>,
98}
99
100impl PartialSortExec {
101    /// Create a new partial sort execution plan
102    pub fn new(
103        expr: LexOrdering,
104        input: Arc<dyn ExecutionPlan>,
105        common_prefix_length: usize,
106    ) -> Self {
107        debug_assert!(common_prefix_length > 0);
108        let preserve_partitioning = false;
109        let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
110            .unwrap();
111        Self {
112            input,
113            expr,
114            common_prefix_length,
115            metrics_set: ExecutionPlanMetricsSet::new(),
116            preserve_partitioning,
117            fetch: None,
118            cache: Arc::new(cache),
119        }
120    }
121
122    /// Whether this `PartialSortExec` preserves partitioning of the children
123    pub fn preserve_partitioning(&self) -> bool {
124        self.preserve_partitioning
125    }
126
127    /// Specify the partitioning behavior of this partial sort exec
128    ///
129    /// If `preserve_partitioning` is true, sorts each partition
130    /// individually, producing one sorted stream for each input partition.
131    ///
132    /// If `preserve_partitioning` is false, sorts and merges all
133    /// input partitions producing a single, sorted partition.
134    pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
135        self.preserve_partitioning = preserve_partitioning;
136        Arc::make_mut(&mut self.cache).partitioning =
137            Self::output_partitioning_helper(&self.input, self.preserve_partitioning);
138        self
139    }
140
141    /// Modify how many rows to include in the result
142    ///
143    /// If None, then all rows will be returned, in sorted order.
144    /// If Some, then only the top `fetch` rows will be returned.
145    /// This can reduce the memory pressure required by the sort
146    /// operation since rows that are not going to be included
147    /// can be dropped.
148    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
149        self.fetch = fetch;
150        self
151    }
152
153    /// Input schema
154    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
155        &self.input
156    }
157
158    /// Sort expressions
159    pub fn expr(&self) -> &LexOrdering {
160        &self.expr
161    }
162
163    /// If `Some(fetch)`, limits output to only the first "fetch" items
164    pub fn fetch(&self) -> Option<usize> {
165        self.fetch
166    }
167
168    /// Common prefix length
169    pub fn common_prefix_length(&self) -> usize {
170        self.common_prefix_length
171    }
172
173    fn output_partitioning_helper(
174        input: &Arc<dyn ExecutionPlan>,
175        preserve_partitioning: bool,
176    ) -> Partitioning {
177        // Get output partitioning:
178        if preserve_partitioning {
179            input.output_partitioning().clone()
180        } else {
181            Partitioning::UnknownPartitioning(1)
182        }
183    }
184
185    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
186    fn compute_properties(
187        input: &Arc<dyn ExecutionPlan>,
188        sort_exprs: LexOrdering,
189        preserve_partitioning: bool,
190    ) -> Result<PlanProperties> {
191        // Calculate equivalence properties; i.e. reset the ordering equivalence
192        // class with the new ordering:
193        let mut eq_properties = input.equivalence_properties().clone();
194        eq_properties.reorder(sort_exprs)?;
195
196        // Get output partitioning:
197        let output_partitioning =
198            Self::output_partitioning_helper(input, preserve_partitioning);
199
200        Ok(PlanProperties::new(
201            eq_properties,
202            output_partitioning,
203            input.pipeline_behavior(),
204            input.boundedness(),
205        ))
206    }
207
208    fn with_new_children_and_same_properties(
209        &self,
210        mut children: Vec<Arc<dyn ExecutionPlan>>,
211    ) -> Self {
212        Self {
213            input: children.swap_remove(0),
214            metrics_set: ExecutionPlanMetricsSet::new(),
215            ..Self::clone(self)
216        }
217    }
218}
219
220impl DisplayAs for PartialSortExec {
221    fn fmt_as(
222        &self,
223        t: DisplayFormatType,
224        f: &mut std::fmt::Formatter,
225    ) -> std::fmt::Result {
226        match t {
227            DisplayFormatType::Default | DisplayFormatType::Verbose => {
228                let common_prefix_length = self.common_prefix_length;
229                match self.fetch {
230                    Some(fetch) => {
231                        write!(
232                            f,
233                            "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]",
234                            self.expr
235                        )
236                    }
237                    None => write!(
238                        f,
239                        "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]",
240                        self.expr
241                    ),
242                }
243            }
244            DisplayFormatType::TreeRender => match self.fetch {
245                Some(fetch) => {
246                    writeln!(f, "{}", self.expr)?;
247                    writeln!(f, "limit={fetch}")
248                }
249                None => {
250                    writeln!(f, "{}", self.expr)
251                }
252            },
253        }
254    }
255}
256
257impl ExecutionPlan for PartialSortExec {
258    fn name(&self) -> &'static str {
259        "PartialSortExec"
260    }
261
262    fn as_any(&self) -> &dyn Any {
263        self
264    }
265
266    fn properties(&self) -> &Arc<PlanProperties> {
267        &self.cache
268    }
269
270    fn fetch(&self) -> Option<usize> {
271        self.fetch
272    }
273
274    fn required_input_distribution(&self) -> Vec<Distribution> {
275        if self.preserve_partitioning {
276            vec![Distribution::UnspecifiedDistribution]
277        } else {
278            vec![Distribution::SinglePartition]
279        }
280    }
281
282    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
283        vec![false]
284    }
285
286    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
287        vec![&self.input]
288    }
289
290    fn with_new_children(
291        self: Arc<Self>,
292        children: Vec<Arc<dyn ExecutionPlan>>,
293    ) -> Result<Arc<dyn ExecutionPlan>> {
294        check_if_same_properties!(self, children);
295        let new_partial_sort = PartialSortExec::new(
296            self.expr.clone(),
297            Arc::clone(&children[0]),
298            self.common_prefix_length,
299        )
300        .with_fetch(self.fetch)
301        .with_preserve_partitioning(self.preserve_partitioning);
302
303        Ok(Arc::new(new_partial_sort))
304    }
305
306    fn execute(
307        &self,
308        partition: usize,
309        context: Arc<TaskContext>,
310    ) -> Result<SendableRecordBatchStream> {
311        trace!(
312            "Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}",
313            partition,
314            context.session_id(),
315            context.task_id()
316        );
317
318        let input = self.input.execute(partition, Arc::clone(&context))?;
319
320        trace!("End PartialSortExec's input.execute for partition: {partition}");
321
322        // Make sure common prefix length is larger than 0
323        // Otherwise, we should use SortExec.
324        debug_assert!(self.common_prefix_length > 0);
325
326        Ok(Box::pin(PartialSortStream {
327            input,
328            expr: self.expr.clone(),
329            common_prefix_length: self.common_prefix_length,
330            in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
331            fetch: self.fetch,
332            is_closed: false,
333            baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
334        }))
335    }
336
337    fn metrics(&self) -> Option<MetricsSet> {
338        Some(self.metrics_set.clone_inner())
339    }
340
341    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
342        self.input.partition_statistics(partition)
343    }
344}
345
346struct PartialSortStream {
347    /// The input plan
348    input: SendableRecordBatchStream,
349    /// Sort expressions
350    expr: LexOrdering,
351    /// Length of prefix common to input ordering and required ordering of plan
352    /// should be more than 0 otherwise PartialSort is not applicable
353    common_prefix_length: usize,
354    /// Used as a buffer for part of the input not ready for sort
355    in_mem_batch: RecordBatch,
356    /// Fetch top N results
357    fetch: Option<usize>,
358    /// Whether the stream has finished returning all of its data or not
359    is_closed: bool,
360    /// Execution metrics
361    baseline_metrics: BaselineMetrics,
362}
363
364impl Stream for PartialSortStream {
365    type Item = Result<RecordBatch>;
366
367    fn poll_next(
368        mut self: Pin<&mut Self>,
369        cx: &mut Context<'_>,
370    ) -> Poll<Option<Self::Item>> {
371        let poll = self.poll_next_inner(cx);
372        self.baseline_metrics.record_poll(poll)
373    }
374
375    fn size_hint(&self) -> (usize, Option<usize>) {
376        // we can't predict the size of incoming batches so re-use the size hint from the input
377        self.input.size_hint()
378    }
379}
380
381impl RecordBatchStream for PartialSortStream {
382    fn schema(&self) -> SchemaRef {
383        self.input.schema()
384    }
385}
386
387impl PartialSortStream {
388    fn poll_next_inner(
389        self: &mut Pin<&mut Self>,
390        cx: &mut Context<'_>,
391    ) -> Poll<Option<Result<RecordBatch>>> {
392        if self.is_closed {
393            return Poll::Ready(None);
394        }
395        loop {
396            // Check if we've already reached the fetch limit
397            if self.fetch == Some(0) {
398                self.is_closed = true;
399                return Poll::Ready(None);
400            }
401
402            match ready!(self.input.poll_next_unpin(cx)) {
403                Some(Ok(batch)) => {
404                    // Merge new batch into in_mem_batch
405                    self.in_mem_batch = concat_batches(
406                        &self.schema(),
407                        &[self.in_mem_batch.clone(), batch],
408                    )?;
409
410                    // Check if we have a slice point, otherwise keep accumulating in `self.in_mem_batch`.
411                    if let Some(slice_point) = self
412                        .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
413                    {
414                        let sorted = self.in_mem_batch.slice(0, slice_point);
415                        self.in_mem_batch = self.in_mem_batch.slice(
416                            slice_point,
417                            self.in_mem_batch.num_rows() - slice_point,
418                        );
419                        let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
420                        if let Some(fetch) = self.fetch.as_mut() {
421                            *fetch -= sorted_batch.num_rows();
422                        }
423
424                        if sorted_batch.num_rows() > 0 {
425                            return Poll::Ready(Some(Ok(sorted_batch)));
426                        }
427                    }
428                }
429                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
430                None => {
431                    self.is_closed = true;
432                    // Once input is consumed, sort the rest of the inserted batches
433                    let remaining_batch = self.sort_in_mem_batch()?;
434                    return if remaining_batch.num_rows() > 0 {
435                        Poll::Ready(Some(Ok(remaining_batch)))
436                    } else {
437                        Poll::Ready(None)
438                    };
439                }
440            };
441        }
442    }
443
444    /// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
445    ///
446    /// If fetch is specified for PartialSortStream `sort_in_mem_batch` will limit
447    /// the last RecordBatch returned and will mark the stream as closed
448    fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
449        let input_batch = self.in_mem_batch.clone();
450        self.in_mem_batch = RecordBatch::new_empty(self.schema());
451        let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
452        if let Some(remaining_fetch) = self.fetch {
453            // remaining_fetch - result.num_rows() is always be >= 0
454            // because result length of sort_batch with limit cannot be
455            // more than the requested limit
456            self.fetch = Some(remaining_fetch - result.num_rows());
457            if remaining_fetch == result.num_rows() {
458                self.is_closed = true;
459            }
460        }
461        Ok(result)
462    }
463
464    /// Return the end index of the second last partition if the batch
465    /// can be partitioned based on its already sorted columns
466    ///
467    /// Return None if the batch cannot be partitioned, which means the
468    /// batch does not have the information for a safe sort
469    fn get_slice_point(
470        &self,
471        common_prefix_len: usize,
472        batch: &RecordBatch,
473    ) -> Result<Option<usize>> {
474        let common_prefix_sort_keys = (0..common_prefix_len)
475            .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
476            .collect::<Result<Vec<_>>>()?;
477        let partition_points =
478            evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
479        // If partition points are [0..100], [100..200], [200..300]
480        // we should return 200, which is the safest and furthest partition boundary
481        // Please note that we shouldn't return 300 (which is number of rows in the batch),
482        // because this boundary may change with new data.
483        if partition_points.len() >= 2 {
484            Ok(Some(partition_points[partition_points.len() - 2].end))
485        } else {
486            Ok(None)
487        }
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use std::collections::HashMap;
494
495    use arrow::array::*;
496    use arrow::compute::SortOptions;
497    use arrow::datatypes::*;
498    use datafusion_common::test_util::batches_to_string;
499    use futures::FutureExt;
500    use insta::allow_duplicates;
501    use insta::assert_snapshot;
502    use itertools::Itertools;
503
504    use crate::collect;
505    use crate::expressions::PhysicalSortExpr;
506    use crate::expressions::col;
507    use crate::sorts::sort::SortExec;
508    use crate::test;
509    use crate::test::TestMemoryExec;
510    use crate::test::assert_is_pending;
511    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
512
513    use super::*;
514
515    #[tokio::test]
516    async fn test_partial_sort() -> Result<()> {
517        let task_ctx = Arc::new(TaskContext::default());
518        let source = test::build_table_scan_i32(
519            ("a", &vec![0, 0, 0, 1, 1, 1]),
520            ("b", &vec![1, 1, 2, 2, 3, 3]),
521            ("c", &vec![1, 0, 5, 4, 3, 2]),
522        );
523        let schema = Schema::new(vec![
524            Field::new("a", DataType::Int32, false),
525            Field::new("b", DataType::Int32, false),
526            Field::new("c", DataType::Int32, false),
527        ]);
528        let option_asc = SortOptions {
529            descending: false,
530            nulls_first: false,
531        };
532
533        let partial_sort_exec = Arc::new(PartialSortExec::new(
534            [
535                PhysicalSortExpr {
536                    expr: col("a", &schema)?,
537                    options: option_asc,
538                },
539                PhysicalSortExpr {
540                    expr: col("b", &schema)?,
541                    options: option_asc,
542                },
543                PhysicalSortExpr {
544                    expr: col("c", &schema)?,
545                    options: option_asc,
546                },
547            ]
548            .into(),
549            Arc::clone(&source),
550            2,
551        ));
552
553        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
554
555        assert_eq!(2, result.len());
556        allow_duplicates! {
557            assert_snapshot!(batches_to_string(&result), @r"
558            +---+---+---+
559            | a | b | c |
560            +---+---+---+
561            | 0 | 1 | 0 |
562            | 0 | 1 | 1 |
563            | 0 | 2 | 5 |
564            | 1 | 2 | 4 |
565            | 1 | 3 | 2 |
566            | 1 | 3 | 3 |
567            +---+---+---+
568            ");
569        }
570        assert_eq!(
571            task_ctx.runtime_env().memory_pool.reserved(),
572            0,
573            "The sort should have returned all memory used back to the memory manager"
574        );
575
576        Ok(())
577    }
578
579    #[tokio::test]
580    async fn test_partial_sort_with_fetch() -> Result<()> {
581        let task_ctx = Arc::new(TaskContext::default());
582        let source = test::build_table_scan_i32(
583            ("a", &vec![0, 0, 1, 1, 1]),
584            ("b", &vec![1, 2, 2, 3, 3]),
585            ("c", &vec![4, 3, 2, 1, 0]),
586        );
587        let schema = Schema::new(vec![
588            Field::new("a", DataType::Int32, false),
589            Field::new("b", DataType::Int32, false),
590            Field::new("c", DataType::Int32, false),
591        ]);
592        let option_asc = SortOptions {
593            descending: false,
594            nulls_first: false,
595        };
596
597        for common_prefix_length in [1, 2] {
598            let partial_sort_exec = Arc::new(
599                PartialSortExec::new(
600                    [
601                        PhysicalSortExpr {
602                            expr: col("a", &schema)?,
603                            options: option_asc,
604                        },
605                        PhysicalSortExpr {
606                            expr: col("b", &schema)?,
607                            options: option_asc,
608                        },
609                        PhysicalSortExpr {
610                            expr: col("c", &schema)?,
611                            options: option_asc,
612                        },
613                    ]
614                    .into(),
615                    Arc::clone(&source),
616                    common_prefix_length,
617                )
618                .with_fetch(Some(4)),
619            );
620
621            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
622
623            assert_eq!(2, result.len());
624            allow_duplicates! {
625                assert_snapshot!(batches_to_string(&result), @r"
626                +---+---+---+
627                | a | b | c |
628                +---+---+---+
629                | 0 | 1 | 4 |
630                | 0 | 2 | 3 |
631                | 1 | 2 | 2 |
632                | 1 | 3 | 0 |
633                +---+---+---+
634                ");
635            }
636            assert_eq!(
637                task_ctx.runtime_env().memory_pool.reserved(),
638                0,
639                "The sort should have returned all memory used back to the memory manager"
640            );
641        }
642
643        Ok(())
644    }
645
646    #[tokio::test]
647    async fn test_partial_sort2() -> Result<()> {
648        let task_ctx = Arc::new(TaskContext::default());
649        let source_tables = [
650            test::build_table_scan_i32(
651                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
652                ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
653                ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
654            ),
655            test::build_table_scan_i32(
656                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
657                ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
658                ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
659            ),
660        ];
661        let schema = Schema::new(vec![
662            Field::new("a", DataType::Int32, false),
663            Field::new("b", DataType::Int32, false),
664            Field::new("c", DataType::Int32, false),
665        ]);
666        let option_asc = SortOptions {
667            descending: false,
668            nulls_first: false,
669        };
670        for (common_prefix_length, source) in
671            [(1, &source_tables[0]), (2, &source_tables[1])]
672        {
673            let partial_sort_exec = Arc::new(PartialSortExec::new(
674                [
675                    PhysicalSortExpr {
676                        expr: col("a", &schema)?,
677                        options: option_asc,
678                    },
679                    PhysicalSortExpr {
680                        expr: col("b", &schema)?,
681                        options: option_asc,
682                    },
683                    PhysicalSortExpr {
684                        expr: col("c", &schema)?,
685                        options: option_asc,
686                    },
687                ]
688                .into(),
689                Arc::clone(source),
690                common_prefix_length,
691            ));
692
693            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
694            assert_eq!(2, result.len());
695            assert_eq!(
696                task_ctx.runtime_env().memory_pool.reserved(),
697                0,
698                "The sort should have returned all memory used back to the memory manager"
699            );
700            allow_duplicates! {
701                assert_snapshot!(batches_to_string(&result), @r"
702                +---+---+---+
703                | a | b | c |
704                +---+---+---+
705                | 0 | 1 | 6 |
706                | 0 | 1 | 7 |
707                | 0 | 3 | 4 |
708                | 0 | 3 | 5 |
709                | 1 | 2 | 0 |
710                | 1 | 2 | 1 |
711                | 1 | 4 | 2 |
712                | 1 | 4 | 3 |
713                +---+---+---+
714                ");
715            }
716        }
717        Ok(())
718    }
719
720    fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
721        let batch1 = test::build_table_i32(
722            ("a", &vec![1; 100]),
723            ("b", &(0..100).rev().collect()),
724            ("c", &(0..100).rev().collect()),
725        );
726        let batch2 = test::build_table_i32(
727            ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
728            ("b", &(100..200).rev().collect()),
729            ("c", &(0..100).collect()),
730        );
731        let batch3 = test::build_table_i32(
732            ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
733            ("b", &(150..250).rev().collect()),
734            ("c", &(0..100).rev().collect()),
735        );
736        let batch4 = test::build_table_i32(
737            ("a", &vec![4; 100]),
738            ("b", &(50..150).rev().collect()),
739            ("c", &(0..100).rev().collect()),
740        );
741        let schema = batch1.schema();
742
743        TestMemoryExec::try_new_exec(
744            &[vec![batch1, batch2, batch3, batch4]],
745            Arc::clone(&schema),
746            None,
747        )
748        .unwrap() as Arc<dyn ExecutionPlan>
749    }
750
751    #[tokio::test]
752    async fn test_partitioned_input_partial_sort() -> Result<()> {
753        let task_ctx = Arc::new(TaskContext::default());
754        let mem_exec = prepare_partitioned_input();
755        let option_asc = SortOptions {
756            descending: false,
757            nulls_first: false,
758        };
759        let option_desc = SortOptions {
760            descending: false,
761            nulls_first: false,
762        };
763        let schema = mem_exec.schema();
764        let partial_sort_exec = PartialSortExec::new(
765            [
766                PhysicalSortExpr {
767                    expr: col("a", &schema)?,
768                    options: option_asc,
769                },
770                PhysicalSortExpr {
771                    expr: col("b", &schema)?,
772                    options: option_desc,
773                },
774                PhysicalSortExpr {
775                    expr: col("c", &schema)?,
776                    options: option_asc,
777                },
778            ]
779            .into(),
780            Arc::clone(&mem_exec),
781            1,
782        );
783        let sort_exec = Arc::new(SortExec::new(
784            partial_sort_exec.expr.clone(),
785            Arc::clone(&partial_sort_exec.input),
786        ));
787        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
788        assert_eq!(
789            result.iter().map(|r| r.num_rows()).collect_vec(),
790            [125, 125, 150]
791        );
792
793        assert_eq!(
794            task_ctx.runtime_env().memory_pool.reserved(),
795            0,
796            "The sort should have returned all memory used back to the memory manager"
797        );
798        let partial_sort_result = concat_batches(&schema, &result).unwrap();
799        let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
800        assert_eq!(sort_result[0], partial_sort_result);
801
802        Ok(())
803    }
804
805    #[tokio::test]
806    async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
807        let task_ctx = Arc::new(TaskContext::default());
808        let mem_exec = prepare_partitioned_input();
809        let schema = mem_exec.schema();
810        let option_asc = SortOptions {
811            descending: false,
812            nulls_first: false,
813        };
814        let option_desc = SortOptions {
815            descending: false,
816            nulls_first: false,
817        };
818        for (fetch_size, expected_batch_num_rows) in [
819            (Some(50), vec![50]),
820            (Some(120), vec![120]),
821            (Some(150), vec![125, 25]),
822            (Some(250), vec![125, 125]),
823        ] {
824            let partial_sort_exec = PartialSortExec::new(
825                [
826                    PhysicalSortExpr {
827                        expr: col("a", &schema)?,
828                        options: option_asc,
829                    },
830                    PhysicalSortExpr {
831                        expr: col("b", &schema)?,
832                        options: option_desc,
833                    },
834                    PhysicalSortExpr {
835                        expr: col("c", &schema)?,
836                        options: option_asc,
837                    },
838                ]
839                .into(),
840                Arc::clone(&mem_exec),
841                1,
842            )
843            .with_fetch(fetch_size);
844
845            let sort_exec = Arc::new(
846                SortExec::new(
847                    partial_sort_exec.expr.clone(),
848                    Arc::clone(&partial_sort_exec.input),
849                )
850                .with_fetch(fetch_size),
851            );
852            let result =
853                collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
854            assert_eq!(
855                result.iter().map(|r| r.num_rows()).collect_vec(),
856                expected_batch_num_rows
857            );
858
859            assert_eq!(
860                task_ctx.runtime_env().memory_pool.reserved(),
861                0,
862                "The sort should have returned all memory used back to the memory manager"
863            );
864            let partial_sort_result = concat_batches(&schema, &result)?;
865            let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
866            assert_eq!(sort_result[0], partial_sort_result);
867        }
868
869        Ok(())
870    }
871
872    #[tokio::test]
873    async fn test_partial_sort_no_empty_batches() -> Result<()> {
874        let task_ctx = Arc::new(TaskContext::default());
875        let mem_exec = prepare_partitioned_input();
876        let schema = mem_exec.schema();
877        let option_asc = SortOptions {
878            descending: false,
879            nulls_first: false,
880        };
881        let fetch_size = Some(250);
882        let partial_sort_exec = PartialSortExec::new(
883            [
884                PhysicalSortExpr {
885                    expr: col("a", &schema)?,
886                    options: option_asc,
887                },
888                PhysicalSortExpr {
889                    expr: col("c", &schema)?,
890                    options: option_asc,
891                },
892            ]
893            .into(),
894            Arc::clone(&mem_exec),
895            1,
896        )
897        .with_fetch(fetch_size);
898
899        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
900        for rb in result {
901            assert!(rb.num_rows() > 0);
902        }
903
904        Ok(())
905    }
906
907    #[tokio::test]
908    async fn test_sort_metadata() -> Result<()> {
909        let task_ctx = Arc::new(TaskContext::default());
910        let field_metadata: HashMap<String, String> =
911            vec![("foo".to_string(), "bar".to_string())]
912                .into_iter()
913                .collect();
914        let schema_metadata: HashMap<String, String> =
915            vec![("baz".to_string(), "barf".to_string())]
916                .into_iter()
917                .collect();
918
919        let mut field = Field::new("field_name", DataType::UInt64, true);
920        field.set_metadata(field_metadata.clone());
921        let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
922        let schema = Arc::new(schema);
923
924        let data: ArrayRef =
925            Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
926
927        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
928        let input =
929            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
930
931        let partial_sort_exec = Arc::new(PartialSortExec::new(
932            [PhysicalSortExpr {
933                expr: col("field_name", &schema)?,
934                options: SortOptions::default(),
935            }]
936            .into(),
937            input,
938            1,
939        ));
940
941        let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
942        let expected_batch = vec![
943            RecordBatch::try_new(
944                Arc::clone(&schema),
945                vec![Arc::new(
946                    vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
947                )],
948            )?,
949            RecordBatch::try_new(
950                Arc::clone(&schema),
951                vec![Arc::new(
952                    vec![2].into_iter().map(Some).collect::<UInt64Array>(),
953                )],
954            )?,
955        ];
956
957        // Data is correct
958        assert_eq!(&expected_batch, &result);
959
960        // explicitly ensure the metadata is present
961        assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
962        assert_eq!(result[0].schema().metadata(), &schema_metadata);
963
964        Ok(())
965    }
966
967    #[tokio::test]
968    async fn test_lex_sort_by_float() -> Result<()> {
969        let task_ctx = Arc::new(TaskContext::default());
970        let schema = Arc::new(Schema::new(vec![
971            Field::new("a", DataType::Float32, true),
972            Field::new("b", DataType::Float64, true),
973            Field::new("c", DataType::Float64, true),
974        ]));
975        let option_asc = SortOptions {
976            descending: false,
977            nulls_first: true,
978        };
979        let option_desc = SortOptions {
980            descending: true,
981            nulls_first: true,
982        };
983
984        // define data.
985        let batch = RecordBatch::try_new(
986            Arc::clone(&schema),
987            vec![
988                Arc::new(Float32Array::from(vec![
989                    Some(1.0_f32),
990                    Some(1.0_f32),
991                    Some(1.0_f32),
992                    Some(2.0_f32),
993                    Some(2.0_f32),
994                    Some(3.0_f32),
995                    Some(3.0_f32),
996                    Some(3.0_f32),
997                ])),
998                Arc::new(Float64Array::from(vec![
999                    Some(20.0_f64),
1000                    Some(20.0_f64),
1001                    Some(40.0_f64),
1002                    Some(40.0_f64),
1003                    Some(f64::NAN),
1004                    None,
1005                    None,
1006                    Some(f64::NAN),
1007                ])),
1008                Arc::new(Float64Array::from(vec![
1009                    Some(10.0_f64),
1010                    Some(20.0_f64),
1011                    Some(10.0_f64),
1012                    Some(100.0_f64),
1013                    Some(f64::NAN),
1014                    Some(100.0_f64),
1015                    None,
1016                    Some(f64::NAN),
1017                ])),
1018            ],
1019        )?;
1020
1021        let partial_sort_exec = Arc::new(PartialSortExec::new(
1022            [
1023                PhysicalSortExpr {
1024                    expr: col("a", &schema)?,
1025                    options: option_asc,
1026                },
1027                PhysicalSortExpr {
1028                    expr: col("b", &schema)?,
1029                    options: option_asc,
1030                },
1031                PhysicalSortExpr {
1032                    expr: col("c", &schema)?,
1033                    options: option_desc,
1034                },
1035            ]
1036            .into(),
1037            TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1038            2,
1039        ));
1040
1041        assert_eq!(
1042            DataType::Float32,
1043            *partial_sort_exec.schema().field(0).data_type()
1044        );
1045        assert_eq!(
1046            DataType::Float64,
1047            *partial_sort_exec.schema().field(1).data_type()
1048        );
1049        assert_eq!(
1050            DataType::Float64,
1051            *partial_sort_exec.schema().field(2).data_type()
1052        );
1053
1054        let result: Vec<RecordBatch> = collect(
1055            Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1056            task_ctx,
1057        )
1058        .await?;
1059        assert_snapshot!(batches_to_string(&result), @r"
1060        +-----+------+-------+
1061        | a   | b    | c     |
1062        +-----+------+-------+
1063        | 1.0 | 20.0 | 20.0  |
1064        | 1.0 | 20.0 | 10.0  |
1065        | 1.0 | 40.0 | 10.0  |
1066        | 2.0 | 40.0 | 100.0 |
1067        | 2.0 | NaN  | NaN   |
1068        | 3.0 |      |       |
1069        | 3.0 |      | 100.0 |
1070        | 3.0 | NaN  | NaN   |
1071        +-----+------+-------+
1072        ");
1073        assert_eq!(result.len(), 2);
1074        let metrics = partial_sort_exec.metrics().unwrap();
1075        assert!(metrics.elapsed_compute().unwrap() > 0);
1076        assert_eq!(metrics.output_rows().unwrap(), 8);
1077
1078        let columns = result[0].columns();
1079
1080        assert_eq!(DataType::Float32, *columns[0].data_type());
1081        assert_eq!(DataType::Float64, *columns[1].data_type());
1082        assert_eq!(DataType::Float64, *columns[2].data_type());
1083
1084        Ok(())
1085    }
1086
1087    #[tokio::test]
1088    async fn test_drop_cancel() -> Result<()> {
1089        let task_ctx = Arc::new(TaskContext::default());
1090        let schema = Arc::new(Schema::new(vec![
1091            Field::new("a", DataType::Float32, true),
1092            Field::new("b", DataType::Float32, true),
1093        ]));
1094
1095        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1096        let refs = blocking_exec.refs();
1097        let sort_exec = Arc::new(PartialSortExec::new(
1098            [PhysicalSortExpr {
1099                expr: col("a", &schema)?,
1100                options: SortOptions::default(),
1101            }]
1102            .into(),
1103            blocking_exec,
1104            1,
1105        ));
1106
1107        let fut = collect(sort_exec, Arc::clone(&task_ctx));
1108        let mut fut = fut.boxed();
1109
1110        assert_is_pending(&mut fut);
1111        drop(fut);
1112        assert_strong_count_converges_to_zero(refs).await;
1113
1114        assert_eq!(
1115            task_ctx.runtime_env().memory_pool.reserved(),
1116            0,
1117            "The sort should have returned all memory used back to the memory manager"
1118        );
1119
1120        Ok(())
1121    }
1122
1123    #[tokio::test]
1124    async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1125        // Test case for the bug where batches with homogeneous sort keys
1126        // (e.g., [1,1,1], [2,2,2]) would not be properly detected as having
1127        // slice points between batches.
1128        let task_ctx = Arc::new(TaskContext::default());
1129
1130        // Create batches where each batch has homogeneous values for sort keys
1131        let batch1 = test::build_table_i32(
1132            ("a", &vec![1; 3]),
1133            ("b", &vec![1; 3]),
1134            ("c", &vec![3, 2, 1]),
1135        );
1136        let batch2 = test::build_table_i32(
1137            ("a", &vec![2; 3]),
1138            ("b", &vec![2; 3]),
1139            ("c", &vec![4, 6, 4]),
1140        );
1141        let batch3 = test::build_table_i32(
1142            ("a", &vec![3; 3]),
1143            ("b", &vec![3; 3]),
1144            ("c", &vec![9, 7, 8]),
1145        );
1146
1147        let schema = batch1.schema();
1148        let mem_exec = TestMemoryExec::try_new_exec(
1149            &[vec![batch1, batch2, batch3]],
1150            Arc::clone(&schema),
1151            None,
1152        )?;
1153
1154        let option_asc = SortOptions {
1155            descending: false,
1156            nulls_first: false,
1157        };
1158
1159        // Partial sort with common prefix of 2 (sorting by a, b, c)
1160        let partial_sort_exec = Arc::new(PartialSortExec::new(
1161            [
1162                PhysicalSortExpr {
1163                    expr: col("a", &schema)?,
1164                    options: option_asc,
1165                },
1166                PhysicalSortExpr {
1167                    expr: col("b", &schema)?,
1168                    options: option_asc,
1169                },
1170                PhysicalSortExpr {
1171                    expr: col("c", &schema)?,
1172                    options: option_asc,
1173                },
1174            ]
1175            .into(),
1176            mem_exec,
1177            2,
1178        ));
1179
1180        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1181
1182        assert_eq!(result.len(), 3,);
1183
1184        allow_duplicates! {
1185            assert_snapshot!(batches_to_string(&result), @r"
1186            +---+---+---+
1187            | a | b | c |
1188            +---+---+---+
1189            | 1 | 1 | 1 |
1190            | 1 | 1 | 2 |
1191            | 1 | 1 | 3 |
1192            | 2 | 2 | 4 |
1193            | 2 | 2 | 4 |
1194            | 2 | 2 | 6 |
1195            | 3 | 3 | 7 |
1196            | 3 | 3 | 8 |
1197            | 3 | 3 | 9 |
1198            +---+---+---+
1199            ");
1200        }
1201
1202        assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1203        Ok(())
1204    }
1205}