1use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65};
66
67use arrow::compute::concat_batches;
68use arrow::datatypes::SchemaRef;
69use arrow::record_batch::RecordBatch;
70use datafusion_common::utils::evaluate_partition_ranges;
71use datafusion_common::Result;
72use datafusion_execution::{RecordBatchStream, TaskContext};
73use datafusion_physical_expr::LexOrdering;
74
75use futures::{ready, Stream, StreamExt};
76use log::trace;
77
78#[derive(Debug, Clone)]
80pub struct PartialSortExec {
81 pub(crate) input: Arc<dyn ExecutionPlan>,
83 expr: LexOrdering,
85 common_prefix_length: usize,
88 metrics_set: ExecutionPlanMetricsSet,
90 preserve_partitioning: bool,
93 fetch: Option<usize>,
95 cache: PlanProperties,
97}
98
99impl PartialSortExec {
100 pub fn new(
102 expr: LexOrdering,
103 input: Arc<dyn ExecutionPlan>,
104 common_prefix_length: usize,
105 ) -> Self {
106 debug_assert!(common_prefix_length > 0);
107 let preserve_partitioning = false;
108 let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
109 .unwrap();
110 Self {
111 input,
112 expr,
113 common_prefix_length,
114 metrics_set: ExecutionPlanMetricsSet::new(),
115 preserve_partitioning,
116 fetch: None,
117 cache,
118 }
119 }
120
121 pub fn preserve_partitioning(&self) -> bool {
123 self.preserve_partitioning
124 }
125
126 pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
134 self.preserve_partitioning = preserve_partitioning;
135 self.cache = self
136 .cache
137 .with_partitioning(Self::output_partitioning_helper(
138 &self.input,
139 self.preserve_partitioning,
140 ));
141 self
142 }
143
144 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
152 self.fetch = fetch;
153 self
154 }
155
156 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
158 &self.input
159 }
160
161 pub fn expr(&self) -> &LexOrdering {
163 &self.expr
164 }
165
166 pub fn fetch(&self) -> Option<usize> {
168 self.fetch
169 }
170
171 pub fn common_prefix_length(&self) -> usize {
173 self.common_prefix_length
174 }
175
176 fn output_partitioning_helper(
177 input: &Arc<dyn ExecutionPlan>,
178 preserve_partitioning: bool,
179 ) -> Partitioning {
180 if preserve_partitioning {
182 input.output_partitioning().clone()
183 } else {
184 Partitioning::UnknownPartitioning(1)
185 }
186 }
187
188 fn compute_properties(
190 input: &Arc<dyn ExecutionPlan>,
191 sort_exprs: LexOrdering,
192 preserve_partitioning: bool,
193 ) -> Result<PlanProperties> {
194 let mut eq_properties = input.equivalence_properties().clone();
197 eq_properties.reorder(sort_exprs)?;
198
199 let output_partitioning =
201 Self::output_partitioning_helper(input, preserve_partitioning);
202
203 Ok(PlanProperties::new(
204 eq_properties,
205 output_partitioning,
206 input.pipeline_behavior(),
207 input.boundedness(),
208 ))
209 }
210}
211
212impl DisplayAs for PartialSortExec {
213 fn fmt_as(
214 &self,
215 t: DisplayFormatType,
216 f: &mut std::fmt::Formatter,
217 ) -> std::fmt::Result {
218 match t {
219 DisplayFormatType::Default | DisplayFormatType::Verbose => {
220 let common_prefix_length = self.common_prefix_length;
221 match self.fetch {
222 Some(fetch) => {
223 write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr)
224 }
225 None => write!(f, "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr),
226 }
227 }
228 DisplayFormatType::TreeRender => match self.fetch {
229 Some(fetch) => {
230 writeln!(f, "{}", self.expr)?;
231 writeln!(f, "limit={fetch}")
232 }
233 None => {
234 writeln!(f, "{}", self.expr)
235 }
236 },
237 }
238 }
239}
240
241impl ExecutionPlan for PartialSortExec {
242 fn name(&self) -> &'static str {
243 "PartialSortExec"
244 }
245
246 fn as_any(&self) -> &dyn Any {
247 self
248 }
249
250 fn properties(&self) -> &PlanProperties {
251 &self.cache
252 }
253
254 fn fetch(&self) -> Option<usize> {
255 self.fetch
256 }
257
258 fn required_input_distribution(&self) -> Vec<Distribution> {
259 if self.preserve_partitioning {
260 vec![Distribution::UnspecifiedDistribution]
261 } else {
262 vec![Distribution::SinglePartition]
263 }
264 }
265
266 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
267 vec![false]
268 }
269
270 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
271 vec![&self.input]
272 }
273
274 fn with_new_children(
275 self: Arc<Self>,
276 children: Vec<Arc<dyn ExecutionPlan>>,
277 ) -> Result<Arc<dyn ExecutionPlan>> {
278 let new_partial_sort = PartialSortExec::new(
279 self.expr.clone(),
280 Arc::clone(&children[0]),
281 self.common_prefix_length,
282 )
283 .with_fetch(self.fetch)
284 .with_preserve_partitioning(self.preserve_partitioning);
285
286 Ok(Arc::new(new_partial_sort))
287 }
288
289 fn execute(
290 &self,
291 partition: usize,
292 context: Arc<TaskContext>,
293 ) -> Result<SendableRecordBatchStream> {
294 trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
295
296 let input = self.input.execute(partition, Arc::clone(&context))?;
297
298 trace!("End PartialSortExec's input.execute for partition: {partition}");
299
300 debug_assert!(self.common_prefix_length > 0);
303
304 Ok(Box::pin(PartialSortStream {
305 input,
306 expr: self.expr.clone(),
307 common_prefix_length: self.common_prefix_length,
308 in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
309 fetch: self.fetch,
310 is_closed: false,
311 baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
312 }))
313 }
314
315 fn metrics(&self) -> Option<MetricsSet> {
316 Some(self.metrics_set.clone_inner())
317 }
318
319 fn statistics(&self) -> Result<Statistics> {
320 self.input.partition_statistics(None)
321 }
322
323 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
324 self.input.partition_statistics(partition)
325 }
326}
327
328struct PartialSortStream {
329 input: SendableRecordBatchStream,
331 expr: LexOrdering,
333 common_prefix_length: usize,
336 in_mem_batch: RecordBatch,
338 fetch: Option<usize>,
340 is_closed: bool,
342 baseline_metrics: BaselineMetrics,
344}
345
346impl Stream for PartialSortStream {
347 type Item = Result<RecordBatch>;
348
349 fn poll_next(
350 mut self: Pin<&mut Self>,
351 cx: &mut Context<'_>,
352 ) -> Poll<Option<Self::Item>> {
353 let poll = self.poll_next_inner(cx);
354 self.baseline_metrics.record_poll(poll)
355 }
356
357 fn size_hint(&self) -> (usize, Option<usize>) {
358 self.input.size_hint()
360 }
361}
362
363impl RecordBatchStream for PartialSortStream {
364 fn schema(&self) -> SchemaRef {
365 self.input.schema()
366 }
367}
368
369impl PartialSortStream {
370 fn poll_next_inner(
371 self: &mut Pin<&mut Self>,
372 cx: &mut Context<'_>,
373 ) -> Poll<Option<Result<RecordBatch>>> {
374 if self.is_closed {
375 return Poll::Ready(None);
376 }
377 loop {
378 if self.fetch == Some(0) {
380 self.is_closed = true;
381 return Poll::Ready(None);
382 }
383
384 match ready!(self.input.poll_next_unpin(cx)) {
385 Some(Ok(batch)) => {
386 self.in_mem_batch = concat_batches(
388 &self.schema(),
389 &[self.in_mem_batch.clone(), batch],
390 )?;
391
392 if let Some(slice_point) = self
394 .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
395 {
396 let sorted = self.in_mem_batch.slice(0, slice_point);
397 self.in_mem_batch = self.in_mem_batch.slice(
398 slice_point,
399 self.in_mem_batch.num_rows() - slice_point,
400 );
401 let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
402 if let Some(fetch) = self.fetch.as_mut() {
403 *fetch -= sorted_batch.num_rows();
404 }
405
406 if sorted_batch.num_rows() > 0 {
407 return Poll::Ready(Some(Ok(sorted_batch)));
408 }
409 }
410 }
411 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
412 None => {
413 self.is_closed = true;
414 let remaining_batch = self.sort_in_mem_batch()?;
416 return if remaining_batch.num_rows() > 0 {
417 Poll::Ready(Some(Ok(remaining_batch)))
418 } else {
419 Poll::Ready(None)
420 };
421 }
422 };
423 }
424 }
425
426 fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
431 let input_batch = self.in_mem_batch.clone();
432 self.in_mem_batch = RecordBatch::new_empty(self.schema());
433 let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
434 if let Some(remaining_fetch) = self.fetch {
435 self.fetch = Some(remaining_fetch - result.num_rows());
439 if remaining_fetch == result.num_rows() {
440 self.is_closed = true;
441 }
442 }
443 Ok(result)
444 }
445
446 fn get_slice_point(
452 &self,
453 common_prefix_len: usize,
454 batch: &RecordBatch,
455 ) -> Result<Option<usize>> {
456 let common_prefix_sort_keys = (0..common_prefix_len)
457 .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
458 .collect::<Result<Vec<_>>>()?;
459 let partition_points =
460 evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
461 if partition_points.len() >= 2 {
466 Ok(Some(partition_points[partition_points.len() - 2].end))
467 } else {
468 Ok(None)
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use std::collections::HashMap;
476
477 use arrow::array::*;
478 use arrow::compute::SortOptions;
479 use arrow::datatypes::*;
480 use datafusion_common::test_util::batches_to_string;
481 use futures::FutureExt;
482 use insta::allow_duplicates;
483 use insta::assert_snapshot;
484 use itertools::Itertools;
485
486 use crate::collect;
487 use crate::expressions::col;
488 use crate::expressions::PhysicalSortExpr;
489 use crate::sorts::sort::SortExec;
490 use crate::test;
491 use crate::test::assert_is_pending;
492 use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
493 use crate::test::TestMemoryExec;
494
495 use super::*;
496
497 #[tokio::test]
498 async fn test_partial_sort() -> Result<()> {
499 let task_ctx = Arc::new(TaskContext::default());
500 let source = test::build_table_scan_i32(
501 ("a", &vec![0, 0, 0, 1, 1, 1]),
502 ("b", &vec![1, 1, 2, 2, 3, 3]),
503 ("c", &vec![1, 0, 5, 4, 3, 2]),
504 );
505 let schema = Schema::new(vec![
506 Field::new("a", DataType::Int32, false),
507 Field::new("b", DataType::Int32, false),
508 Field::new("c", DataType::Int32, false),
509 ]);
510 let option_asc = SortOptions {
511 descending: false,
512 nulls_first: false,
513 };
514
515 let partial_sort_exec = Arc::new(PartialSortExec::new(
516 [
517 PhysicalSortExpr {
518 expr: col("a", &schema)?,
519 options: option_asc,
520 },
521 PhysicalSortExpr {
522 expr: col("b", &schema)?,
523 options: option_asc,
524 },
525 PhysicalSortExpr {
526 expr: col("c", &schema)?,
527 options: option_asc,
528 },
529 ]
530 .into(),
531 Arc::clone(&source),
532 2,
533 ));
534
535 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
536
537 assert_eq!(2, result.len());
538 allow_duplicates! {
539 assert_snapshot!(batches_to_string(&result), @r#"
540 +---+---+---+
541 | a | b | c |
542 +---+---+---+
543 | 0 | 1 | 0 |
544 | 0 | 1 | 1 |
545 | 0 | 2 | 5 |
546 | 1 | 2 | 4 |
547 | 1 | 3 | 2 |
548 | 1 | 3 | 3 |
549 +---+---+---+
550 "#);
551 }
552 assert_eq!(
553 task_ctx.runtime_env().memory_pool.reserved(),
554 0,
555 "The sort should have returned all memory used back to the memory manager"
556 );
557
558 Ok(())
559 }
560
561 #[tokio::test]
562 async fn test_partial_sort_with_fetch() -> Result<()> {
563 let task_ctx = Arc::new(TaskContext::default());
564 let source = test::build_table_scan_i32(
565 ("a", &vec![0, 0, 1, 1, 1]),
566 ("b", &vec![1, 2, 2, 3, 3]),
567 ("c", &vec![4, 3, 2, 1, 0]),
568 );
569 let schema = Schema::new(vec![
570 Field::new("a", DataType::Int32, false),
571 Field::new("b", DataType::Int32, false),
572 Field::new("c", DataType::Int32, false),
573 ]);
574 let option_asc = SortOptions {
575 descending: false,
576 nulls_first: false,
577 };
578
579 for common_prefix_length in [1, 2] {
580 let partial_sort_exec = Arc::new(
581 PartialSortExec::new(
582 [
583 PhysicalSortExpr {
584 expr: col("a", &schema)?,
585 options: option_asc,
586 },
587 PhysicalSortExpr {
588 expr: col("b", &schema)?,
589 options: option_asc,
590 },
591 PhysicalSortExpr {
592 expr: col("c", &schema)?,
593 options: option_asc,
594 },
595 ]
596 .into(),
597 Arc::clone(&source),
598 common_prefix_length,
599 )
600 .with_fetch(Some(4)),
601 );
602
603 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
604
605 assert_eq!(2, result.len());
606 allow_duplicates! {
607 assert_snapshot!(batches_to_string(&result), @r#"
608 +---+---+---+
609 | a | b | c |
610 +---+---+---+
611 | 0 | 1 | 4 |
612 | 0 | 2 | 3 |
613 | 1 | 2 | 2 |
614 | 1 | 3 | 0 |
615 +---+---+---+
616 "#);
617 }
618 assert_eq!(
619 task_ctx.runtime_env().memory_pool.reserved(),
620 0,
621 "The sort should have returned all memory used back to the memory manager"
622 );
623 }
624
625 Ok(())
626 }
627
628 #[tokio::test]
629 async fn test_partial_sort2() -> Result<()> {
630 let task_ctx = Arc::new(TaskContext::default());
631 let source_tables = [
632 test::build_table_scan_i32(
633 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
634 ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
635 ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
636 ),
637 test::build_table_scan_i32(
638 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
639 ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
640 ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
641 ),
642 ];
643 let schema = Schema::new(vec![
644 Field::new("a", DataType::Int32, false),
645 Field::new("b", DataType::Int32, false),
646 Field::new("c", DataType::Int32, false),
647 ]);
648 let option_asc = SortOptions {
649 descending: false,
650 nulls_first: false,
651 };
652 for (common_prefix_length, source) in
653 [(1, &source_tables[0]), (2, &source_tables[1])]
654 {
655 let partial_sort_exec = Arc::new(PartialSortExec::new(
656 [
657 PhysicalSortExpr {
658 expr: col("a", &schema)?,
659 options: option_asc,
660 },
661 PhysicalSortExpr {
662 expr: col("b", &schema)?,
663 options: option_asc,
664 },
665 PhysicalSortExpr {
666 expr: col("c", &schema)?,
667 options: option_asc,
668 },
669 ]
670 .into(),
671 Arc::clone(source),
672 common_prefix_length,
673 ));
674
675 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
676 assert_eq!(2, result.len());
677 assert_eq!(
678 task_ctx.runtime_env().memory_pool.reserved(),
679 0,
680 "The sort should have returned all memory used back to the memory manager"
681 );
682 allow_duplicates! {
683 assert_snapshot!(batches_to_string(&result), @r#"
684 +---+---+---+
685 | a | b | c |
686 +---+---+---+
687 | 0 | 1 | 6 |
688 | 0 | 1 | 7 |
689 | 0 | 3 | 4 |
690 | 0 | 3 | 5 |
691 | 1 | 2 | 0 |
692 | 1 | 2 | 1 |
693 | 1 | 4 | 2 |
694 | 1 | 4 | 3 |
695 +---+---+---+
696 "#);
697 }
698 }
699 Ok(())
700 }
701
702 fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
703 let batch1 = test::build_table_i32(
704 ("a", &vec![1; 100]),
705 ("b", &(0..100).rev().collect()),
706 ("c", &(0..100).rev().collect()),
707 );
708 let batch2 = test::build_table_i32(
709 ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
710 ("b", &(100..200).rev().collect()),
711 ("c", &(0..100).collect()),
712 );
713 let batch3 = test::build_table_i32(
714 ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
715 ("b", &(150..250).rev().collect()),
716 ("c", &(0..100).rev().collect()),
717 );
718 let batch4 = test::build_table_i32(
719 ("a", &vec![4; 100]),
720 ("b", &(50..150).rev().collect()),
721 ("c", &(0..100).rev().collect()),
722 );
723 let schema = batch1.schema();
724
725 TestMemoryExec::try_new_exec(
726 &[vec![batch1, batch2, batch3, batch4]],
727 Arc::clone(&schema),
728 None,
729 )
730 .unwrap() as Arc<dyn ExecutionPlan>
731 }
732
733 #[tokio::test]
734 async fn test_partitioned_input_partial_sort() -> Result<()> {
735 let task_ctx = Arc::new(TaskContext::default());
736 let mem_exec = prepare_partitioned_input();
737 let option_asc = SortOptions {
738 descending: false,
739 nulls_first: false,
740 };
741 let option_desc = SortOptions {
742 descending: false,
743 nulls_first: false,
744 };
745 let schema = mem_exec.schema();
746 let partial_sort_exec = PartialSortExec::new(
747 [
748 PhysicalSortExpr {
749 expr: col("a", &schema)?,
750 options: option_asc,
751 },
752 PhysicalSortExpr {
753 expr: col("b", &schema)?,
754 options: option_desc,
755 },
756 PhysicalSortExpr {
757 expr: col("c", &schema)?,
758 options: option_asc,
759 },
760 ]
761 .into(),
762 Arc::clone(&mem_exec),
763 1,
764 );
765 let sort_exec = Arc::new(SortExec::new(
766 partial_sort_exec.expr.clone(),
767 Arc::clone(&partial_sort_exec.input),
768 ));
769 let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
770 assert_eq!(
771 result.iter().map(|r| r.num_rows()).collect_vec(),
772 [125, 125, 150]
773 );
774
775 assert_eq!(
776 task_ctx.runtime_env().memory_pool.reserved(),
777 0,
778 "The sort should have returned all memory used back to the memory manager"
779 );
780 let partial_sort_result = concat_batches(&schema, &result).unwrap();
781 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
782 assert_eq!(sort_result[0], partial_sort_result);
783
784 Ok(())
785 }
786
787 #[tokio::test]
788 async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
789 let task_ctx = Arc::new(TaskContext::default());
790 let mem_exec = prepare_partitioned_input();
791 let schema = mem_exec.schema();
792 let option_asc = SortOptions {
793 descending: false,
794 nulls_first: false,
795 };
796 let option_desc = SortOptions {
797 descending: false,
798 nulls_first: false,
799 };
800 for (fetch_size, expected_batch_num_rows) in [
801 (Some(50), vec![50]),
802 (Some(120), vec![120]),
803 (Some(150), vec![125, 25]),
804 (Some(250), vec![125, 125]),
805 ] {
806 let partial_sort_exec = PartialSortExec::new(
807 [
808 PhysicalSortExpr {
809 expr: col("a", &schema)?,
810 options: option_asc,
811 },
812 PhysicalSortExpr {
813 expr: col("b", &schema)?,
814 options: option_desc,
815 },
816 PhysicalSortExpr {
817 expr: col("c", &schema)?,
818 options: option_asc,
819 },
820 ]
821 .into(),
822 Arc::clone(&mem_exec),
823 1,
824 )
825 .with_fetch(fetch_size);
826
827 let sort_exec = Arc::new(
828 SortExec::new(
829 partial_sort_exec.expr.clone(),
830 Arc::clone(&partial_sort_exec.input),
831 )
832 .with_fetch(fetch_size),
833 );
834 let result =
835 collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
836 assert_eq!(
837 result.iter().map(|r| r.num_rows()).collect_vec(),
838 expected_batch_num_rows
839 );
840
841 assert_eq!(
842 task_ctx.runtime_env().memory_pool.reserved(),
843 0,
844 "The sort should have returned all memory used back to the memory manager"
845 );
846 let partial_sort_result = concat_batches(&schema, &result)?;
847 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
848 assert_eq!(sort_result[0], partial_sort_result);
849 }
850
851 Ok(())
852 }
853
854 #[tokio::test]
855 async fn test_partial_sort_no_empty_batches() -> Result<()> {
856 let task_ctx = Arc::new(TaskContext::default());
857 let mem_exec = prepare_partitioned_input();
858 let schema = mem_exec.schema();
859 let option_asc = SortOptions {
860 descending: false,
861 nulls_first: false,
862 };
863 let fetch_size = Some(250);
864 let partial_sort_exec = PartialSortExec::new(
865 [
866 PhysicalSortExpr {
867 expr: col("a", &schema)?,
868 options: option_asc,
869 },
870 PhysicalSortExpr {
871 expr: col("c", &schema)?,
872 options: option_asc,
873 },
874 ]
875 .into(),
876 Arc::clone(&mem_exec),
877 1,
878 )
879 .with_fetch(fetch_size);
880
881 let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
882 for rb in result {
883 assert!(rb.num_rows() > 0);
884 }
885
886 Ok(())
887 }
888
889 #[tokio::test]
890 async fn test_sort_metadata() -> Result<()> {
891 let task_ctx = Arc::new(TaskContext::default());
892 let field_metadata: HashMap<String, String> =
893 vec![("foo".to_string(), "bar".to_string())]
894 .into_iter()
895 .collect();
896 let schema_metadata: HashMap<String, String> =
897 vec![("baz".to_string(), "barf".to_string())]
898 .into_iter()
899 .collect();
900
901 let mut field = Field::new("field_name", DataType::UInt64, true);
902 field.set_metadata(field_metadata.clone());
903 let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
904 let schema = Arc::new(schema);
905
906 let data: ArrayRef =
907 Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
908
909 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
910 let input =
911 TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
912
913 let partial_sort_exec = Arc::new(PartialSortExec::new(
914 [PhysicalSortExpr {
915 expr: col("field_name", &schema)?,
916 options: SortOptions::default(),
917 }]
918 .into(),
919 input,
920 1,
921 ));
922
923 let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
924 let expected_batch = vec![
925 RecordBatch::try_new(
926 Arc::clone(&schema),
927 vec![Arc::new(
928 vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
929 )],
930 )?,
931 RecordBatch::try_new(
932 Arc::clone(&schema),
933 vec![Arc::new(
934 vec![2].into_iter().map(Some).collect::<UInt64Array>(),
935 )],
936 )?,
937 ];
938
939 assert_eq!(&expected_batch, &result);
941
942 assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
944 assert_eq!(result[0].schema().metadata(), &schema_metadata);
945
946 Ok(())
947 }
948
949 #[tokio::test]
950 async fn test_lex_sort_by_float() -> Result<()> {
951 let task_ctx = Arc::new(TaskContext::default());
952 let schema = Arc::new(Schema::new(vec![
953 Field::new("a", DataType::Float32, true),
954 Field::new("b", DataType::Float64, true),
955 Field::new("c", DataType::Float64, true),
956 ]));
957 let option_asc = SortOptions {
958 descending: false,
959 nulls_first: true,
960 };
961 let option_desc = SortOptions {
962 descending: true,
963 nulls_first: true,
964 };
965
966 let batch = RecordBatch::try_new(
968 Arc::clone(&schema),
969 vec![
970 Arc::new(Float32Array::from(vec![
971 Some(1.0_f32),
972 Some(1.0_f32),
973 Some(1.0_f32),
974 Some(2.0_f32),
975 Some(2.0_f32),
976 Some(3.0_f32),
977 Some(3.0_f32),
978 Some(3.0_f32),
979 ])),
980 Arc::new(Float64Array::from(vec![
981 Some(20.0_f64),
982 Some(20.0_f64),
983 Some(40.0_f64),
984 Some(40.0_f64),
985 Some(f64::NAN),
986 None,
987 None,
988 Some(f64::NAN),
989 ])),
990 Arc::new(Float64Array::from(vec![
991 Some(10.0_f64),
992 Some(20.0_f64),
993 Some(10.0_f64),
994 Some(100.0_f64),
995 Some(f64::NAN),
996 Some(100.0_f64),
997 None,
998 Some(f64::NAN),
999 ])),
1000 ],
1001 )?;
1002
1003 let partial_sort_exec = Arc::new(PartialSortExec::new(
1004 [
1005 PhysicalSortExpr {
1006 expr: col("a", &schema)?,
1007 options: option_asc,
1008 },
1009 PhysicalSortExpr {
1010 expr: col("b", &schema)?,
1011 options: option_asc,
1012 },
1013 PhysicalSortExpr {
1014 expr: col("c", &schema)?,
1015 options: option_desc,
1016 },
1017 ]
1018 .into(),
1019 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1020 2,
1021 ));
1022
1023 assert_eq!(
1024 DataType::Float32,
1025 *partial_sort_exec.schema().field(0).data_type()
1026 );
1027 assert_eq!(
1028 DataType::Float64,
1029 *partial_sort_exec.schema().field(1).data_type()
1030 );
1031 assert_eq!(
1032 DataType::Float64,
1033 *partial_sort_exec.schema().field(2).data_type()
1034 );
1035
1036 let result: Vec<RecordBatch> = collect(
1037 Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1038 task_ctx,
1039 )
1040 .await?;
1041 assert_snapshot!(batches_to_string(&result), @r#"
1042 +-----+------+-------+
1043 | a | b | c |
1044 +-----+------+-------+
1045 | 1.0 | 20.0 | 20.0 |
1046 | 1.0 | 20.0 | 10.0 |
1047 | 1.0 | 40.0 | 10.0 |
1048 | 2.0 | 40.0 | 100.0 |
1049 | 2.0 | NaN | NaN |
1050 | 3.0 | | |
1051 | 3.0 | | 100.0 |
1052 | 3.0 | NaN | NaN |
1053 +-----+------+-------+
1054 "#);
1055 assert_eq!(result.len(), 2);
1056 let metrics = partial_sort_exec.metrics().unwrap();
1057 assert!(metrics.elapsed_compute().unwrap() > 0);
1058 assert_eq!(metrics.output_rows().unwrap(), 8);
1059
1060 let columns = result[0].columns();
1061
1062 assert_eq!(DataType::Float32, *columns[0].data_type());
1063 assert_eq!(DataType::Float64, *columns[1].data_type());
1064 assert_eq!(DataType::Float64, *columns[2].data_type());
1065
1066 Ok(())
1067 }
1068
1069 #[tokio::test]
1070 async fn test_drop_cancel() -> Result<()> {
1071 let task_ctx = Arc::new(TaskContext::default());
1072 let schema = Arc::new(Schema::new(vec![
1073 Field::new("a", DataType::Float32, true),
1074 Field::new("b", DataType::Float32, true),
1075 ]));
1076
1077 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1078 let refs = blocking_exec.refs();
1079 let sort_exec = Arc::new(PartialSortExec::new(
1080 [PhysicalSortExpr {
1081 expr: col("a", &schema)?,
1082 options: SortOptions::default(),
1083 }]
1084 .into(),
1085 blocking_exec,
1086 1,
1087 ));
1088
1089 let fut = collect(sort_exec, Arc::clone(&task_ctx));
1090 let mut fut = fut.boxed();
1091
1092 assert_is_pending(&mut fut);
1093 drop(fut);
1094 assert_strong_count_converges_to_zero(refs).await;
1095
1096 assert_eq!(
1097 task_ctx.runtime_env().memory_pool.reserved(),
1098 0,
1099 "The sort should have returned all memory used back to the memory manager"
1100 );
1101
1102 Ok(())
1103 }
1104
1105 #[tokio::test]
1106 async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1107 let task_ctx = Arc::new(TaskContext::default());
1111
1112 let batch1 = test::build_table_i32(
1114 ("a", &vec![1; 3]),
1115 ("b", &vec![1; 3]),
1116 ("c", &vec![3, 2, 1]),
1117 );
1118 let batch2 = test::build_table_i32(
1119 ("a", &vec![2; 3]),
1120 ("b", &vec![2; 3]),
1121 ("c", &vec![4, 6, 4]),
1122 );
1123 let batch3 = test::build_table_i32(
1124 ("a", &vec![3; 3]),
1125 ("b", &vec![3; 3]),
1126 ("c", &vec![9, 7, 8]),
1127 );
1128
1129 let schema = batch1.schema();
1130 let mem_exec = TestMemoryExec::try_new_exec(
1131 &[vec![batch1, batch2, batch3]],
1132 Arc::clone(&schema),
1133 None,
1134 )?;
1135
1136 let option_asc = SortOptions {
1137 descending: false,
1138 nulls_first: false,
1139 };
1140
1141 let partial_sort_exec = Arc::new(PartialSortExec::new(
1143 [
1144 PhysicalSortExpr {
1145 expr: col("a", &schema)?,
1146 options: option_asc,
1147 },
1148 PhysicalSortExpr {
1149 expr: col("b", &schema)?,
1150 options: option_asc,
1151 },
1152 PhysicalSortExpr {
1153 expr: col("c", &schema)?,
1154 options: option_asc,
1155 },
1156 ]
1157 .into(),
1158 mem_exec,
1159 2,
1160 ));
1161
1162 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1163
1164 assert_eq!(result.len(), 3,);
1165
1166 allow_duplicates! {
1167 assert_snapshot!(batches_to_string(&result), @r#"
1168 +---+---+---+
1169 | a | b | c |
1170 +---+---+---+
1171 | 1 | 1 | 1 |
1172 | 1 | 1 | 2 |
1173 | 1 | 1 | 3 |
1174 | 2 | 2 | 4 |
1175 | 2 | 2 | 4 |
1176 | 2 | 2 | 6 |
1177 | 3 | 3 | 7 |
1178 | 3 | 3 | 8 |
1179 | 3 | 3 | 9 |
1180 +---+---+---+
1181 "#);
1182 }
1183
1184 assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1185 Ok(())
1186 }
1187}