1use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65 check_if_same_properties,
66};
67
68use arrow::compute::concat_batches;
69use arrow::datatypes::SchemaRef;
70use arrow::record_batch::RecordBatch;
71use datafusion_common::Result;
72use datafusion_common::utils::evaluate_partition_ranges;
73use datafusion_execution::{RecordBatchStream, TaskContext};
74use datafusion_physical_expr::LexOrdering;
75
76use futures::{Stream, StreamExt, ready};
77use log::trace;
78
79#[derive(Debug, Clone)]
81pub struct PartialSortExec {
82 pub(crate) input: Arc<dyn ExecutionPlan>,
84 expr: LexOrdering,
86 common_prefix_length: usize,
89 metrics_set: ExecutionPlanMetricsSet,
91 preserve_partitioning: bool,
94 fetch: Option<usize>,
96 cache: Arc<PlanProperties>,
98}
99
100impl PartialSortExec {
101 pub fn new(
103 expr: LexOrdering,
104 input: Arc<dyn ExecutionPlan>,
105 common_prefix_length: usize,
106 ) -> Self {
107 debug_assert!(common_prefix_length > 0);
108 let preserve_partitioning = false;
109 let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
110 .unwrap();
111 Self {
112 input,
113 expr,
114 common_prefix_length,
115 metrics_set: ExecutionPlanMetricsSet::new(),
116 preserve_partitioning,
117 fetch: None,
118 cache: Arc::new(cache),
119 }
120 }
121
122 pub fn preserve_partitioning(&self) -> bool {
124 self.preserve_partitioning
125 }
126
127 pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
135 self.preserve_partitioning = preserve_partitioning;
136 Arc::make_mut(&mut self.cache).partitioning =
137 Self::output_partitioning_helper(&self.input, self.preserve_partitioning);
138 self
139 }
140
141 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
149 self.fetch = fetch;
150 self
151 }
152
153 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
155 &self.input
156 }
157
158 pub fn expr(&self) -> &LexOrdering {
160 &self.expr
161 }
162
163 pub fn fetch(&self) -> Option<usize> {
165 self.fetch
166 }
167
168 pub fn common_prefix_length(&self) -> usize {
170 self.common_prefix_length
171 }
172
173 fn output_partitioning_helper(
174 input: &Arc<dyn ExecutionPlan>,
175 preserve_partitioning: bool,
176 ) -> Partitioning {
177 if preserve_partitioning {
179 input.output_partitioning().clone()
180 } else {
181 Partitioning::UnknownPartitioning(1)
182 }
183 }
184
185 fn compute_properties(
187 input: &Arc<dyn ExecutionPlan>,
188 sort_exprs: LexOrdering,
189 preserve_partitioning: bool,
190 ) -> Result<PlanProperties> {
191 let mut eq_properties = input.equivalence_properties().clone();
194 eq_properties.reorder(sort_exprs)?;
195
196 let output_partitioning =
198 Self::output_partitioning_helper(input, preserve_partitioning);
199
200 Ok(PlanProperties::new(
201 eq_properties,
202 output_partitioning,
203 input.pipeline_behavior(),
204 input.boundedness(),
205 ))
206 }
207
208 fn with_new_children_and_same_properties(
209 &self,
210 mut children: Vec<Arc<dyn ExecutionPlan>>,
211 ) -> Self {
212 Self {
213 input: children.swap_remove(0),
214 metrics_set: ExecutionPlanMetricsSet::new(),
215 ..Self::clone(self)
216 }
217 }
218}
219
220impl DisplayAs for PartialSortExec {
221 fn fmt_as(
222 &self,
223 t: DisplayFormatType,
224 f: &mut std::fmt::Formatter,
225 ) -> std::fmt::Result {
226 match t {
227 DisplayFormatType::Default | DisplayFormatType::Verbose => {
228 let common_prefix_length = self.common_prefix_length;
229 match self.fetch {
230 Some(fetch) => {
231 write!(
232 f,
233 "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]",
234 self.expr
235 )
236 }
237 None => write!(
238 f,
239 "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]",
240 self.expr
241 ),
242 }
243 }
244 DisplayFormatType::TreeRender => match self.fetch {
245 Some(fetch) => {
246 writeln!(f, "{}", self.expr)?;
247 writeln!(f, "limit={fetch}")
248 }
249 None => {
250 writeln!(f, "{}", self.expr)
251 }
252 },
253 }
254 }
255}
256
257impl ExecutionPlan for PartialSortExec {
258 fn name(&self) -> &'static str {
259 "PartialSortExec"
260 }
261
262 fn as_any(&self) -> &dyn Any {
263 self
264 }
265
266 fn properties(&self) -> &Arc<PlanProperties> {
267 &self.cache
268 }
269
270 fn fetch(&self) -> Option<usize> {
271 self.fetch
272 }
273
274 fn required_input_distribution(&self) -> Vec<Distribution> {
275 if self.preserve_partitioning {
276 vec![Distribution::UnspecifiedDistribution]
277 } else {
278 vec![Distribution::SinglePartition]
279 }
280 }
281
282 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
283 vec![false]
284 }
285
286 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
287 vec![&self.input]
288 }
289
290 fn with_new_children(
291 self: Arc<Self>,
292 children: Vec<Arc<dyn ExecutionPlan>>,
293 ) -> Result<Arc<dyn ExecutionPlan>> {
294 check_if_same_properties!(self, children);
295 let new_partial_sort = PartialSortExec::new(
296 self.expr.clone(),
297 Arc::clone(&children[0]),
298 self.common_prefix_length,
299 )
300 .with_fetch(self.fetch)
301 .with_preserve_partitioning(self.preserve_partitioning);
302
303 Ok(Arc::new(new_partial_sort))
304 }
305
306 fn execute(
307 &self,
308 partition: usize,
309 context: Arc<TaskContext>,
310 ) -> Result<SendableRecordBatchStream> {
311 trace!(
312 "Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}",
313 partition,
314 context.session_id(),
315 context.task_id()
316 );
317
318 let input = self.input.execute(partition, Arc::clone(&context))?;
319
320 trace!("End PartialSortExec's input.execute for partition: {partition}");
321
322 debug_assert!(self.common_prefix_length > 0);
325
326 Ok(Box::pin(PartialSortStream {
327 input,
328 expr: self.expr.clone(),
329 common_prefix_length: self.common_prefix_length,
330 in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
331 fetch: self.fetch,
332 is_closed: false,
333 baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
334 }))
335 }
336
337 fn metrics(&self) -> Option<MetricsSet> {
338 Some(self.metrics_set.clone_inner())
339 }
340
341 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
342 self.input.partition_statistics(partition)
343 }
344}
345
346struct PartialSortStream {
347 input: SendableRecordBatchStream,
349 expr: LexOrdering,
351 common_prefix_length: usize,
354 in_mem_batch: RecordBatch,
356 fetch: Option<usize>,
358 is_closed: bool,
360 baseline_metrics: BaselineMetrics,
362}
363
364impl Stream for PartialSortStream {
365 type Item = Result<RecordBatch>;
366
367 fn poll_next(
368 mut self: Pin<&mut Self>,
369 cx: &mut Context<'_>,
370 ) -> Poll<Option<Self::Item>> {
371 let poll = self.poll_next_inner(cx);
372 self.baseline_metrics.record_poll(poll)
373 }
374
375 fn size_hint(&self) -> (usize, Option<usize>) {
376 self.input.size_hint()
378 }
379}
380
381impl RecordBatchStream for PartialSortStream {
382 fn schema(&self) -> SchemaRef {
383 self.input.schema()
384 }
385}
386
387impl PartialSortStream {
388 fn poll_next_inner(
389 self: &mut Pin<&mut Self>,
390 cx: &mut Context<'_>,
391 ) -> Poll<Option<Result<RecordBatch>>> {
392 if self.is_closed {
393 return Poll::Ready(None);
394 }
395 loop {
396 if self.fetch == Some(0) {
398 self.is_closed = true;
399 return Poll::Ready(None);
400 }
401
402 match ready!(self.input.poll_next_unpin(cx)) {
403 Some(Ok(batch)) => {
404 self.in_mem_batch = concat_batches(
406 &self.schema(),
407 &[self.in_mem_batch.clone(), batch],
408 )?;
409
410 if let Some(slice_point) = self
412 .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
413 {
414 let sorted = self.in_mem_batch.slice(0, slice_point);
415 self.in_mem_batch = self.in_mem_batch.slice(
416 slice_point,
417 self.in_mem_batch.num_rows() - slice_point,
418 );
419 let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
420 if let Some(fetch) = self.fetch.as_mut() {
421 *fetch -= sorted_batch.num_rows();
422 }
423
424 if sorted_batch.num_rows() > 0 {
425 return Poll::Ready(Some(Ok(sorted_batch)));
426 }
427 }
428 }
429 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
430 None => {
431 self.is_closed = true;
432 let remaining_batch = self.sort_in_mem_batch()?;
434 return if remaining_batch.num_rows() > 0 {
435 Poll::Ready(Some(Ok(remaining_batch)))
436 } else {
437 Poll::Ready(None)
438 };
439 }
440 };
441 }
442 }
443
444 fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
449 let input_batch = self.in_mem_batch.clone();
450 self.in_mem_batch = RecordBatch::new_empty(self.schema());
451 let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
452 if let Some(remaining_fetch) = self.fetch {
453 self.fetch = Some(remaining_fetch - result.num_rows());
457 if remaining_fetch == result.num_rows() {
458 self.is_closed = true;
459 }
460 }
461 Ok(result)
462 }
463
464 fn get_slice_point(
470 &self,
471 common_prefix_len: usize,
472 batch: &RecordBatch,
473 ) -> Result<Option<usize>> {
474 let common_prefix_sort_keys = (0..common_prefix_len)
475 .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
476 .collect::<Result<Vec<_>>>()?;
477 let partition_points =
478 evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
479 if partition_points.len() >= 2 {
484 Ok(Some(partition_points[partition_points.len() - 2].end))
485 } else {
486 Ok(None)
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use std::collections::HashMap;
494
495 use arrow::array::*;
496 use arrow::compute::SortOptions;
497 use arrow::datatypes::*;
498 use datafusion_common::test_util::batches_to_string;
499 use futures::FutureExt;
500 use insta::allow_duplicates;
501 use insta::assert_snapshot;
502 use itertools::Itertools;
503
504 use crate::collect;
505 use crate::expressions::PhysicalSortExpr;
506 use crate::expressions::col;
507 use crate::sorts::sort::SortExec;
508 use crate::test;
509 use crate::test::TestMemoryExec;
510 use crate::test::assert_is_pending;
511 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
512
513 use super::*;
514
515 #[tokio::test]
516 async fn test_partial_sort() -> Result<()> {
517 let task_ctx = Arc::new(TaskContext::default());
518 let source = test::build_table_scan_i32(
519 ("a", &vec![0, 0, 0, 1, 1, 1]),
520 ("b", &vec![1, 1, 2, 2, 3, 3]),
521 ("c", &vec![1, 0, 5, 4, 3, 2]),
522 );
523 let schema = Schema::new(vec![
524 Field::new("a", DataType::Int32, false),
525 Field::new("b", DataType::Int32, false),
526 Field::new("c", DataType::Int32, false),
527 ]);
528 let option_asc = SortOptions {
529 descending: false,
530 nulls_first: false,
531 };
532
533 let partial_sort_exec = Arc::new(PartialSortExec::new(
534 [
535 PhysicalSortExpr {
536 expr: col("a", &schema)?,
537 options: option_asc,
538 },
539 PhysicalSortExpr {
540 expr: col("b", &schema)?,
541 options: option_asc,
542 },
543 PhysicalSortExpr {
544 expr: col("c", &schema)?,
545 options: option_asc,
546 },
547 ]
548 .into(),
549 Arc::clone(&source),
550 2,
551 ));
552
553 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
554
555 assert_eq!(2, result.len());
556 allow_duplicates! {
557 assert_snapshot!(batches_to_string(&result), @r"
558 +---+---+---+
559 | a | b | c |
560 +---+---+---+
561 | 0 | 1 | 0 |
562 | 0 | 1 | 1 |
563 | 0 | 2 | 5 |
564 | 1 | 2 | 4 |
565 | 1 | 3 | 2 |
566 | 1 | 3 | 3 |
567 +---+---+---+
568 ");
569 }
570 assert_eq!(
571 task_ctx.runtime_env().memory_pool.reserved(),
572 0,
573 "The sort should have returned all memory used back to the memory manager"
574 );
575
576 Ok(())
577 }
578
579 #[tokio::test]
580 async fn test_partial_sort_with_fetch() -> Result<()> {
581 let task_ctx = Arc::new(TaskContext::default());
582 let source = test::build_table_scan_i32(
583 ("a", &vec![0, 0, 1, 1, 1]),
584 ("b", &vec![1, 2, 2, 3, 3]),
585 ("c", &vec![4, 3, 2, 1, 0]),
586 );
587 let schema = Schema::new(vec![
588 Field::new("a", DataType::Int32, false),
589 Field::new("b", DataType::Int32, false),
590 Field::new("c", DataType::Int32, false),
591 ]);
592 let option_asc = SortOptions {
593 descending: false,
594 nulls_first: false,
595 };
596
597 for common_prefix_length in [1, 2] {
598 let partial_sort_exec = Arc::new(
599 PartialSortExec::new(
600 [
601 PhysicalSortExpr {
602 expr: col("a", &schema)?,
603 options: option_asc,
604 },
605 PhysicalSortExpr {
606 expr: col("b", &schema)?,
607 options: option_asc,
608 },
609 PhysicalSortExpr {
610 expr: col("c", &schema)?,
611 options: option_asc,
612 },
613 ]
614 .into(),
615 Arc::clone(&source),
616 common_prefix_length,
617 )
618 .with_fetch(Some(4)),
619 );
620
621 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
622
623 assert_eq!(2, result.len());
624 allow_duplicates! {
625 assert_snapshot!(batches_to_string(&result), @r"
626 +---+---+---+
627 | a | b | c |
628 +---+---+---+
629 | 0 | 1 | 4 |
630 | 0 | 2 | 3 |
631 | 1 | 2 | 2 |
632 | 1 | 3 | 0 |
633 +---+---+---+
634 ");
635 }
636 assert_eq!(
637 task_ctx.runtime_env().memory_pool.reserved(),
638 0,
639 "The sort should have returned all memory used back to the memory manager"
640 );
641 }
642
643 Ok(())
644 }
645
646 #[tokio::test]
647 async fn test_partial_sort2() -> Result<()> {
648 let task_ctx = Arc::new(TaskContext::default());
649 let source_tables = [
650 test::build_table_scan_i32(
651 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
652 ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
653 ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
654 ),
655 test::build_table_scan_i32(
656 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
657 ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
658 ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
659 ),
660 ];
661 let schema = Schema::new(vec![
662 Field::new("a", DataType::Int32, false),
663 Field::new("b", DataType::Int32, false),
664 Field::new("c", DataType::Int32, false),
665 ]);
666 let option_asc = SortOptions {
667 descending: false,
668 nulls_first: false,
669 };
670 for (common_prefix_length, source) in
671 [(1, &source_tables[0]), (2, &source_tables[1])]
672 {
673 let partial_sort_exec = Arc::new(PartialSortExec::new(
674 [
675 PhysicalSortExpr {
676 expr: col("a", &schema)?,
677 options: option_asc,
678 },
679 PhysicalSortExpr {
680 expr: col("b", &schema)?,
681 options: option_asc,
682 },
683 PhysicalSortExpr {
684 expr: col("c", &schema)?,
685 options: option_asc,
686 },
687 ]
688 .into(),
689 Arc::clone(source),
690 common_prefix_length,
691 ));
692
693 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
694 assert_eq!(2, result.len());
695 assert_eq!(
696 task_ctx.runtime_env().memory_pool.reserved(),
697 0,
698 "The sort should have returned all memory used back to the memory manager"
699 );
700 allow_duplicates! {
701 assert_snapshot!(batches_to_string(&result), @r"
702 +---+---+---+
703 | a | b | c |
704 +---+---+---+
705 | 0 | 1 | 6 |
706 | 0 | 1 | 7 |
707 | 0 | 3 | 4 |
708 | 0 | 3 | 5 |
709 | 1 | 2 | 0 |
710 | 1 | 2 | 1 |
711 | 1 | 4 | 2 |
712 | 1 | 4 | 3 |
713 +---+---+---+
714 ");
715 }
716 }
717 Ok(())
718 }
719
720 fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
721 let batch1 = test::build_table_i32(
722 ("a", &vec![1; 100]),
723 ("b", &(0..100).rev().collect()),
724 ("c", &(0..100).rev().collect()),
725 );
726 let batch2 = test::build_table_i32(
727 ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
728 ("b", &(100..200).rev().collect()),
729 ("c", &(0..100).collect()),
730 );
731 let batch3 = test::build_table_i32(
732 ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
733 ("b", &(150..250).rev().collect()),
734 ("c", &(0..100).rev().collect()),
735 );
736 let batch4 = test::build_table_i32(
737 ("a", &vec![4; 100]),
738 ("b", &(50..150).rev().collect()),
739 ("c", &(0..100).rev().collect()),
740 );
741 let schema = batch1.schema();
742
743 TestMemoryExec::try_new_exec(
744 &[vec![batch1, batch2, batch3, batch4]],
745 Arc::clone(&schema),
746 None,
747 )
748 .unwrap() as Arc<dyn ExecutionPlan>
749 }
750
751 #[tokio::test]
752 async fn test_partitioned_input_partial_sort() -> Result<()> {
753 let task_ctx = Arc::new(TaskContext::default());
754 let mem_exec = prepare_partitioned_input();
755 let option_asc = SortOptions {
756 descending: false,
757 nulls_first: false,
758 };
759 let option_desc = SortOptions {
760 descending: false,
761 nulls_first: false,
762 };
763 let schema = mem_exec.schema();
764 let partial_sort_exec = PartialSortExec::new(
765 [
766 PhysicalSortExpr {
767 expr: col("a", &schema)?,
768 options: option_asc,
769 },
770 PhysicalSortExpr {
771 expr: col("b", &schema)?,
772 options: option_desc,
773 },
774 PhysicalSortExpr {
775 expr: col("c", &schema)?,
776 options: option_asc,
777 },
778 ]
779 .into(),
780 Arc::clone(&mem_exec),
781 1,
782 );
783 let sort_exec = Arc::new(SortExec::new(
784 partial_sort_exec.expr.clone(),
785 Arc::clone(&partial_sort_exec.input),
786 ));
787 let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
788 assert_eq!(
789 result.iter().map(|r| r.num_rows()).collect_vec(),
790 [125, 125, 150]
791 );
792
793 assert_eq!(
794 task_ctx.runtime_env().memory_pool.reserved(),
795 0,
796 "The sort should have returned all memory used back to the memory manager"
797 );
798 let partial_sort_result = concat_batches(&schema, &result).unwrap();
799 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
800 assert_eq!(sort_result[0], partial_sort_result);
801
802 Ok(())
803 }
804
805 #[tokio::test]
806 async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
807 let task_ctx = Arc::new(TaskContext::default());
808 let mem_exec = prepare_partitioned_input();
809 let schema = mem_exec.schema();
810 let option_asc = SortOptions {
811 descending: false,
812 nulls_first: false,
813 };
814 let option_desc = SortOptions {
815 descending: false,
816 nulls_first: false,
817 };
818 for (fetch_size, expected_batch_num_rows) in [
819 (Some(50), vec![50]),
820 (Some(120), vec![120]),
821 (Some(150), vec![125, 25]),
822 (Some(250), vec![125, 125]),
823 ] {
824 let partial_sort_exec = PartialSortExec::new(
825 [
826 PhysicalSortExpr {
827 expr: col("a", &schema)?,
828 options: option_asc,
829 },
830 PhysicalSortExpr {
831 expr: col("b", &schema)?,
832 options: option_desc,
833 },
834 PhysicalSortExpr {
835 expr: col("c", &schema)?,
836 options: option_asc,
837 },
838 ]
839 .into(),
840 Arc::clone(&mem_exec),
841 1,
842 )
843 .with_fetch(fetch_size);
844
845 let sort_exec = Arc::new(
846 SortExec::new(
847 partial_sort_exec.expr.clone(),
848 Arc::clone(&partial_sort_exec.input),
849 )
850 .with_fetch(fetch_size),
851 );
852 let result =
853 collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
854 assert_eq!(
855 result.iter().map(|r| r.num_rows()).collect_vec(),
856 expected_batch_num_rows
857 );
858
859 assert_eq!(
860 task_ctx.runtime_env().memory_pool.reserved(),
861 0,
862 "The sort should have returned all memory used back to the memory manager"
863 );
864 let partial_sort_result = concat_batches(&schema, &result)?;
865 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
866 assert_eq!(sort_result[0], partial_sort_result);
867 }
868
869 Ok(())
870 }
871
872 #[tokio::test]
873 async fn test_partial_sort_no_empty_batches() -> Result<()> {
874 let task_ctx = Arc::new(TaskContext::default());
875 let mem_exec = prepare_partitioned_input();
876 let schema = mem_exec.schema();
877 let option_asc = SortOptions {
878 descending: false,
879 nulls_first: false,
880 };
881 let fetch_size = Some(250);
882 let partial_sort_exec = PartialSortExec::new(
883 [
884 PhysicalSortExpr {
885 expr: col("a", &schema)?,
886 options: option_asc,
887 },
888 PhysicalSortExpr {
889 expr: col("c", &schema)?,
890 options: option_asc,
891 },
892 ]
893 .into(),
894 Arc::clone(&mem_exec),
895 1,
896 )
897 .with_fetch(fetch_size);
898
899 let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
900 for rb in result {
901 assert!(rb.num_rows() > 0);
902 }
903
904 Ok(())
905 }
906
907 #[tokio::test]
908 async fn test_sort_metadata() -> Result<()> {
909 let task_ctx = Arc::new(TaskContext::default());
910 let field_metadata: HashMap<String, String> =
911 vec![("foo".to_string(), "bar".to_string())]
912 .into_iter()
913 .collect();
914 let schema_metadata: HashMap<String, String> =
915 vec![("baz".to_string(), "barf".to_string())]
916 .into_iter()
917 .collect();
918
919 let mut field = Field::new("field_name", DataType::UInt64, true);
920 field.set_metadata(field_metadata.clone());
921 let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
922 let schema = Arc::new(schema);
923
924 let data: ArrayRef =
925 Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
926
927 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
928 let input =
929 TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
930
931 let partial_sort_exec = Arc::new(PartialSortExec::new(
932 [PhysicalSortExpr {
933 expr: col("field_name", &schema)?,
934 options: SortOptions::default(),
935 }]
936 .into(),
937 input,
938 1,
939 ));
940
941 let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
942 let expected_batch = vec![
943 RecordBatch::try_new(
944 Arc::clone(&schema),
945 vec![Arc::new(
946 vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
947 )],
948 )?,
949 RecordBatch::try_new(
950 Arc::clone(&schema),
951 vec![Arc::new(
952 vec![2].into_iter().map(Some).collect::<UInt64Array>(),
953 )],
954 )?,
955 ];
956
957 assert_eq!(&expected_batch, &result);
959
960 assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
962 assert_eq!(result[0].schema().metadata(), &schema_metadata);
963
964 Ok(())
965 }
966
967 #[tokio::test]
968 async fn test_lex_sort_by_float() -> Result<()> {
969 let task_ctx = Arc::new(TaskContext::default());
970 let schema = Arc::new(Schema::new(vec![
971 Field::new("a", DataType::Float32, true),
972 Field::new("b", DataType::Float64, true),
973 Field::new("c", DataType::Float64, true),
974 ]));
975 let option_asc = SortOptions {
976 descending: false,
977 nulls_first: true,
978 };
979 let option_desc = SortOptions {
980 descending: true,
981 nulls_first: true,
982 };
983
984 let batch = RecordBatch::try_new(
986 Arc::clone(&schema),
987 vec![
988 Arc::new(Float32Array::from(vec![
989 Some(1.0_f32),
990 Some(1.0_f32),
991 Some(1.0_f32),
992 Some(2.0_f32),
993 Some(2.0_f32),
994 Some(3.0_f32),
995 Some(3.0_f32),
996 Some(3.0_f32),
997 ])),
998 Arc::new(Float64Array::from(vec![
999 Some(20.0_f64),
1000 Some(20.0_f64),
1001 Some(40.0_f64),
1002 Some(40.0_f64),
1003 Some(f64::NAN),
1004 None,
1005 None,
1006 Some(f64::NAN),
1007 ])),
1008 Arc::new(Float64Array::from(vec![
1009 Some(10.0_f64),
1010 Some(20.0_f64),
1011 Some(10.0_f64),
1012 Some(100.0_f64),
1013 Some(f64::NAN),
1014 Some(100.0_f64),
1015 None,
1016 Some(f64::NAN),
1017 ])),
1018 ],
1019 )?;
1020
1021 let partial_sort_exec = Arc::new(PartialSortExec::new(
1022 [
1023 PhysicalSortExpr {
1024 expr: col("a", &schema)?,
1025 options: option_asc,
1026 },
1027 PhysicalSortExpr {
1028 expr: col("b", &schema)?,
1029 options: option_asc,
1030 },
1031 PhysicalSortExpr {
1032 expr: col("c", &schema)?,
1033 options: option_desc,
1034 },
1035 ]
1036 .into(),
1037 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1038 2,
1039 ));
1040
1041 assert_eq!(
1042 DataType::Float32,
1043 *partial_sort_exec.schema().field(0).data_type()
1044 );
1045 assert_eq!(
1046 DataType::Float64,
1047 *partial_sort_exec.schema().field(1).data_type()
1048 );
1049 assert_eq!(
1050 DataType::Float64,
1051 *partial_sort_exec.schema().field(2).data_type()
1052 );
1053
1054 let result: Vec<RecordBatch> = collect(
1055 Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1056 task_ctx,
1057 )
1058 .await?;
1059 assert_snapshot!(batches_to_string(&result), @r"
1060 +-----+------+-------+
1061 | a | b | c |
1062 +-----+------+-------+
1063 | 1.0 | 20.0 | 20.0 |
1064 | 1.0 | 20.0 | 10.0 |
1065 | 1.0 | 40.0 | 10.0 |
1066 | 2.0 | 40.0 | 100.0 |
1067 | 2.0 | NaN | NaN |
1068 | 3.0 | | |
1069 | 3.0 | | 100.0 |
1070 | 3.0 | NaN | NaN |
1071 +-----+------+-------+
1072 ");
1073 assert_eq!(result.len(), 2);
1074 let metrics = partial_sort_exec.metrics().unwrap();
1075 assert!(metrics.elapsed_compute().unwrap() > 0);
1076 assert_eq!(metrics.output_rows().unwrap(), 8);
1077
1078 let columns = result[0].columns();
1079
1080 assert_eq!(DataType::Float32, *columns[0].data_type());
1081 assert_eq!(DataType::Float64, *columns[1].data_type());
1082 assert_eq!(DataType::Float64, *columns[2].data_type());
1083
1084 Ok(())
1085 }
1086
1087 #[tokio::test]
1088 async fn test_drop_cancel() -> Result<()> {
1089 let task_ctx = Arc::new(TaskContext::default());
1090 let schema = Arc::new(Schema::new(vec![
1091 Field::new("a", DataType::Float32, true),
1092 Field::new("b", DataType::Float32, true),
1093 ]));
1094
1095 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1096 let refs = blocking_exec.refs();
1097 let sort_exec = Arc::new(PartialSortExec::new(
1098 [PhysicalSortExpr {
1099 expr: col("a", &schema)?,
1100 options: SortOptions::default(),
1101 }]
1102 .into(),
1103 blocking_exec,
1104 1,
1105 ));
1106
1107 let fut = collect(sort_exec, Arc::clone(&task_ctx));
1108 let mut fut = fut.boxed();
1109
1110 assert_is_pending(&mut fut);
1111 drop(fut);
1112 assert_strong_count_converges_to_zero(refs).await;
1113
1114 assert_eq!(
1115 task_ctx.runtime_env().memory_pool.reserved(),
1116 0,
1117 "The sort should have returned all memory used back to the memory manager"
1118 );
1119
1120 Ok(())
1121 }
1122
1123 #[tokio::test]
1124 async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1125 let task_ctx = Arc::new(TaskContext::default());
1129
1130 let batch1 = test::build_table_i32(
1132 ("a", &vec![1; 3]),
1133 ("b", &vec![1; 3]),
1134 ("c", &vec![3, 2, 1]),
1135 );
1136 let batch2 = test::build_table_i32(
1137 ("a", &vec![2; 3]),
1138 ("b", &vec![2; 3]),
1139 ("c", &vec![4, 6, 4]),
1140 );
1141 let batch3 = test::build_table_i32(
1142 ("a", &vec![3; 3]),
1143 ("b", &vec![3; 3]),
1144 ("c", &vec![9, 7, 8]),
1145 );
1146
1147 let schema = batch1.schema();
1148 let mem_exec = TestMemoryExec::try_new_exec(
1149 &[vec![batch1, batch2, batch3]],
1150 Arc::clone(&schema),
1151 None,
1152 )?;
1153
1154 let option_asc = SortOptions {
1155 descending: false,
1156 nulls_first: false,
1157 };
1158
1159 let partial_sort_exec = Arc::new(PartialSortExec::new(
1161 [
1162 PhysicalSortExpr {
1163 expr: col("a", &schema)?,
1164 options: option_asc,
1165 },
1166 PhysicalSortExpr {
1167 expr: col("b", &schema)?,
1168 options: option_asc,
1169 },
1170 PhysicalSortExpr {
1171 expr: col("c", &schema)?,
1172 options: option_asc,
1173 },
1174 ]
1175 .into(),
1176 mem_exec,
1177 2,
1178 ));
1179
1180 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1181
1182 assert_eq!(result.len(), 3,);
1183
1184 allow_duplicates! {
1185 assert_snapshot!(batches_to_string(&result), @r"
1186 +---+---+---+
1187 | a | b | c |
1188 +---+---+---+
1189 | 1 | 1 | 1 |
1190 | 1 | 1 | 2 |
1191 | 1 | 1 | 3 |
1192 | 2 | 2 | 4 |
1193 | 2 | 2 | 4 |
1194 | 2 | 2 | 6 |
1195 | 3 | 3 | 7 |
1196 | 3 | 3 | 8 |
1197 | 3 | 3 | 9 |
1198 +---+---+---+
1199 ");
1200 }
1201
1202 assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1203 Ok(())
1204 }
1205}