Skip to main content

datafusion_physical_plan/sorts/
partial_sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partial Sort deals with input data that partially
19//! satisfies the required sort order. Such an input data can be
20//! partitioned into segments where each segment already has the
21//! required information for lexicographic sorting so sorting
22//! can be done without loading the entire dataset.
23//!
24//! Consider a sort plan having an input with ordering `a ASC, b ASC`
25//!
26//! ```text
27//! +---+---+---+
28//! | a | b | d |
29//! +---+---+---+
30//! | 0 | 0 | 3 |
31//! | 0 | 0 | 2 |
32//! | 0 | 1 | 1 |
33//! | 0 | 2 | 0 |
34//! +---+---+---+
35//! ```
36//!
37//! and required ordering for the plan is `a ASC, b ASC, d ASC`.
38//! The first 3 rows(segment) can be sorted as the segment already
39//! has the required information for the sort, but the last row
40//! requires further information as the input can continue with a
41//! batch with a starting row where a and b does not change as below
42//!
43//! ```text
44//! +---+---+---+
45//! | a | b | d |
46//! +---+---+---+
47//! | 0 | 2 | 4 |
48//! +---+---+---+
49//! ```
50//!
51//! The plan concats incoming data with such last rows of previous input
52//! and continues partial sorting of the segments.
53
54use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65};
66
67use arrow::compute::concat_batches;
68use arrow::datatypes::SchemaRef;
69use arrow::record_batch::RecordBatch;
70use datafusion_common::Result;
71use datafusion_common::utils::evaluate_partition_ranges;
72use datafusion_execution::{RecordBatchStream, TaskContext};
73use datafusion_physical_expr::LexOrdering;
74
75use futures::{Stream, StreamExt, ready};
76use log::trace;
77
78/// Partial Sort execution plan.
79#[derive(Debug, Clone)]
80pub struct PartialSortExec {
81    /// Input schema
82    pub(crate) input: Arc<dyn ExecutionPlan>,
83    /// Sort expressions
84    expr: LexOrdering,
85    /// Length of continuous matching columns of input that satisfy
86    /// the required ordering for the sort
87    common_prefix_length: usize,
88    /// Containing all metrics set created during sort
89    metrics_set: ExecutionPlanMetricsSet,
90    /// Preserve partitions of input plan. If false, the input partitions
91    /// will be sorted and merged into a single output partition.
92    preserve_partitioning: bool,
93    /// Fetch highest/lowest n results
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97}
98
99impl PartialSortExec {
100    /// Create a new partial sort execution plan
101    pub fn new(
102        expr: LexOrdering,
103        input: Arc<dyn ExecutionPlan>,
104        common_prefix_length: usize,
105    ) -> Self {
106        debug_assert!(common_prefix_length > 0);
107        let preserve_partitioning = false;
108        let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
109            .unwrap();
110        Self {
111            input,
112            expr,
113            common_prefix_length,
114            metrics_set: ExecutionPlanMetricsSet::new(),
115            preserve_partitioning,
116            fetch: None,
117            cache,
118        }
119    }
120
121    /// Whether this `PartialSortExec` preserves partitioning of the children
122    pub fn preserve_partitioning(&self) -> bool {
123        self.preserve_partitioning
124    }
125
126    /// Specify the partitioning behavior of this partial sort exec
127    ///
128    /// If `preserve_partitioning` is true, sorts each partition
129    /// individually, producing one sorted stream for each input partition.
130    ///
131    /// If `preserve_partitioning` is false, sorts and merges all
132    /// input partitions producing a single, sorted partition.
133    pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
134        self.preserve_partitioning = preserve_partitioning;
135        self.cache = self
136            .cache
137            .with_partitioning(Self::output_partitioning_helper(
138                &self.input,
139                self.preserve_partitioning,
140            ));
141        self
142    }
143
144    /// Modify how many rows to include in the result
145    ///
146    /// If None, then all rows will be returned, in sorted order.
147    /// If Some, then only the top `fetch` rows will be returned.
148    /// This can reduce the memory pressure required by the sort
149    /// operation since rows that are not going to be included
150    /// can be dropped.
151    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
152        self.fetch = fetch;
153        self
154    }
155
156    /// Input schema
157    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
158        &self.input
159    }
160
161    /// Sort expressions
162    pub fn expr(&self) -> &LexOrdering {
163        &self.expr
164    }
165
166    /// If `Some(fetch)`, limits output to only the first "fetch" items
167    pub fn fetch(&self) -> Option<usize> {
168        self.fetch
169    }
170
171    /// Common prefix length
172    pub fn common_prefix_length(&self) -> usize {
173        self.common_prefix_length
174    }
175
176    fn output_partitioning_helper(
177        input: &Arc<dyn ExecutionPlan>,
178        preserve_partitioning: bool,
179    ) -> Partitioning {
180        // Get output partitioning:
181        if preserve_partitioning {
182            input.output_partitioning().clone()
183        } else {
184            Partitioning::UnknownPartitioning(1)
185        }
186    }
187
188    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
189    fn compute_properties(
190        input: &Arc<dyn ExecutionPlan>,
191        sort_exprs: LexOrdering,
192        preserve_partitioning: bool,
193    ) -> Result<PlanProperties> {
194        // Calculate equivalence properties; i.e. reset the ordering equivalence
195        // class with the new ordering:
196        let mut eq_properties = input.equivalence_properties().clone();
197        eq_properties.reorder(sort_exprs)?;
198
199        // Get output partitioning:
200        let output_partitioning =
201            Self::output_partitioning_helper(input, preserve_partitioning);
202
203        Ok(PlanProperties::new(
204            eq_properties,
205            output_partitioning,
206            input.pipeline_behavior(),
207            input.boundedness(),
208        ))
209    }
210}
211
212impl DisplayAs for PartialSortExec {
213    fn fmt_as(
214        &self,
215        t: DisplayFormatType,
216        f: &mut std::fmt::Formatter,
217    ) -> std::fmt::Result {
218        match t {
219            DisplayFormatType::Default | DisplayFormatType::Verbose => {
220                let common_prefix_length = self.common_prefix_length;
221                match self.fetch {
222                    Some(fetch) => {
223                        write!(
224                            f,
225                            "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]",
226                            self.expr
227                        )
228                    }
229                    None => write!(
230                        f,
231                        "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]",
232                        self.expr
233                    ),
234                }
235            }
236            DisplayFormatType::TreeRender => match self.fetch {
237                Some(fetch) => {
238                    writeln!(f, "{}", self.expr)?;
239                    writeln!(f, "limit={fetch}")
240                }
241                None => {
242                    writeln!(f, "{}", self.expr)
243                }
244            },
245        }
246    }
247}
248
249impl ExecutionPlan for PartialSortExec {
250    fn name(&self) -> &'static str {
251        "PartialSortExec"
252    }
253
254    fn as_any(&self) -> &dyn Any {
255        self
256    }
257
258    fn properties(&self) -> &PlanProperties {
259        &self.cache
260    }
261
262    fn fetch(&self) -> Option<usize> {
263        self.fetch
264    }
265
266    fn required_input_distribution(&self) -> Vec<Distribution> {
267        if self.preserve_partitioning {
268            vec![Distribution::UnspecifiedDistribution]
269        } else {
270            vec![Distribution::SinglePartition]
271        }
272    }
273
274    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
275        vec![false]
276    }
277
278    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
279        vec![&self.input]
280    }
281
282    fn with_new_children(
283        self: Arc<Self>,
284        children: Vec<Arc<dyn ExecutionPlan>>,
285    ) -> Result<Arc<dyn ExecutionPlan>> {
286        let new_partial_sort = PartialSortExec::new(
287            self.expr.clone(),
288            Arc::clone(&children[0]),
289            self.common_prefix_length,
290        )
291        .with_fetch(self.fetch)
292        .with_preserve_partitioning(self.preserve_partitioning);
293
294        Ok(Arc::new(new_partial_sort))
295    }
296
297    fn execute(
298        &self,
299        partition: usize,
300        context: Arc<TaskContext>,
301    ) -> Result<SendableRecordBatchStream> {
302        trace!(
303            "Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}",
304            partition,
305            context.session_id(),
306            context.task_id()
307        );
308
309        let input = self.input.execute(partition, Arc::clone(&context))?;
310
311        trace!("End PartialSortExec's input.execute for partition: {partition}");
312
313        // Make sure common prefix length is larger than 0
314        // Otherwise, we should use SortExec.
315        debug_assert!(self.common_prefix_length > 0);
316
317        Ok(Box::pin(PartialSortStream {
318            input,
319            expr: self.expr.clone(),
320            common_prefix_length: self.common_prefix_length,
321            in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
322            fetch: self.fetch,
323            is_closed: false,
324            baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
325        }))
326    }
327
328    fn metrics(&self) -> Option<MetricsSet> {
329        Some(self.metrics_set.clone_inner())
330    }
331
332    fn statistics(&self) -> Result<Statistics> {
333        self.input.partition_statistics(None)
334    }
335
336    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
337        self.input.partition_statistics(partition)
338    }
339}
340
341struct PartialSortStream {
342    /// The input plan
343    input: SendableRecordBatchStream,
344    /// Sort expressions
345    expr: LexOrdering,
346    /// Length of prefix common to input ordering and required ordering of plan
347    /// should be more than 0 otherwise PartialSort is not applicable
348    common_prefix_length: usize,
349    /// Used as a buffer for part of the input not ready for sort
350    in_mem_batch: RecordBatch,
351    /// Fetch top N results
352    fetch: Option<usize>,
353    /// Whether the stream has finished returning all of its data or not
354    is_closed: bool,
355    /// Execution metrics
356    baseline_metrics: BaselineMetrics,
357}
358
359impl Stream for PartialSortStream {
360    type Item = Result<RecordBatch>;
361
362    fn poll_next(
363        mut self: Pin<&mut Self>,
364        cx: &mut Context<'_>,
365    ) -> Poll<Option<Self::Item>> {
366        let poll = self.poll_next_inner(cx);
367        self.baseline_metrics.record_poll(poll)
368    }
369
370    fn size_hint(&self) -> (usize, Option<usize>) {
371        // we can't predict the size of incoming batches so re-use the size hint from the input
372        self.input.size_hint()
373    }
374}
375
376impl RecordBatchStream for PartialSortStream {
377    fn schema(&self) -> SchemaRef {
378        self.input.schema()
379    }
380}
381
382impl PartialSortStream {
383    fn poll_next_inner(
384        self: &mut Pin<&mut Self>,
385        cx: &mut Context<'_>,
386    ) -> Poll<Option<Result<RecordBatch>>> {
387        if self.is_closed {
388            return Poll::Ready(None);
389        }
390        loop {
391            // Check if we've already reached the fetch limit
392            if self.fetch == Some(0) {
393                self.is_closed = true;
394                return Poll::Ready(None);
395            }
396
397            match ready!(self.input.poll_next_unpin(cx)) {
398                Some(Ok(batch)) => {
399                    // Merge new batch into in_mem_batch
400                    self.in_mem_batch = concat_batches(
401                        &self.schema(),
402                        &[self.in_mem_batch.clone(), batch],
403                    )?;
404
405                    // Check if we have a slice point, otherwise keep accumulating in `self.in_mem_batch`.
406                    if let Some(slice_point) = self
407                        .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
408                    {
409                        let sorted = self.in_mem_batch.slice(0, slice_point);
410                        self.in_mem_batch = self.in_mem_batch.slice(
411                            slice_point,
412                            self.in_mem_batch.num_rows() - slice_point,
413                        );
414                        let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
415                        if let Some(fetch) = self.fetch.as_mut() {
416                            *fetch -= sorted_batch.num_rows();
417                        }
418
419                        if sorted_batch.num_rows() > 0 {
420                            return Poll::Ready(Some(Ok(sorted_batch)));
421                        }
422                    }
423                }
424                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
425                None => {
426                    self.is_closed = true;
427                    // Once input is consumed, sort the rest of the inserted batches
428                    let remaining_batch = self.sort_in_mem_batch()?;
429                    return if remaining_batch.num_rows() > 0 {
430                        Poll::Ready(Some(Ok(remaining_batch)))
431                    } else {
432                        Poll::Ready(None)
433                    };
434                }
435            };
436        }
437    }
438
439    /// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
440    ///
441    /// If fetch is specified for PartialSortStream `sort_in_mem_batch` will limit
442    /// the last RecordBatch returned and will mark the stream as closed
443    fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
444        let input_batch = self.in_mem_batch.clone();
445        self.in_mem_batch = RecordBatch::new_empty(self.schema());
446        let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
447        if let Some(remaining_fetch) = self.fetch {
448            // remaining_fetch - result.num_rows() is always be >= 0
449            // because result length of sort_batch with limit cannot be
450            // more than the requested limit
451            self.fetch = Some(remaining_fetch - result.num_rows());
452            if remaining_fetch == result.num_rows() {
453                self.is_closed = true;
454            }
455        }
456        Ok(result)
457    }
458
459    /// Return the end index of the second last partition if the batch
460    /// can be partitioned based on its already sorted columns
461    ///
462    /// Return None if the batch cannot be partitioned, which means the
463    /// batch does not have the information for a safe sort
464    fn get_slice_point(
465        &self,
466        common_prefix_len: usize,
467        batch: &RecordBatch,
468    ) -> Result<Option<usize>> {
469        let common_prefix_sort_keys = (0..common_prefix_len)
470            .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
471            .collect::<Result<Vec<_>>>()?;
472        let partition_points =
473            evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
474        // If partition points are [0..100], [100..200], [200..300]
475        // we should return 200, which is the safest and furthest partition boundary
476        // Please note that we shouldn't return 300 (which is number of rows in the batch),
477        // because this boundary may change with new data.
478        if partition_points.len() >= 2 {
479            Ok(Some(partition_points[partition_points.len() - 2].end))
480        } else {
481            Ok(None)
482        }
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use std::collections::HashMap;
489
490    use arrow::array::*;
491    use arrow::compute::SortOptions;
492    use arrow::datatypes::*;
493    use datafusion_common::test_util::batches_to_string;
494    use futures::FutureExt;
495    use insta::allow_duplicates;
496    use insta::assert_snapshot;
497    use itertools::Itertools;
498
499    use crate::collect;
500    use crate::expressions::PhysicalSortExpr;
501    use crate::expressions::col;
502    use crate::sorts::sort::SortExec;
503    use crate::test;
504    use crate::test::TestMemoryExec;
505    use crate::test::assert_is_pending;
506    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
507
508    use super::*;
509
510    #[tokio::test]
511    async fn test_partial_sort() -> Result<()> {
512        let task_ctx = Arc::new(TaskContext::default());
513        let source = test::build_table_scan_i32(
514            ("a", &vec![0, 0, 0, 1, 1, 1]),
515            ("b", &vec![1, 1, 2, 2, 3, 3]),
516            ("c", &vec![1, 0, 5, 4, 3, 2]),
517        );
518        let schema = Schema::new(vec![
519            Field::new("a", DataType::Int32, false),
520            Field::new("b", DataType::Int32, false),
521            Field::new("c", DataType::Int32, false),
522        ]);
523        let option_asc = SortOptions {
524            descending: false,
525            nulls_first: false,
526        };
527
528        let partial_sort_exec = Arc::new(PartialSortExec::new(
529            [
530                PhysicalSortExpr {
531                    expr: col("a", &schema)?,
532                    options: option_asc,
533                },
534                PhysicalSortExpr {
535                    expr: col("b", &schema)?,
536                    options: option_asc,
537                },
538                PhysicalSortExpr {
539                    expr: col("c", &schema)?,
540                    options: option_asc,
541                },
542            ]
543            .into(),
544            Arc::clone(&source),
545            2,
546        ));
547
548        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
549
550        assert_eq!(2, result.len());
551        allow_duplicates! {
552            assert_snapshot!(batches_to_string(&result), @r"
553            +---+---+---+
554            | a | b | c |
555            +---+---+---+
556            | 0 | 1 | 0 |
557            | 0 | 1 | 1 |
558            | 0 | 2 | 5 |
559            | 1 | 2 | 4 |
560            | 1 | 3 | 2 |
561            | 1 | 3 | 3 |
562            +---+---+---+
563            ");
564        }
565        assert_eq!(
566            task_ctx.runtime_env().memory_pool.reserved(),
567            0,
568            "The sort should have returned all memory used back to the memory manager"
569        );
570
571        Ok(())
572    }
573
574    #[tokio::test]
575    async fn test_partial_sort_with_fetch() -> Result<()> {
576        let task_ctx = Arc::new(TaskContext::default());
577        let source = test::build_table_scan_i32(
578            ("a", &vec![0, 0, 1, 1, 1]),
579            ("b", &vec![1, 2, 2, 3, 3]),
580            ("c", &vec![4, 3, 2, 1, 0]),
581        );
582        let schema = Schema::new(vec![
583            Field::new("a", DataType::Int32, false),
584            Field::new("b", DataType::Int32, false),
585            Field::new("c", DataType::Int32, false),
586        ]);
587        let option_asc = SortOptions {
588            descending: false,
589            nulls_first: false,
590        };
591
592        for common_prefix_length in [1, 2] {
593            let partial_sort_exec = Arc::new(
594                PartialSortExec::new(
595                    [
596                        PhysicalSortExpr {
597                            expr: col("a", &schema)?,
598                            options: option_asc,
599                        },
600                        PhysicalSortExpr {
601                            expr: col("b", &schema)?,
602                            options: option_asc,
603                        },
604                        PhysicalSortExpr {
605                            expr: col("c", &schema)?,
606                            options: option_asc,
607                        },
608                    ]
609                    .into(),
610                    Arc::clone(&source),
611                    common_prefix_length,
612                )
613                .with_fetch(Some(4)),
614            );
615
616            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
617
618            assert_eq!(2, result.len());
619            allow_duplicates! {
620                assert_snapshot!(batches_to_string(&result), @r"
621                +---+---+---+
622                | a | b | c |
623                +---+---+---+
624                | 0 | 1 | 4 |
625                | 0 | 2 | 3 |
626                | 1 | 2 | 2 |
627                | 1 | 3 | 0 |
628                +---+---+---+
629                ");
630            }
631            assert_eq!(
632                task_ctx.runtime_env().memory_pool.reserved(),
633                0,
634                "The sort should have returned all memory used back to the memory manager"
635            );
636        }
637
638        Ok(())
639    }
640
641    #[tokio::test]
642    async fn test_partial_sort2() -> Result<()> {
643        let task_ctx = Arc::new(TaskContext::default());
644        let source_tables = [
645            test::build_table_scan_i32(
646                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
647                ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
648                ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
649            ),
650            test::build_table_scan_i32(
651                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
652                ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
653                ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
654            ),
655        ];
656        let schema = Schema::new(vec![
657            Field::new("a", DataType::Int32, false),
658            Field::new("b", DataType::Int32, false),
659            Field::new("c", DataType::Int32, false),
660        ]);
661        let option_asc = SortOptions {
662            descending: false,
663            nulls_first: false,
664        };
665        for (common_prefix_length, source) in
666            [(1, &source_tables[0]), (2, &source_tables[1])]
667        {
668            let partial_sort_exec = Arc::new(PartialSortExec::new(
669                [
670                    PhysicalSortExpr {
671                        expr: col("a", &schema)?,
672                        options: option_asc,
673                    },
674                    PhysicalSortExpr {
675                        expr: col("b", &schema)?,
676                        options: option_asc,
677                    },
678                    PhysicalSortExpr {
679                        expr: col("c", &schema)?,
680                        options: option_asc,
681                    },
682                ]
683                .into(),
684                Arc::clone(source),
685                common_prefix_length,
686            ));
687
688            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
689            assert_eq!(2, result.len());
690            assert_eq!(
691                task_ctx.runtime_env().memory_pool.reserved(),
692                0,
693                "The sort should have returned all memory used back to the memory manager"
694            );
695            allow_duplicates! {
696                assert_snapshot!(batches_to_string(&result), @r"
697                +---+---+---+
698                | a | b | c |
699                +---+---+---+
700                | 0 | 1 | 6 |
701                | 0 | 1 | 7 |
702                | 0 | 3 | 4 |
703                | 0 | 3 | 5 |
704                | 1 | 2 | 0 |
705                | 1 | 2 | 1 |
706                | 1 | 4 | 2 |
707                | 1 | 4 | 3 |
708                +---+---+---+
709                ");
710            }
711        }
712        Ok(())
713    }
714
715    fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
716        let batch1 = test::build_table_i32(
717            ("a", &vec![1; 100]),
718            ("b", &(0..100).rev().collect()),
719            ("c", &(0..100).rev().collect()),
720        );
721        let batch2 = test::build_table_i32(
722            ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
723            ("b", &(100..200).rev().collect()),
724            ("c", &(0..100).collect()),
725        );
726        let batch3 = test::build_table_i32(
727            ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
728            ("b", &(150..250).rev().collect()),
729            ("c", &(0..100).rev().collect()),
730        );
731        let batch4 = test::build_table_i32(
732            ("a", &vec![4; 100]),
733            ("b", &(50..150).rev().collect()),
734            ("c", &(0..100).rev().collect()),
735        );
736        let schema = batch1.schema();
737
738        TestMemoryExec::try_new_exec(
739            &[vec![batch1, batch2, batch3, batch4]],
740            Arc::clone(&schema),
741            None,
742        )
743        .unwrap() as Arc<dyn ExecutionPlan>
744    }
745
746    #[tokio::test]
747    async fn test_partitioned_input_partial_sort() -> Result<()> {
748        let task_ctx = Arc::new(TaskContext::default());
749        let mem_exec = prepare_partitioned_input();
750        let option_asc = SortOptions {
751            descending: false,
752            nulls_first: false,
753        };
754        let option_desc = SortOptions {
755            descending: false,
756            nulls_first: false,
757        };
758        let schema = mem_exec.schema();
759        let partial_sort_exec = PartialSortExec::new(
760            [
761                PhysicalSortExpr {
762                    expr: col("a", &schema)?,
763                    options: option_asc,
764                },
765                PhysicalSortExpr {
766                    expr: col("b", &schema)?,
767                    options: option_desc,
768                },
769                PhysicalSortExpr {
770                    expr: col("c", &schema)?,
771                    options: option_asc,
772                },
773            ]
774            .into(),
775            Arc::clone(&mem_exec),
776            1,
777        );
778        let sort_exec = Arc::new(SortExec::new(
779            partial_sort_exec.expr.clone(),
780            Arc::clone(&partial_sort_exec.input),
781        ));
782        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
783        assert_eq!(
784            result.iter().map(|r| r.num_rows()).collect_vec(),
785            [125, 125, 150]
786        );
787
788        assert_eq!(
789            task_ctx.runtime_env().memory_pool.reserved(),
790            0,
791            "The sort should have returned all memory used back to the memory manager"
792        );
793        let partial_sort_result = concat_batches(&schema, &result).unwrap();
794        let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
795        assert_eq!(sort_result[0], partial_sort_result);
796
797        Ok(())
798    }
799
800    #[tokio::test]
801    async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
802        let task_ctx = Arc::new(TaskContext::default());
803        let mem_exec = prepare_partitioned_input();
804        let schema = mem_exec.schema();
805        let option_asc = SortOptions {
806            descending: false,
807            nulls_first: false,
808        };
809        let option_desc = SortOptions {
810            descending: false,
811            nulls_first: false,
812        };
813        for (fetch_size, expected_batch_num_rows) in [
814            (Some(50), vec![50]),
815            (Some(120), vec![120]),
816            (Some(150), vec![125, 25]),
817            (Some(250), vec![125, 125]),
818        ] {
819            let partial_sort_exec = PartialSortExec::new(
820                [
821                    PhysicalSortExpr {
822                        expr: col("a", &schema)?,
823                        options: option_asc,
824                    },
825                    PhysicalSortExpr {
826                        expr: col("b", &schema)?,
827                        options: option_desc,
828                    },
829                    PhysicalSortExpr {
830                        expr: col("c", &schema)?,
831                        options: option_asc,
832                    },
833                ]
834                .into(),
835                Arc::clone(&mem_exec),
836                1,
837            )
838            .with_fetch(fetch_size);
839
840            let sort_exec = Arc::new(
841                SortExec::new(
842                    partial_sort_exec.expr.clone(),
843                    Arc::clone(&partial_sort_exec.input),
844                )
845                .with_fetch(fetch_size),
846            );
847            let result =
848                collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
849            assert_eq!(
850                result.iter().map(|r| r.num_rows()).collect_vec(),
851                expected_batch_num_rows
852            );
853
854            assert_eq!(
855                task_ctx.runtime_env().memory_pool.reserved(),
856                0,
857                "The sort should have returned all memory used back to the memory manager"
858            );
859            let partial_sort_result = concat_batches(&schema, &result)?;
860            let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
861            assert_eq!(sort_result[0], partial_sort_result);
862        }
863
864        Ok(())
865    }
866
867    #[tokio::test]
868    async fn test_partial_sort_no_empty_batches() -> Result<()> {
869        let task_ctx = Arc::new(TaskContext::default());
870        let mem_exec = prepare_partitioned_input();
871        let schema = mem_exec.schema();
872        let option_asc = SortOptions {
873            descending: false,
874            nulls_first: false,
875        };
876        let fetch_size = Some(250);
877        let partial_sort_exec = PartialSortExec::new(
878            [
879                PhysicalSortExpr {
880                    expr: col("a", &schema)?,
881                    options: option_asc,
882                },
883                PhysicalSortExpr {
884                    expr: col("c", &schema)?,
885                    options: option_asc,
886                },
887            ]
888            .into(),
889            Arc::clone(&mem_exec),
890            1,
891        )
892        .with_fetch(fetch_size);
893
894        let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
895        for rb in result {
896            assert!(rb.num_rows() > 0);
897        }
898
899        Ok(())
900    }
901
902    #[tokio::test]
903    async fn test_sort_metadata() -> Result<()> {
904        let task_ctx = Arc::new(TaskContext::default());
905        let field_metadata: HashMap<String, String> =
906            vec![("foo".to_string(), "bar".to_string())]
907                .into_iter()
908                .collect();
909        let schema_metadata: HashMap<String, String> =
910            vec![("baz".to_string(), "barf".to_string())]
911                .into_iter()
912                .collect();
913
914        let mut field = Field::new("field_name", DataType::UInt64, true);
915        field.set_metadata(field_metadata.clone());
916        let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
917        let schema = Arc::new(schema);
918
919        let data: ArrayRef =
920            Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
921
922        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
923        let input =
924            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
925
926        let partial_sort_exec = Arc::new(PartialSortExec::new(
927            [PhysicalSortExpr {
928                expr: col("field_name", &schema)?,
929                options: SortOptions::default(),
930            }]
931            .into(),
932            input,
933            1,
934        ));
935
936        let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
937        let expected_batch = vec![
938            RecordBatch::try_new(
939                Arc::clone(&schema),
940                vec![Arc::new(
941                    vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
942                )],
943            )?,
944            RecordBatch::try_new(
945                Arc::clone(&schema),
946                vec![Arc::new(
947                    vec![2].into_iter().map(Some).collect::<UInt64Array>(),
948                )],
949            )?,
950        ];
951
952        // Data is correct
953        assert_eq!(&expected_batch, &result);
954
955        // explicitly ensure the metadata is present
956        assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
957        assert_eq!(result[0].schema().metadata(), &schema_metadata);
958
959        Ok(())
960    }
961
962    #[tokio::test]
963    async fn test_lex_sort_by_float() -> Result<()> {
964        let task_ctx = Arc::new(TaskContext::default());
965        let schema = Arc::new(Schema::new(vec![
966            Field::new("a", DataType::Float32, true),
967            Field::new("b", DataType::Float64, true),
968            Field::new("c", DataType::Float64, true),
969        ]));
970        let option_asc = SortOptions {
971            descending: false,
972            nulls_first: true,
973        };
974        let option_desc = SortOptions {
975            descending: true,
976            nulls_first: true,
977        };
978
979        // define data.
980        let batch = RecordBatch::try_new(
981            Arc::clone(&schema),
982            vec![
983                Arc::new(Float32Array::from(vec![
984                    Some(1.0_f32),
985                    Some(1.0_f32),
986                    Some(1.0_f32),
987                    Some(2.0_f32),
988                    Some(2.0_f32),
989                    Some(3.0_f32),
990                    Some(3.0_f32),
991                    Some(3.0_f32),
992                ])),
993                Arc::new(Float64Array::from(vec![
994                    Some(20.0_f64),
995                    Some(20.0_f64),
996                    Some(40.0_f64),
997                    Some(40.0_f64),
998                    Some(f64::NAN),
999                    None,
1000                    None,
1001                    Some(f64::NAN),
1002                ])),
1003                Arc::new(Float64Array::from(vec![
1004                    Some(10.0_f64),
1005                    Some(20.0_f64),
1006                    Some(10.0_f64),
1007                    Some(100.0_f64),
1008                    Some(f64::NAN),
1009                    Some(100.0_f64),
1010                    None,
1011                    Some(f64::NAN),
1012                ])),
1013            ],
1014        )?;
1015
1016        let partial_sort_exec = Arc::new(PartialSortExec::new(
1017            [
1018                PhysicalSortExpr {
1019                    expr: col("a", &schema)?,
1020                    options: option_asc,
1021                },
1022                PhysicalSortExpr {
1023                    expr: col("b", &schema)?,
1024                    options: option_asc,
1025                },
1026                PhysicalSortExpr {
1027                    expr: col("c", &schema)?,
1028                    options: option_desc,
1029                },
1030            ]
1031            .into(),
1032            TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1033            2,
1034        ));
1035
1036        assert_eq!(
1037            DataType::Float32,
1038            *partial_sort_exec.schema().field(0).data_type()
1039        );
1040        assert_eq!(
1041            DataType::Float64,
1042            *partial_sort_exec.schema().field(1).data_type()
1043        );
1044        assert_eq!(
1045            DataType::Float64,
1046            *partial_sort_exec.schema().field(2).data_type()
1047        );
1048
1049        let result: Vec<RecordBatch> = collect(
1050            Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1051            task_ctx,
1052        )
1053        .await?;
1054        assert_snapshot!(batches_to_string(&result), @r"
1055        +-----+------+-------+
1056        | a   | b    | c     |
1057        +-----+------+-------+
1058        | 1.0 | 20.0 | 20.0  |
1059        | 1.0 | 20.0 | 10.0  |
1060        | 1.0 | 40.0 | 10.0  |
1061        | 2.0 | 40.0 | 100.0 |
1062        | 2.0 | NaN  | NaN   |
1063        | 3.0 |      |       |
1064        | 3.0 |      | 100.0 |
1065        | 3.0 | NaN  | NaN   |
1066        +-----+------+-------+
1067        ");
1068        assert_eq!(result.len(), 2);
1069        let metrics = partial_sort_exec.metrics().unwrap();
1070        assert!(metrics.elapsed_compute().unwrap() > 0);
1071        assert_eq!(metrics.output_rows().unwrap(), 8);
1072
1073        let columns = result[0].columns();
1074
1075        assert_eq!(DataType::Float32, *columns[0].data_type());
1076        assert_eq!(DataType::Float64, *columns[1].data_type());
1077        assert_eq!(DataType::Float64, *columns[2].data_type());
1078
1079        Ok(())
1080    }
1081
1082    #[tokio::test]
1083    async fn test_drop_cancel() -> Result<()> {
1084        let task_ctx = Arc::new(TaskContext::default());
1085        let schema = Arc::new(Schema::new(vec![
1086            Field::new("a", DataType::Float32, true),
1087            Field::new("b", DataType::Float32, true),
1088        ]));
1089
1090        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1091        let refs = blocking_exec.refs();
1092        let sort_exec = Arc::new(PartialSortExec::new(
1093            [PhysicalSortExpr {
1094                expr: col("a", &schema)?,
1095                options: SortOptions::default(),
1096            }]
1097            .into(),
1098            blocking_exec,
1099            1,
1100        ));
1101
1102        let fut = collect(sort_exec, Arc::clone(&task_ctx));
1103        let mut fut = fut.boxed();
1104
1105        assert_is_pending(&mut fut);
1106        drop(fut);
1107        assert_strong_count_converges_to_zero(refs).await;
1108
1109        assert_eq!(
1110            task_ctx.runtime_env().memory_pool.reserved(),
1111            0,
1112            "The sort should have returned all memory used back to the memory manager"
1113        );
1114
1115        Ok(())
1116    }
1117
1118    #[tokio::test]
1119    async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1120        // Test case for the bug where batches with homogeneous sort keys
1121        // (e.g., [1,1,1], [2,2,2]) would not be properly detected as having
1122        // slice points between batches.
1123        let task_ctx = Arc::new(TaskContext::default());
1124
1125        // Create batches where each batch has homogeneous values for sort keys
1126        let batch1 = test::build_table_i32(
1127            ("a", &vec![1; 3]),
1128            ("b", &vec![1; 3]),
1129            ("c", &vec![3, 2, 1]),
1130        );
1131        let batch2 = test::build_table_i32(
1132            ("a", &vec![2; 3]),
1133            ("b", &vec![2; 3]),
1134            ("c", &vec![4, 6, 4]),
1135        );
1136        let batch3 = test::build_table_i32(
1137            ("a", &vec![3; 3]),
1138            ("b", &vec![3; 3]),
1139            ("c", &vec![9, 7, 8]),
1140        );
1141
1142        let schema = batch1.schema();
1143        let mem_exec = TestMemoryExec::try_new_exec(
1144            &[vec![batch1, batch2, batch3]],
1145            Arc::clone(&schema),
1146            None,
1147        )?;
1148
1149        let option_asc = SortOptions {
1150            descending: false,
1151            nulls_first: false,
1152        };
1153
1154        // Partial sort with common prefix of 2 (sorting by a, b, c)
1155        let partial_sort_exec = Arc::new(PartialSortExec::new(
1156            [
1157                PhysicalSortExpr {
1158                    expr: col("a", &schema)?,
1159                    options: option_asc,
1160                },
1161                PhysicalSortExpr {
1162                    expr: col("b", &schema)?,
1163                    options: option_asc,
1164                },
1165                PhysicalSortExpr {
1166                    expr: col("c", &schema)?,
1167                    options: option_asc,
1168                },
1169            ]
1170            .into(),
1171            mem_exec,
1172            2,
1173        ));
1174
1175        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1176
1177        assert_eq!(result.len(), 3,);
1178
1179        allow_duplicates! {
1180            assert_snapshot!(batches_to_string(&result), @r"
1181            +---+---+---+
1182            | a | b | c |
1183            +---+---+---+
1184            | 1 | 1 | 1 |
1185            | 1 | 1 | 2 |
1186            | 1 | 1 | 3 |
1187            | 2 | 2 | 4 |
1188            | 2 | 2 | 4 |
1189            | 2 | 2 | 6 |
1190            | 3 | 3 | 7 |
1191            | 3 | 3 | 8 |
1192            | 3 | 3 | 9 |
1193            +---+---+---+
1194            ");
1195        }
1196
1197        assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1198        Ok(())
1199    }
1200}