1use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65};
66
67use arrow::compute::concat_batches;
68use arrow::datatypes::SchemaRef;
69use arrow::record_batch::RecordBatch;
70use datafusion_common::Result;
71use datafusion_common::utils::evaluate_partition_ranges;
72use datafusion_execution::{RecordBatchStream, TaskContext};
73use datafusion_physical_expr::LexOrdering;
74
75use futures::{Stream, StreamExt, ready};
76use log::trace;
77
78#[derive(Debug, Clone)]
80pub struct PartialSortExec {
81 pub(crate) input: Arc<dyn ExecutionPlan>,
83 expr: LexOrdering,
85 common_prefix_length: usize,
88 metrics_set: ExecutionPlanMetricsSet,
90 preserve_partitioning: bool,
93 fetch: Option<usize>,
95 cache: PlanProperties,
97}
98
99impl PartialSortExec {
100 pub fn new(
102 expr: LexOrdering,
103 input: Arc<dyn ExecutionPlan>,
104 common_prefix_length: usize,
105 ) -> Self {
106 debug_assert!(common_prefix_length > 0);
107 let preserve_partitioning = false;
108 let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning)
109 .unwrap();
110 Self {
111 input,
112 expr,
113 common_prefix_length,
114 metrics_set: ExecutionPlanMetricsSet::new(),
115 preserve_partitioning,
116 fetch: None,
117 cache,
118 }
119 }
120
121 pub fn preserve_partitioning(&self) -> bool {
123 self.preserve_partitioning
124 }
125
126 pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
134 self.preserve_partitioning = preserve_partitioning;
135 self.cache = self
136 .cache
137 .with_partitioning(Self::output_partitioning_helper(
138 &self.input,
139 self.preserve_partitioning,
140 ));
141 self
142 }
143
144 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
152 self.fetch = fetch;
153 self
154 }
155
156 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
158 &self.input
159 }
160
161 pub fn expr(&self) -> &LexOrdering {
163 &self.expr
164 }
165
166 pub fn fetch(&self) -> Option<usize> {
168 self.fetch
169 }
170
171 pub fn common_prefix_length(&self) -> usize {
173 self.common_prefix_length
174 }
175
176 fn output_partitioning_helper(
177 input: &Arc<dyn ExecutionPlan>,
178 preserve_partitioning: bool,
179 ) -> Partitioning {
180 if preserve_partitioning {
182 input.output_partitioning().clone()
183 } else {
184 Partitioning::UnknownPartitioning(1)
185 }
186 }
187
188 fn compute_properties(
190 input: &Arc<dyn ExecutionPlan>,
191 sort_exprs: LexOrdering,
192 preserve_partitioning: bool,
193 ) -> Result<PlanProperties> {
194 let mut eq_properties = input.equivalence_properties().clone();
197 eq_properties.reorder(sort_exprs)?;
198
199 let output_partitioning =
201 Self::output_partitioning_helper(input, preserve_partitioning);
202
203 Ok(PlanProperties::new(
204 eq_properties,
205 output_partitioning,
206 input.pipeline_behavior(),
207 input.boundedness(),
208 ))
209 }
210}
211
212impl DisplayAs for PartialSortExec {
213 fn fmt_as(
214 &self,
215 t: DisplayFormatType,
216 f: &mut std::fmt::Formatter,
217 ) -> std::fmt::Result {
218 match t {
219 DisplayFormatType::Default | DisplayFormatType::Verbose => {
220 let common_prefix_length = self.common_prefix_length;
221 match self.fetch {
222 Some(fetch) => {
223 write!(
224 f,
225 "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]",
226 self.expr
227 )
228 }
229 None => write!(
230 f,
231 "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]",
232 self.expr
233 ),
234 }
235 }
236 DisplayFormatType::TreeRender => match self.fetch {
237 Some(fetch) => {
238 writeln!(f, "{}", self.expr)?;
239 writeln!(f, "limit={fetch}")
240 }
241 None => {
242 writeln!(f, "{}", self.expr)
243 }
244 },
245 }
246 }
247}
248
249impl ExecutionPlan for PartialSortExec {
250 fn name(&self) -> &'static str {
251 "PartialSortExec"
252 }
253
254 fn as_any(&self) -> &dyn Any {
255 self
256 }
257
258 fn properties(&self) -> &PlanProperties {
259 &self.cache
260 }
261
262 fn fetch(&self) -> Option<usize> {
263 self.fetch
264 }
265
266 fn required_input_distribution(&self) -> Vec<Distribution> {
267 if self.preserve_partitioning {
268 vec![Distribution::UnspecifiedDistribution]
269 } else {
270 vec![Distribution::SinglePartition]
271 }
272 }
273
274 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
275 vec![false]
276 }
277
278 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
279 vec![&self.input]
280 }
281
282 fn with_new_children(
283 self: Arc<Self>,
284 children: Vec<Arc<dyn ExecutionPlan>>,
285 ) -> Result<Arc<dyn ExecutionPlan>> {
286 let new_partial_sort = PartialSortExec::new(
287 self.expr.clone(),
288 Arc::clone(&children[0]),
289 self.common_prefix_length,
290 )
291 .with_fetch(self.fetch)
292 .with_preserve_partitioning(self.preserve_partitioning);
293
294 Ok(Arc::new(new_partial_sort))
295 }
296
297 fn execute(
298 &self,
299 partition: usize,
300 context: Arc<TaskContext>,
301 ) -> Result<SendableRecordBatchStream> {
302 trace!(
303 "Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}",
304 partition,
305 context.session_id(),
306 context.task_id()
307 );
308
309 let input = self.input.execute(partition, Arc::clone(&context))?;
310
311 trace!("End PartialSortExec's input.execute for partition: {partition}");
312
313 debug_assert!(self.common_prefix_length > 0);
316
317 Ok(Box::pin(PartialSortStream {
318 input,
319 expr: self.expr.clone(),
320 common_prefix_length: self.common_prefix_length,
321 in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())),
322 fetch: self.fetch,
323 is_closed: false,
324 baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
325 }))
326 }
327
328 fn metrics(&self) -> Option<MetricsSet> {
329 Some(self.metrics_set.clone_inner())
330 }
331
332 fn statistics(&self) -> Result<Statistics> {
333 self.input.partition_statistics(None)
334 }
335
336 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
337 self.input.partition_statistics(partition)
338 }
339}
340
341struct PartialSortStream {
342 input: SendableRecordBatchStream,
344 expr: LexOrdering,
346 common_prefix_length: usize,
349 in_mem_batch: RecordBatch,
351 fetch: Option<usize>,
353 is_closed: bool,
355 baseline_metrics: BaselineMetrics,
357}
358
359impl Stream for PartialSortStream {
360 type Item = Result<RecordBatch>;
361
362 fn poll_next(
363 mut self: Pin<&mut Self>,
364 cx: &mut Context<'_>,
365 ) -> Poll<Option<Self::Item>> {
366 let poll = self.poll_next_inner(cx);
367 self.baseline_metrics.record_poll(poll)
368 }
369
370 fn size_hint(&self) -> (usize, Option<usize>) {
371 self.input.size_hint()
373 }
374}
375
376impl RecordBatchStream for PartialSortStream {
377 fn schema(&self) -> SchemaRef {
378 self.input.schema()
379 }
380}
381
382impl PartialSortStream {
383 fn poll_next_inner(
384 self: &mut Pin<&mut Self>,
385 cx: &mut Context<'_>,
386 ) -> Poll<Option<Result<RecordBatch>>> {
387 if self.is_closed {
388 return Poll::Ready(None);
389 }
390 loop {
391 if self.fetch == Some(0) {
393 self.is_closed = true;
394 return Poll::Ready(None);
395 }
396
397 match ready!(self.input.poll_next_unpin(cx)) {
398 Some(Ok(batch)) => {
399 self.in_mem_batch = concat_batches(
401 &self.schema(),
402 &[self.in_mem_batch.clone(), batch],
403 )?;
404
405 if let Some(slice_point) = self
407 .get_slice_point(self.common_prefix_length, &self.in_mem_batch)?
408 {
409 let sorted = self.in_mem_batch.slice(0, slice_point);
410 self.in_mem_batch = self.in_mem_batch.slice(
411 slice_point,
412 self.in_mem_batch.num_rows() - slice_point,
413 );
414 let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?;
415 if let Some(fetch) = self.fetch.as_mut() {
416 *fetch -= sorted_batch.num_rows();
417 }
418
419 if sorted_batch.num_rows() > 0 {
420 return Poll::Ready(Some(Ok(sorted_batch)));
421 }
422 }
423 }
424 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
425 None => {
426 self.is_closed = true;
427 let remaining_batch = self.sort_in_mem_batch()?;
429 return if remaining_batch.num_rows() > 0 {
430 Poll::Ready(Some(Ok(remaining_batch)))
431 } else {
432 Poll::Ready(None)
433 };
434 }
435 };
436 }
437 }
438
439 fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
444 let input_batch = self.in_mem_batch.clone();
445 self.in_mem_batch = RecordBatch::new_empty(self.schema());
446 let result = sort_batch(&input_batch, &self.expr, self.fetch)?;
447 if let Some(remaining_fetch) = self.fetch {
448 self.fetch = Some(remaining_fetch - result.num_rows());
452 if remaining_fetch == result.num_rows() {
453 self.is_closed = true;
454 }
455 }
456 Ok(result)
457 }
458
459 fn get_slice_point(
465 &self,
466 common_prefix_len: usize,
467 batch: &RecordBatch,
468 ) -> Result<Option<usize>> {
469 let common_prefix_sort_keys = (0..common_prefix_len)
470 .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
471 .collect::<Result<Vec<_>>>()?;
472 let partition_points =
473 evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
474 if partition_points.len() >= 2 {
479 Ok(Some(partition_points[partition_points.len() - 2].end))
480 } else {
481 Ok(None)
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use std::collections::HashMap;
489
490 use arrow::array::*;
491 use arrow::compute::SortOptions;
492 use arrow::datatypes::*;
493 use datafusion_common::test_util::batches_to_string;
494 use futures::FutureExt;
495 use insta::allow_duplicates;
496 use insta::assert_snapshot;
497 use itertools::Itertools;
498
499 use crate::collect;
500 use crate::expressions::PhysicalSortExpr;
501 use crate::expressions::col;
502 use crate::sorts::sort::SortExec;
503 use crate::test;
504 use crate::test::TestMemoryExec;
505 use crate::test::assert_is_pending;
506 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
507
508 use super::*;
509
510 #[tokio::test]
511 async fn test_partial_sort() -> Result<()> {
512 let task_ctx = Arc::new(TaskContext::default());
513 let source = test::build_table_scan_i32(
514 ("a", &vec![0, 0, 0, 1, 1, 1]),
515 ("b", &vec![1, 1, 2, 2, 3, 3]),
516 ("c", &vec![1, 0, 5, 4, 3, 2]),
517 );
518 let schema = Schema::new(vec![
519 Field::new("a", DataType::Int32, false),
520 Field::new("b", DataType::Int32, false),
521 Field::new("c", DataType::Int32, false),
522 ]);
523 let option_asc = SortOptions {
524 descending: false,
525 nulls_first: false,
526 };
527
528 let partial_sort_exec = Arc::new(PartialSortExec::new(
529 [
530 PhysicalSortExpr {
531 expr: col("a", &schema)?,
532 options: option_asc,
533 },
534 PhysicalSortExpr {
535 expr: col("b", &schema)?,
536 options: option_asc,
537 },
538 PhysicalSortExpr {
539 expr: col("c", &schema)?,
540 options: option_asc,
541 },
542 ]
543 .into(),
544 Arc::clone(&source),
545 2,
546 ));
547
548 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
549
550 assert_eq!(2, result.len());
551 allow_duplicates! {
552 assert_snapshot!(batches_to_string(&result), @r"
553 +---+---+---+
554 | a | b | c |
555 +---+---+---+
556 | 0 | 1 | 0 |
557 | 0 | 1 | 1 |
558 | 0 | 2 | 5 |
559 | 1 | 2 | 4 |
560 | 1 | 3 | 2 |
561 | 1 | 3 | 3 |
562 +---+---+---+
563 ");
564 }
565 assert_eq!(
566 task_ctx.runtime_env().memory_pool.reserved(),
567 0,
568 "The sort should have returned all memory used back to the memory manager"
569 );
570
571 Ok(())
572 }
573
574 #[tokio::test]
575 async fn test_partial_sort_with_fetch() -> Result<()> {
576 let task_ctx = Arc::new(TaskContext::default());
577 let source = test::build_table_scan_i32(
578 ("a", &vec![0, 0, 1, 1, 1]),
579 ("b", &vec![1, 2, 2, 3, 3]),
580 ("c", &vec![4, 3, 2, 1, 0]),
581 );
582 let schema = Schema::new(vec![
583 Field::new("a", DataType::Int32, false),
584 Field::new("b", DataType::Int32, false),
585 Field::new("c", DataType::Int32, false),
586 ]);
587 let option_asc = SortOptions {
588 descending: false,
589 nulls_first: false,
590 };
591
592 for common_prefix_length in [1, 2] {
593 let partial_sort_exec = Arc::new(
594 PartialSortExec::new(
595 [
596 PhysicalSortExpr {
597 expr: col("a", &schema)?,
598 options: option_asc,
599 },
600 PhysicalSortExpr {
601 expr: col("b", &schema)?,
602 options: option_asc,
603 },
604 PhysicalSortExpr {
605 expr: col("c", &schema)?,
606 options: option_asc,
607 },
608 ]
609 .into(),
610 Arc::clone(&source),
611 common_prefix_length,
612 )
613 .with_fetch(Some(4)),
614 );
615
616 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
617
618 assert_eq!(2, result.len());
619 allow_duplicates! {
620 assert_snapshot!(batches_to_string(&result), @r"
621 +---+---+---+
622 | a | b | c |
623 +---+---+---+
624 | 0 | 1 | 4 |
625 | 0 | 2 | 3 |
626 | 1 | 2 | 2 |
627 | 1 | 3 | 0 |
628 +---+---+---+
629 ");
630 }
631 assert_eq!(
632 task_ctx.runtime_env().memory_pool.reserved(),
633 0,
634 "The sort should have returned all memory used back to the memory manager"
635 );
636 }
637
638 Ok(())
639 }
640
641 #[tokio::test]
642 async fn test_partial_sort2() -> Result<()> {
643 let task_ctx = Arc::new(TaskContext::default());
644 let source_tables = [
645 test::build_table_scan_i32(
646 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
647 ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
648 ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
649 ),
650 test::build_table_scan_i32(
651 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
652 ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
653 ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
654 ),
655 ];
656 let schema = Schema::new(vec![
657 Field::new("a", DataType::Int32, false),
658 Field::new("b", DataType::Int32, false),
659 Field::new("c", DataType::Int32, false),
660 ]);
661 let option_asc = SortOptions {
662 descending: false,
663 nulls_first: false,
664 };
665 for (common_prefix_length, source) in
666 [(1, &source_tables[0]), (2, &source_tables[1])]
667 {
668 let partial_sort_exec = Arc::new(PartialSortExec::new(
669 [
670 PhysicalSortExpr {
671 expr: col("a", &schema)?,
672 options: option_asc,
673 },
674 PhysicalSortExpr {
675 expr: col("b", &schema)?,
676 options: option_asc,
677 },
678 PhysicalSortExpr {
679 expr: col("c", &schema)?,
680 options: option_asc,
681 },
682 ]
683 .into(),
684 Arc::clone(source),
685 common_prefix_length,
686 ));
687
688 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
689 assert_eq!(2, result.len());
690 assert_eq!(
691 task_ctx.runtime_env().memory_pool.reserved(),
692 0,
693 "The sort should have returned all memory used back to the memory manager"
694 );
695 allow_duplicates! {
696 assert_snapshot!(batches_to_string(&result), @r"
697 +---+---+---+
698 | a | b | c |
699 +---+---+---+
700 | 0 | 1 | 6 |
701 | 0 | 1 | 7 |
702 | 0 | 3 | 4 |
703 | 0 | 3 | 5 |
704 | 1 | 2 | 0 |
705 | 1 | 2 | 1 |
706 | 1 | 4 | 2 |
707 | 1 | 4 | 3 |
708 +---+---+---+
709 ");
710 }
711 }
712 Ok(())
713 }
714
715 fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
716 let batch1 = test::build_table_i32(
717 ("a", &vec![1; 100]),
718 ("b", &(0..100).rev().collect()),
719 ("c", &(0..100).rev().collect()),
720 );
721 let batch2 = test::build_table_i32(
722 ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
723 ("b", &(100..200).rev().collect()),
724 ("c", &(0..100).collect()),
725 );
726 let batch3 = test::build_table_i32(
727 ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
728 ("b", &(150..250).rev().collect()),
729 ("c", &(0..100).rev().collect()),
730 );
731 let batch4 = test::build_table_i32(
732 ("a", &vec![4; 100]),
733 ("b", &(50..150).rev().collect()),
734 ("c", &(0..100).rev().collect()),
735 );
736 let schema = batch1.schema();
737
738 TestMemoryExec::try_new_exec(
739 &[vec![batch1, batch2, batch3, batch4]],
740 Arc::clone(&schema),
741 None,
742 )
743 .unwrap() as Arc<dyn ExecutionPlan>
744 }
745
746 #[tokio::test]
747 async fn test_partitioned_input_partial_sort() -> Result<()> {
748 let task_ctx = Arc::new(TaskContext::default());
749 let mem_exec = prepare_partitioned_input();
750 let option_asc = SortOptions {
751 descending: false,
752 nulls_first: false,
753 };
754 let option_desc = SortOptions {
755 descending: false,
756 nulls_first: false,
757 };
758 let schema = mem_exec.schema();
759 let partial_sort_exec = PartialSortExec::new(
760 [
761 PhysicalSortExpr {
762 expr: col("a", &schema)?,
763 options: option_asc,
764 },
765 PhysicalSortExpr {
766 expr: col("b", &schema)?,
767 options: option_desc,
768 },
769 PhysicalSortExpr {
770 expr: col("c", &schema)?,
771 options: option_asc,
772 },
773 ]
774 .into(),
775 Arc::clone(&mem_exec),
776 1,
777 );
778 let sort_exec = Arc::new(SortExec::new(
779 partial_sort_exec.expr.clone(),
780 Arc::clone(&partial_sort_exec.input),
781 ));
782 let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
783 assert_eq!(
784 result.iter().map(|r| r.num_rows()).collect_vec(),
785 [125, 125, 150]
786 );
787
788 assert_eq!(
789 task_ctx.runtime_env().memory_pool.reserved(),
790 0,
791 "The sort should have returned all memory used back to the memory manager"
792 );
793 let partial_sort_result = concat_batches(&schema, &result).unwrap();
794 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
795 assert_eq!(sort_result[0], partial_sort_result);
796
797 Ok(())
798 }
799
800 #[tokio::test]
801 async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
802 let task_ctx = Arc::new(TaskContext::default());
803 let mem_exec = prepare_partitioned_input();
804 let schema = mem_exec.schema();
805 let option_asc = SortOptions {
806 descending: false,
807 nulls_first: false,
808 };
809 let option_desc = SortOptions {
810 descending: false,
811 nulls_first: false,
812 };
813 for (fetch_size, expected_batch_num_rows) in [
814 (Some(50), vec![50]),
815 (Some(120), vec![120]),
816 (Some(150), vec![125, 25]),
817 (Some(250), vec![125, 125]),
818 ] {
819 let partial_sort_exec = PartialSortExec::new(
820 [
821 PhysicalSortExpr {
822 expr: col("a", &schema)?,
823 options: option_asc,
824 },
825 PhysicalSortExpr {
826 expr: col("b", &schema)?,
827 options: option_desc,
828 },
829 PhysicalSortExpr {
830 expr: col("c", &schema)?,
831 options: option_asc,
832 },
833 ]
834 .into(),
835 Arc::clone(&mem_exec),
836 1,
837 )
838 .with_fetch(fetch_size);
839
840 let sort_exec = Arc::new(
841 SortExec::new(
842 partial_sort_exec.expr.clone(),
843 Arc::clone(&partial_sort_exec.input),
844 )
845 .with_fetch(fetch_size),
846 );
847 let result =
848 collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
849 assert_eq!(
850 result.iter().map(|r| r.num_rows()).collect_vec(),
851 expected_batch_num_rows
852 );
853
854 assert_eq!(
855 task_ctx.runtime_env().memory_pool.reserved(),
856 0,
857 "The sort should have returned all memory used back to the memory manager"
858 );
859 let partial_sort_result = concat_batches(&schema, &result)?;
860 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
861 assert_eq!(sort_result[0], partial_sort_result);
862 }
863
864 Ok(())
865 }
866
867 #[tokio::test]
868 async fn test_partial_sort_no_empty_batches() -> Result<()> {
869 let task_ctx = Arc::new(TaskContext::default());
870 let mem_exec = prepare_partitioned_input();
871 let schema = mem_exec.schema();
872 let option_asc = SortOptions {
873 descending: false,
874 nulls_first: false,
875 };
876 let fetch_size = Some(250);
877 let partial_sort_exec = PartialSortExec::new(
878 [
879 PhysicalSortExpr {
880 expr: col("a", &schema)?,
881 options: option_asc,
882 },
883 PhysicalSortExpr {
884 expr: col("c", &schema)?,
885 options: option_asc,
886 },
887 ]
888 .into(),
889 Arc::clone(&mem_exec),
890 1,
891 )
892 .with_fetch(fetch_size);
893
894 let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?;
895 for rb in result {
896 assert!(rb.num_rows() > 0);
897 }
898
899 Ok(())
900 }
901
902 #[tokio::test]
903 async fn test_sort_metadata() -> Result<()> {
904 let task_ctx = Arc::new(TaskContext::default());
905 let field_metadata: HashMap<String, String> =
906 vec![("foo".to_string(), "bar".to_string())]
907 .into_iter()
908 .collect();
909 let schema_metadata: HashMap<String, String> =
910 vec![("baz".to_string(), "barf".to_string())]
911 .into_iter()
912 .collect();
913
914 let mut field = Field::new("field_name", DataType::UInt64, true);
915 field.set_metadata(field_metadata.clone());
916 let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
917 let schema = Arc::new(schema);
918
919 let data: ArrayRef =
920 Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
921
922 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
923 let input =
924 TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
925
926 let partial_sort_exec = Arc::new(PartialSortExec::new(
927 [PhysicalSortExpr {
928 expr: col("field_name", &schema)?,
929 options: SortOptions::default(),
930 }]
931 .into(),
932 input,
933 1,
934 ));
935
936 let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
937 let expected_batch = vec![
938 RecordBatch::try_new(
939 Arc::clone(&schema),
940 vec![Arc::new(
941 vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
942 )],
943 )?,
944 RecordBatch::try_new(
945 Arc::clone(&schema),
946 vec![Arc::new(
947 vec![2].into_iter().map(Some).collect::<UInt64Array>(),
948 )],
949 )?,
950 ];
951
952 assert_eq!(&expected_batch, &result);
954
955 assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
957 assert_eq!(result[0].schema().metadata(), &schema_metadata);
958
959 Ok(())
960 }
961
962 #[tokio::test]
963 async fn test_lex_sort_by_float() -> Result<()> {
964 let task_ctx = Arc::new(TaskContext::default());
965 let schema = Arc::new(Schema::new(vec![
966 Field::new("a", DataType::Float32, true),
967 Field::new("b", DataType::Float64, true),
968 Field::new("c", DataType::Float64, true),
969 ]));
970 let option_asc = SortOptions {
971 descending: false,
972 nulls_first: true,
973 };
974 let option_desc = SortOptions {
975 descending: true,
976 nulls_first: true,
977 };
978
979 let batch = RecordBatch::try_new(
981 Arc::clone(&schema),
982 vec![
983 Arc::new(Float32Array::from(vec![
984 Some(1.0_f32),
985 Some(1.0_f32),
986 Some(1.0_f32),
987 Some(2.0_f32),
988 Some(2.0_f32),
989 Some(3.0_f32),
990 Some(3.0_f32),
991 Some(3.0_f32),
992 ])),
993 Arc::new(Float64Array::from(vec![
994 Some(20.0_f64),
995 Some(20.0_f64),
996 Some(40.0_f64),
997 Some(40.0_f64),
998 Some(f64::NAN),
999 None,
1000 None,
1001 Some(f64::NAN),
1002 ])),
1003 Arc::new(Float64Array::from(vec![
1004 Some(10.0_f64),
1005 Some(20.0_f64),
1006 Some(10.0_f64),
1007 Some(100.0_f64),
1008 Some(f64::NAN),
1009 Some(100.0_f64),
1010 None,
1011 Some(f64::NAN),
1012 ])),
1013 ],
1014 )?;
1015
1016 let partial_sort_exec = Arc::new(PartialSortExec::new(
1017 [
1018 PhysicalSortExpr {
1019 expr: col("a", &schema)?,
1020 options: option_asc,
1021 },
1022 PhysicalSortExpr {
1023 expr: col("b", &schema)?,
1024 options: option_asc,
1025 },
1026 PhysicalSortExpr {
1027 expr: col("c", &schema)?,
1028 options: option_desc,
1029 },
1030 ]
1031 .into(),
1032 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
1033 2,
1034 ));
1035
1036 assert_eq!(
1037 DataType::Float32,
1038 *partial_sort_exec.schema().field(0).data_type()
1039 );
1040 assert_eq!(
1041 DataType::Float64,
1042 *partial_sort_exec.schema().field(1).data_type()
1043 );
1044 assert_eq!(
1045 DataType::Float64,
1046 *partial_sort_exec.schema().field(2).data_type()
1047 );
1048
1049 let result: Vec<RecordBatch> = collect(
1050 Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1051 task_ctx,
1052 )
1053 .await?;
1054 assert_snapshot!(batches_to_string(&result), @r"
1055 +-----+------+-------+
1056 | a | b | c |
1057 +-----+------+-------+
1058 | 1.0 | 20.0 | 20.0 |
1059 | 1.0 | 20.0 | 10.0 |
1060 | 1.0 | 40.0 | 10.0 |
1061 | 2.0 | 40.0 | 100.0 |
1062 | 2.0 | NaN | NaN |
1063 | 3.0 | | |
1064 | 3.0 | | 100.0 |
1065 | 3.0 | NaN | NaN |
1066 +-----+------+-------+
1067 ");
1068 assert_eq!(result.len(), 2);
1069 let metrics = partial_sort_exec.metrics().unwrap();
1070 assert!(metrics.elapsed_compute().unwrap() > 0);
1071 assert_eq!(metrics.output_rows().unwrap(), 8);
1072
1073 let columns = result[0].columns();
1074
1075 assert_eq!(DataType::Float32, *columns[0].data_type());
1076 assert_eq!(DataType::Float64, *columns[1].data_type());
1077 assert_eq!(DataType::Float64, *columns[2].data_type());
1078
1079 Ok(())
1080 }
1081
1082 #[tokio::test]
1083 async fn test_drop_cancel() -> Result<()> {
1084 let task_ctx = Arc::new(TaskContext::default());
1085 let schema = Arc::new(Schema::new(vec![
1086 Field::new("a", DataType::Float32, true),
1087 Field::new("b", DataType::Float32, true),
1088 ]));
1089
1090 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1091 let refs = blocking_exec.refs();
1092 let sort_exec = Arc::new(PartialSortExec::new(
1093 [PhysicalSortExpr {
1094 expr: col("a", &schema)?,
1095 options: SortOptions::default(),
1096 }]
1097 .into(),
1098 blocking_exec,
1099 1,
1100 ));
1101
1102 let fut = collect(sort_exec, Arc::clone(&task_ctx));
1103 let mut fut = fut.boxed();
1104
1105 assert_is_pending(&mut fut);
1106 drop(fut);
1107 assert_strong_count_converges_to_zero(refs).await;
1108
1109 assert_eq!(
1110 task_ctx.runtime_env().memory_pool.reserved(),
1111 0,
1112 "The sort should have returned all memory used back to the memory manager"
1113 );
1114
1115 Ok(())
1116 }
1117
1118 #[tokio::test]
1119 async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
1120 let task_ctx = Arc::new(TaskContext::default());
1124
1125 let batch1 = test::build_table_i32(
1127 ("a", &vec![1; 3]),
1128 ("b", &vec![1; 3]),
1129 ("c", &vec![3, 2, 1]),
1130 );
1131 let batch2 = test::build_table_i32(
1132 ("a", &vec![2; 3]),
1133 ("b", &vec![2; 3]),
1134 ("c", &vec![4, 6, 4]),
1135 );
1136 let batch3 = test::build_table_i32(
1137 ("a", &vec![3; 3]),
1138 ("b", &vec![3; 3]),
1139 ("c", &vec![9, 7, 8]),
1140 );
1141
1142 let schema = batch1.schema();
1143 let mem_exec = TestMemoryExec::try_new_exec(
1144 &[vec![batch1, batch2, batch3]],
1145 Arc::clone(&schema),
1146 None,
1147 )?;
1148
1149 let option_asc = SortOptions {
1150 descending: false,
1151 nulls_first: false,
1152 };
1153
1154 let partial_sort_exec = Arc::new(PartialSortExec::new(
1156 [
1157 PhysicalSortExpr {
1158 expr: col("a", &schema)?,
1159 options: option_asc,
1160 },
1161 PhysicalSortExpr {
1162 expr: col("b", &schema)?,
1163 options: option_asc,
1164 },
1165 PhysicalSortExpr {
1166 expr: col("c", &schema)?,
1167 options: option_asc,
1168 },
1169 ]
1170 .into(),
1171 mem_exec,
1172 2,
1173 ));
1174
1175 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
1176
1177 assert_eq!(result.len(), 3,);
1178
1179 allow_duplicates! {
1180 assert_snapshot!(batches_to_string(&result), @r"
1181 +---+---+---+
1182 | a | b | c |
1183 +---+---+---+
1184 | 1 | 1 | 1 |
1185 | 1 | 1 | 2 |
1186 | 1 | 1 | 3 |
1187 | 2 | 2 | 4 |
1188 | 2 | 2 | 4 |
1189 | 2 | 2 | 6 |
1190 | 3 | 3 | 7 |
1191 | 3 | 3 | 8 |
1192 | 3 | 3 | 9 |
1193 +---+---+---+
1194 ");
1195 }
1196
1197 assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,);
1198 Ok(())
1199 }
1200}