1use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
34use crate::hash_utils::create_hashes;
35use crate::metrics::BaselineMetrics;
36use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
37use crate::repartition::distributor_channels::{
38 channels, partition_aware_channels, DistributionReceiver, DistributionSender,
39};
40use crate::sorts::streaming_merge::StreamingMergeBuilder;
41use crate::stream::RecordBatchStreamAdapter;
42use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
43
44use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
45use arrow::compute::take_arrays;
46use arrow::datatypes::{SchemaRef, UInt32Type};
47use datafusion_common::config::ConfigOptions;
48use datafusion_common::stats::Precision;
49use datafusion_common::utils::transpose;
50use datafusion_common::{internal_err, ColumnStatistics, HashMap};
51use datafusion_common::{not_impl_err, DataFusionError, Result};
52use datafusion_common_runtime::SpawnedTask;
53use datafusion_execution::memory_pool::MemoryConsumer;
54use datafusion_execution::TaskContext;
55use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
56use datafusion_physical_expr_common::sort_expr::LexOrdering;
57
58use crate::filter_pushdown::{
59 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
60 FilterPushdownPropagation,
61};
62use futures::stream::Stream;
63use futures::{FutureExt, StreamExt, TryStreamExt};
64use log::trace;
65use parking_lot::Mutex;
66
67mod distributor_channels;
68
69type MaybeBatch = Option<Result<RecordBatch>>;
70type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
71type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
72
73#[derive(Debug)]
74struct ConsumingInputStreamsState {
75 channels: HashMap<
78 usize,
79 (
80 InputPartitionsToCurrentPartitionSender,
81 InputPartitionsToCurrentPartitionReceiver,
82 SharedMemoryReservation,
83 ),
84 >,
85
86 abort_helper: Arc<Vec<SpawnedTask<()>>>,
88}
89
90enum RepartitionExecState {
92 NotInitialized,
95 InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
99 ConsumingInputStreams(ConsumingInputStreamsState),
102}
103
104impl Default for RepartitionExecState {
105 fn default() -> Self {
106 Self::NotInitialized
107 }
108}
109
110impl Debug for RepartitionExecState {
111 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112 match self {
113 RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
114 RepartitionExecState::InputStreamsInitialized(v) => {
115 write!(f, "InputStreamsInitialized({:?})", v.len())
116 }
117 RepartitionExecState::ConsumingInputStreams(v) => {
118 write!(f, "ConsumingInputStreams({v:?})")
119 }
120 }
121 }
122}
123
124impl RepartitionExecState {
125 fn ensure_input_streams_initialized(
126 &mut self,
127 input: Arc<dyn ExecutionPlan>,
128 metrics: ExecutionPlanMetricsSet,
129 output_partitions: usize,
130 ctx: Arc<TaskContext>,
131 ) -> Result<()> {
132 if !matches!(self, RepartitionExecState::NotInitialized) {
133 return Ok(());
134 }
135
136 let num_input_partitions = input.output_partitioning().partition_count();
137 let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
138
139 for i in 0..num_input_partitions {
140 let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
141
142 let timer = metrics.fetch_time.timer();
143 let stream = input.execute(i, Arc::clone(&ctx))?;
144 timer.done();
145
146 streams_and_metrics.push((stream, metrics));
147 }
148 *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
149 Ok(())
150 }
151
152 fn consume_input_streams(
153 &mut self,
154 input: Arc<dyn ExecutionPlan>,
155 metrics: ExecutionPlanMetricsSet,
156 partitioning: Partitioning,
157 preserve_order: bool,
158 name: String,
159 context: Arc<TaskContext>,
160 ) -> Result<&mut ConsumingInputStreamsState> {
161 let streams_and_metrics = match self {
162 RepartitionExecState::NotInitialized => {
163 self.ensure_input_streams_initialized(
164 input,
165 metrics,
166 partitioning.partition_count(),
167 Arc::clone(&context),
168 )?;
169 let RepartitionExecState::InputStreamsInitialized(value) = self else {
170 return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
173 };
174 value
175 }
176 RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
177 RepartitionExecState::InputStreamsInitialized(value) => value,
178 };
179
180 let num_input_partitions = streams_and_metrics.len();
181 let num_output_partitions = partitioning.partition_count();
182
183 let (txs, rxs) = if preserve_order {
184 let (txs, rxs) =
185 partition_aware_channels(num_input_partitions, num_output_partitions);
186 let txs = transpose(txs);
188 let rxs = transpose(rxs);
189 (txs, rxs)
190 } else {
191 let (txs, rxs) = channels(num_output_partitions);
195 let txs = txs
197 .into_iter()
198 .map(|item| vec![item; num_input_partitions])
199 .collect::<Vec<_>>();
200 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
201 (txs, rxs)
202 };
203
204 let mut channels = HashMap::with_capacity(txs.len());
205 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
206 let reservation = Arc::new(Mutex::new(
207 MemoryConsumer::new(format!("{name}[{partition}]"))
208 .register(context.memory_pool()),
209 ));
210 channels.insert(partition, (tx, rx, reservation));
211 }
212
213 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
215 for (i, (stream, metrics)) in
216 std::mem::take(streams_and_metrics).into_iter().enumerate()
217 {
218 let txs: HashMap<_, _> = channels
219 .iter()
220 .map(|(partition, (tx, _rx, reservation))| {
221 (*partition, (tx[i].clone(), Arc::clone(reservation)))
222 })
223 .collect();
224
225 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
226 stream,
227 txs.clone(),
228 partitioning.clone(),
229 metrics,
230 ));
231
232 let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
235 input_task,
236 txs.into_iter()
237 .map(|(partition, (tx, _reservation))| (partition, tx))
238 .collect(),
239 ));
240 spawned_tasks.push(wait_for_task);
241 }
242 *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
243 channels,
244 abort_helper: Arc::new(spawned_tasks),
245 });
246 match self {
247 RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
248 _ => unreachable!(),
249 }
250 }
251}
252
253pub struct BatchPartitioner {
255 state: BatchPartitionerState,
256 timer: metrics::Time,
257}
258
259enum BatchPartitionerState {
260 Hash {
261 random_state: ahash::RandomState,
262 exprs: Vec<Arc<dyn PhysicalExpr>>,
263 num_partitions: usize,
264 hash_buffer: Vec<u64>,
265 },
266 RoundRobin {
267 num_partitions: usize,
268 next_idx: usize,
269 },
270}
271
272impl BatchPartitioner {
273 pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
277 let state = match partitioning {
278 Partitioning::RoundRobinBatch(num_partitions) => {
279 BatchPartitionerState::RoundRobin {
280 num_partitions,
281 next_idx: 0,
282 }
283 }
284 Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
285 exprs,
286 num_partitions,
287 random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
289 hash_buffer: vec![],
290 },
291 other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
292 };
293
294 Ok(Self { state, timer })
295 }
296
297 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
307 where
308 F: FnMut(usize, RecordBatch) -> Result<()>,
309 {
310 self.partition_iter(batch)?.try_for_each(|res| match res {
311 Ok((partition, batch)) => f(partition, batch),
312 Err(e) => Err(e),
313 })
314 }
315
316 fn partition_iter(
322 &mut self,
323 batch: RecordBatch,
324 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
325 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
326 match &mut self.state {
327 BatchPartitionerState::RoundRobin {
328 num_partitions,
329 next_idx,
330 } => {
331 let idx = *next_idx;
332 *next_idx = (*next_idx + 1) % *num_partitions;
333 Box::new(std::iter::once(Ok((idx, batch))))
334 }
335 BatchPartitionerState::Hash {
336 random_state,
337 exprs,
338 num_partitions: partitions,
339 hash_buffer,
340 } => {
341 let timer = self.timer.timer();
343
344 let arrays = exprs
345 .iter()
346 .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
347 .collect::<Result<Vec<_>>>()?;
348
349 hash_buffer.clear();
350 hash_buffer.resize(batch.num_rows(), 0);
351
352 create_hashes(&arrays, random_state, hash_buffer)?;
353
354 let mut indices: Vec<_> = (0..*partitions)
355 .map(|_| Vec::with_capacity(batch.num_rows()))
356 .collect();
357
358 for (index, hash) in hash_buffer.iter().enumerate() {
359 indices[(*hash % *partitions as u64) as usize].push(index as u32);
360 }
361
362 timer.done();
364
365 let partitioner_timer = &self.timer;
367 let it = indices
368 .into_iter()
369 .enumerate()
370 .filter_map(|(partition, indices)| {
371 let indices: PrimitiveArray<UInt32Type> = indices.into();
372 (!indices.is_empty()).then_some((partition, indices))
373 })
374 .map(move |(partition, indices)| {
375 let _timer = partitioner_timer.timer();
377
378 let columns = take_arrays(batch.columns(), &indices, None)?;
380
381 let mut options = RecordBatchOptions::new();
382 options = options.with_row_count(Some(indices.len()));
383 let batch = RecordBatch::try_new_with_options(
384 batch.schema(),
385 columns,
386 &options,
387 )
388 .unwrap();
389
390 Ok((partition, batch))
391 });
392
393 Box::new(it)
394 }
395 };
396
397 Ok(it)
398 }
399
400 fn num_partitions(&self) -> usize {
402 match self.state {
403 BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
404 BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
405 }
406 }
407}
408
409#[derive(Debug, Clone)]
476pub struct RepartitionExec {
477 input: Arc<dyn ExecutionPlan>,
479 state: Arc<Mutex<RepartitionExecState>>,
482 metrics: ExecutionPlanMetricsSet,
484 preserve_order: bool,
487 cache: PlanProperties,
489}
490
491#[derive(Debug, Clone)]
492struct RepartitionMetrics {
493 fetch_time: metrics::Time,
495 repartition_time: metrics::Time,
497 send_time: Vec<metrics::Time>,
501}
502
503impl RepartitionMetrics {
504 pub fn new(
505 input_partition: usize,
506 num_output_partitions: usize,
507 metrics: &ExecutionPlanMetricsSet,
508 ) -> Self {
509 let fetch_time =
511 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
512
513 let repartition_time =
515 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
516
517 let send_time = (0..num_output_partitions)
519 .map(|output_partition| {
520 let label =
521 metrics::Label::new("outputPartition", output_partition.to_string());
522 MetricBuilder::new(metrics)
523 .with_label(label)
524 .subset_time("send_time", input_partition)
525 })
526 .collect();
527
528 Self {
529 fetch_time,
530 repartition_time,
531 send_time,
532 }
533 }
534}
535
536impl RepartitionExec {
537 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
539 &self.input
540 }
541
542 pub fn partitioning(&self) -> &Partitioning {
544 &self.cache.partitioning
545 }
546
547 pub fn preserve_order(&self) -> bool {
550 self.preserve_order
551 }
552
553 pub fn name(&self) -> &str {
555 "RepartitionExec"
556 }
557}
558
559impl DisplayAs for RepartitionExec {
560 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
561 match t {
562 DisplayFormatType::Default | DisplayFormatType::Verbose => {
563 write!(
564 f,
565 "{}: partitioning={}, input_partitions={}",
566 self.name(),
567 self.partitioning(),
568 self.input.output_partitioning().partition_count()
569 )?;
570
571 if self.preserve_order {
572 write!(f, ", preserve_order=true")?;
573 }
574
575 if let Some(sort_exprs) = self.sort_exprs() {
576 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
577 }
578 Ok(())
579 }
580 DisplayFormatType::TreeRender => {
581 writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
582
583 let input_partition_count =
584 self.input.output_partitioning().partition_count();
585 let output_partition_count = self.partitioning().partition_count();
586 let input_to_output_partition_str =
587 format!("{input_partition_count} -> {output_partition_count}");
588 writeln!(
589 f,
590 "partition_count(in->out)={input_to_output_partition_str}"
591 )?;
592
593 if self.preserve_order {
594 writeln!(f, "preserve_order={}", self.preserve_order)?;
595 }
596 Ok(())
597 }
598 }
599 }
600}
601
602impl ExecutionPlan for RepartitionExec {
603 fn name(&self) -> &'static str {
604 "RepartitionExec"
605 }
606
607 fn as_any(&self) -> &dyn Any {
609 self
610 }
611
612 fn properties(&self) -> &PlanProperties {
613 &self.cache
614 }
615
616 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
617 vec![&self.input]
618 }
619
620 fn with_new_children(
621 self: Arc<Self>,
622 mut children: Vec<Arc<dyn ExecutionPlan>>,
623 ) -> Result<Arc<dyn ExecutionPlan>> {
624 let mut repartition = RepartitionExec::try_new(
625 children.swap_remove(0),
626 self.partitioning().clone(),
627 )?;
628 if self.preserve_order {
629 repartition = repartition.with_preserve_order();
630 }
631 Ok(Arc::new(repartition))
632 }
633
634 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
635 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
636 }
637
638 fn maintains_input_order(&self) -> Vec<bool> {
639 Self::maintains_input_order_helper(self.input(), self.preserve_order)
640 }
641
642 fn execute(
643 &self,
644 partition: usize,
645 context: Arc<TaskContext>,
646 ) -> Result<SendableRecordBatchStream> {
647 trace!(
648 "Start {}::execute for partition: {}",
649 self.name(),
650 partition
651 );
652
653 let input = Arc::clone(&self.input);
654 let partitioning = self.partitioning().clone();
655 let metrics = self.metrics.clone();
656 let preserve_order = self.sort_exprs().is_some();
657 let name = self.name().to_owned();
658 let schema = self.schema();
659 let schema_captured = Arc::clone(&schema);
660
661 let sort_exprs = self.sort_exprs().cloned();
663
664 let state = Arc::clone(&self.state);
665 if let Some(mut state) = state.try_lock() {
666 state.ensure_input_streams_initialized(
667 Arc::clone(&input),
668 metrics.clone(),
669 partitioning.partition_count(),
670 Arc::clone(&context),
671 )?;
672 }
673
674 let stream = futures::stream::once(async move {
675 let num_input_partitions = input.output_partitioning().partition_count();
676
677 let (mut rx, reservation, abort_helper) = {
679 let mut state = state.lock();
681 let state = state.consume_input_streams(
682 Arc::clone(&input),
683 metrics.clone(),
684 partitioning,
685 preserve_order,
686 name.clone(),
687 Arc::clone(&context),
688 )?;
689
690 let (_tx, rx, reservation) = state
693 .channels
694 .remove(&partition)
695 .expect("partition not used yet");
696
697 (rx, reservation, Arc::clone(&state.abort_helper))
698 };
699
700 trace!(
701 "Before returning stream in {name}::execute for partition: {partition}"
702 );
703
704 if preserve_order {
705 let input_streams = rx
707 .into_iter()
708 .map(|receiver| {
709 Box::pin(PerPartitionStream {
710 schema: Arc::clone(&schema_captured),
711 receiver,
712 _drop_helper: Arc::clone(&abort_helper),
713 reservation: Arc::clone(&reservation),
714 }) as SendableRecordBatchStream
715 })
716 .collect::<Vec<_>>();
717 let fetch = None;
722 let merge_reservation =
723 MemoryConsumer::new(format!("{name}[Merge {partition}]"))
724 .register(context.memory_pool());
725 StreamingMergeBuilder::new()
726 .with_streams(input_streams)
727 .with_schema(schema_captured)
728 .with_expressions(&sort_exprs.unwrap())
729 .with_metrics(BaselineMetrics::new(&metrics, partition))
730 .with_batch_size(context.session_config().batch_size())
731 .with_fetch(fetch)
732 .with_reservation(merge_reservation)
733 .build()
734 } else {
735 Ok(Box::pin(RepartitionStream {
736 num_input_partitions,
737 num_input_partitions_processed: 0,
738 schema: input.schema(),
739 input: rx.swap_remove(0),
740 _drop_helper: abort_helper,
741 reservation,
742 }) as SendableRecordBatchStream)
743 }
744 })
745 .try_flatten();
746 let stream = RecordBatchStreamAdapter::new(schema, stream);
747 Ok(Box::pin(stream))
748 }
749
750 fn metrics(&self) -> Option<MetricsSet> {
751 Some(self.metrics.clone_inner())
752 }
753
754 fn statistics(&self) -> Result<Statistics> {
755 self.input.partition_statistics(None)
756 }
757
758 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
759 if let Some(partition) = partition {
760 let partition_count = self.partitioning().partition_count();
761 if partition_count == 0 {
762 return Ok(Statistics::new_unknown(&self.schema()));
763 }
764
765 if partition >= partition_count {
766 return internal_err!(
767 "RepartitionExec invalid partition {} (expected less than {})",
768 partition,
769 self.partitioning().partition_count()
770 );
771 }
772
773 let mut stats = self.input.partition_statistics(None)?;
774
775 stats.num_rows = stats
777 .num_rows
778 .get_value()
779 .map(|rows| Precision::Inexact(rows / partition_count))
780 .unwrap_or(Precision::Absent);
781 stats.total_byte_size = stats
782 .total_byte_size
783 .get_value()
784 .map(|bytes| Precision::Inexact(bytes / partition_count))
785 .unwrap_or(Precision::Absent);
786
787 stats.column_statistics = stats
789 .column_statistics
790 .iter()
791 .map(|_| ColumnStatistics::new_unknown())
792 .collect();
793
794 Ok(stats)
795 } else {
796 self.input.partition_statistics(None)
797 }
798 }
799
800 fn cardinality_effect(&self) -> CardinalityEffect {
801 CardinalityEffect::Equal
802 }
803
804 fn try_swapping_with_projection(
805 &self,
806 projection: &ProjectionExec,
807 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
808 if projection.expr().len() >= projection.input().schema().fields().len() {
810 return Ok(None);
811 }
812
813 if projection.benefits_from_input_partitioning()[0]
815 || !all_columns(projection.expr())
816 {
817 return Ok(None);
818 }
819
820 let new_projection = make_with_child(projection, self.input())?;
821
822 let new_partitioning = match self.partitioning() {
823 Partitioning::Hash(partitions, size) => {
824 let mut new_partitions = vec![];
825 for partition in partitions {
826 let Some(new_partition) =
827 update_expr(partition, projection.expr(), false)?
828 else {
829 return Ok(None);
830 };
831 new_partitions.push(new_partition);
832 }
833 Partitioning::Hash(new_partitions, *size)
834 }
835 others => others.clone(),
836 };
837
838 Ok(Some(Arc::new(RepartitionExec::try_new(
839 new_projection,
840 new_partitioning,
841 )?)))
842 }
843
844 fn gather_filters_for_pushdown(
845 &self,
846 _phase: FilterPushdownPhase,
847 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
848 _config: &ConfigOptions,
849 ) -> Result<FilterDescription> {
850 FilterDescription::from_children(parent_filters, &self.children())
851 }
852
853 fn handle_child_pushdown_result(
854 &self,
855 _phase: FilterPushdownPhase,
856 child_pushdown_result: ChildPushdownResult,
857 _config: &ConfigOptions,
858 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
859 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
860 }
861}
862
863impl RepartitionExec {
864 pub fn try_new(
868 input: Arc<dyn ExecutionPlan>,
869 partitioning: Partitioning,
870 ) -> Result<Self> {
871 let preserve_order = false;
872 let cache =
873 Self::compute_properties(&input, partitioning.clone(), preserve_order);
874 Ok(RepartitionExec {
875 input,
876 state: Default::default(),
877 metrics: ExecutionPlanMetricsSet::new(),
878 preserve_order,
879 cache,
880 })
881 }
882
883 fn maintains_input_order_helper(
884 input: &Arc<dyn ExecutionPlan>,
885 preserve_order: bool,
886 ) -> Vec<bool> {
887 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
889 }
890
891 fn eq_properties_helper(
892 input: &Arc<dyn ExecutionPlan>,
893 preserve_order: bool,
894 ) -> EquivalenceProperties {
895 let mut eq_properties = input.equivalence_properties().clone();
897 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
899 eq_properties.clear_orderings();
900 }
901 if input.output_partitioning().partition_count() > 1 {
904 eq_properties.clear_per_partition_constants();
905 }
906 eq_properties
907 }
908
909 fn compute_properties(
911 input: &Arc<dyn ExecutionPlan>,
912 partitioning: Partitioning,
913 preserve_order: bool,
914 ) -> PlanProperties {
915 PlanProperties::new(
916 Self::eq_properties_helper(input, preserve_order),
917 partitioning,
918 input.pipeline_behavior(),
919 input.boundedness(),
920 )
921 .with_scheduling_type(SchedulingType::Cooperative)
922 .with_evaluation_type(EvaluationType::Eager)
923 }
924
925 pub fn with_preserve_order(mut self) -> Self {
933 self.preserve_order =
934 self.input.output_ordering().is_some() &&
936 self.input.output_partitioning().partition_count() > 1;
939 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
940 self.cache = self.cache.with_eq_properties(eq_properties);
941 self
942 }
943
944 fn sort_exprs(&self) -> Option<&LexOrdering> {
946 if self.preserve_order {
947 self.input.output_ordering()
948 } else {
949 None
950 }
951 }
952
953 async fn pull_from_input(
958 mut stream: SendableRecordBatchStream,
959 mut output_channels: HashMap<
960 usize,
961 (DistributionSender<MaybeBatch>, SharedMemoryReservation),
962 >,
963 partitioning: Partitioning,
964 metrics: RepartitionMetrics,
965 ) -> Result<()> {
966 let mut partitioner =
967 BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
968
969 let mut batches_until_yield = partitioner.num_partitions();
971 while !output_channels.is_empty() {
972 let timer = metrics.fetch_time.timer();
974 let result = stream.next().await;
975 timer.done();
976
977 let batch = match result {
979 Some(result) => result?,
980 None => break,
981 };
982
983 for res in partitioner.partition_iter(batch)? {
984 let (partition, batch) = res?;
985 let size = batch.get_array_memory_size();
986
987 let timer = metrics.send_time[partition].timer();
988 if let Some((tx, reservation)) = output_channels.get_mut(&partition) {
990 reservation.lock().try_grow(size)?;
991
992 if tx.send(Some(Ok(batch))).await.is_err() {
993 reservation.lock().shrink(size);
995 output_channels.remove(&partition);
996 }
997 }
998 timer.done();
999 }
1000
1001 if batches_until_yield == 0 {
1018 tokio::task::yield_now().await;
1019 batches_until_yield = partitioner.num_partitions();
1020 } else {
1021 batches_until_yield -= 1;
1022 }
1023 }
1024
1025 Ok(())
1026 }
1027
1028 async fn wait_for_task(
1034 input_task: SpawnedTask<Result<()>>,
1035 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1036 ) {
1037 match input_task.join().await {
1041 Err(e) => {
1043 let e = Arc::new(e);
1044
1045 for (_, tx) in txs {
1046 let err = Err(DataFusionError::Context(
1047 "Join Error".to_string(),
1048 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1049 ));
1050 tx.send(Some(err)).await.ok();
1051 }
1052 }
1053 Ok(Err(e)) => {
1055 let e = Arc::new(e);
1057
1058 for (_, tx) in txs {
1059 let err = Err(DataFusionError::from(&e));
1061 tx.send(Some(err)).await.ok();
1062 }
1063 }
1064 Ok(Ok(())) => {
1066 for (_, tx) in txs {
1068 tx.send(None).await.ok();
1069 }
1070 }
1071 }
1072 }
1073}
1074
1075struct RepartitionStream {
1076 num_input_partitions: usize,
1078
1079 num_input_partitions_processed: usize,
1081
1082 schema: SchemaRef,
1084
1085 input: DistributionReceiver<MaybeBatch>,
1087
1088 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1090
1091 reservation: SharedMemoryReservation,
1093}
1094
1095impl Stream for RepartitionStream {
1096 type Item = Result<RecordBatch>;
1097
1098 fn poll_next(
1099 mut self: Pin<&mut Self>,
1100 cx: &mut Context<'_>,
1101 ) -> Poll<Option<Self::Item>> {
1102 loop {
1103 match self.input.recv().poll_unpin(cx) {
1104 Poll::Ready(Some(Some(v))) => {
1105 if let Ok(batch) = &v {
1106 self.reservation
1107 .lock()
1108 .shrink(batch.get_array_memory_size());
1109 }
1110
1111 return Poll::Ready(Some(v));
1112 }
1113 Poll::Ready(Some(None)) => {
1114 self.num_input_partitions_processed += 1;
1115
1116 if self.num_input_partitions == self.num_input_partitions_processed {
1117 return Poll::Ready(None);
1119 } else {
1120 continue;
1122 }
1123 }
1124 Poll::Ready(None) => {
1125 return Poll::Ready(None);
1126 }
1127 Poll::Pending => {
1128 return Poll::Pending;
1129 }
1130 }
1131 }
1132 }
1133}
1134
1135impl RecordBatchStream for RepartitionStream {
1136 fn schema(&self) -> SchemaRef {
1138 Arc::clone(&self.schema)
1139 }
1140}
1141
1142struct PerPartitionStream {
1145 schema: SchemaRef,
1147
1148 receiver: DistributionReceiver<MaybeBatch>,
1150
1151 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1153
1154 reservation: SharedMemoryReservation,
1156}
1157
1158impl Stream for PerPartitionStream {
1159 type Item = Result<RecordBatch>;
1160
1161 fn poll_next(
1162 mut self: Pin<&mut Self>,
1163 cx: &mut Context<'_>,
1164 ) -> Poll<Option<Self::Item>> {
1165 match self.receiver.recv().poll_unpin(cx) {
1166 Poll::Ready(Some(Some(v))) => {
1167 if let Ok(batch) = &v {
1168 self.reservation
1169 .lock()
1170 .shrink(batch.get_array_memory_size());
1171 }
1172 Poll::Ready(Some(v))
1173 }
1174 Poll::Ready(Some(None)) => {
1175 Poll::Ready(None)
1177 }
1178 Poll::Ready(None) => Poll::Ready(None),
1179 Poll::Pending => Poll::Pending,
1180 }
1181 }
1182}
1183
1184impl RecordBatchStream for PerPartitionStream {
1185 fn schema(&self) -> SchemaRef {
1187 Arc::clone(&self.schema)
1188 }
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193 use std::collections::HashSet;
1194
1195 use super::*;
1196 use crate::test::TestMemoryExec;
1197 use crate::{
1198 test::{
1199 assert_is_pending,
1200 exec::{
1201 assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1202 ErrorExec, MockExec,
1203 },
1204 },
1205 {collect, expressions::col},
1206 };
1207
1208 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1209 use arrow::datatypes::{DataType, Field, Schema};
1210 use datafusion_common::cast::as_string_array;
1211 use datafusion_common::test_util::batches_to_sort_string;
1212 use datafusion_common::{arrow_datafusion_err, exec_err};
1213 use datafusion_common_runtime::JoinSet;
1214 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1215 use insta::assert_snapshot;
1216 use itertools::Itertools;
1217
1218 #[tokio::test]
1219 async fn one_to_many_round_robin() -> Result<()> {
1220 let schema = test_schema();
1222 let partition = create_vec_batches(50);
1223 let partitions = vec![partition];
1224
1225 let output_partitions =
1227 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1228
1229 assert_eq!(4, output_partitions.len());
1230 assert_eq!(13, output_partitions[0].len());
1231 assert_eq!(13, output_partitions[1].len());
1232 assert_eq!(12, output_partitions[2].len());
1233 assert_eq!(12, output_partitions[3].len());
1234
1235 Ok(())
1236 }
1237
1238 #[tokio::test]
1239 async fn many_to_one_round_robin() -> Result<()> {
1240 let schema = test_schema();
1242 let partition = create_vec_batches(50);
1243 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1244
1245 let output_partitions =
1247 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1248
1249 assert_eq!(1, output_partitions.len());
1250 assert_eq!(150, output_partitions[0].len());
1251
1252 Ok(())
1253 }
1254
1255 #[tokio::test]
1256 async fn many_to_many_round_robin() -> Result<()> {
1257 let schema = test_schema();
1259 let partition = create_vec_batches(50);
1260 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1261
1262 let output_partitions =
1264 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1265
1266 assert_eq!(5, output_partitions.len());
1267 assert_eq!(30, output_partitions[0].len());
1268 assert_eq!(30, output_partitions[1].len());
1269 assert_eq!(30, output_partitions[2].len());
1270 assert_eq!(30, output_partitions[3].len());
1271 assert_eq!(30, output_partitions[4].len());
1272
1273 Ok(())
1274 }
1275
1276 #[tokio::test]
1277 async fn many_to_many_hash_partition() -> Result<()> {
1278 let schema = test_schema();
1280 let partition = create_vec_batches(50);
1281 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1282
1283 let output_partitions = repartition(
1284 &schema,
1285 partitions,
1286 Partitioning::Hash(vec![col("c0", &schema)?], 8),
1287 )
1288 .await?;
1289
1290 let total_rows: usize = output_partitions
1291 .iter()
1292 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1293 .sum();
1294
1295 assert_eq!(8, output_partitions.len());
1296 assert_eq!(total_rows, 8 * 50 * 3);
1297
1298 Ok(())
1299 }
1300
1301 fn test_schema() -> Arc<Schema> {
1302 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1303 }
1304
1305 async fn repartition(
1306 schema: &SchemaRef,
1307 input_partitions: Vec<Vec<RecordBatch>>,
1308 partitioning: Partitioning,
1309 ) -> Result<Vec<Vec<RecordBatch>>> {
1310 let task_ctx = Arc::new(TaskContext::default());
1311 let exec =
1313 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1314 let exec = RepartitionExec::try_new(exec, partitioning)?;
1315
1316 let mut output_partitions = vec![];
1318 for i in 0..exec.partitioning().partition_count() {
1319 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1321 let mut batches = vec![];
1322 while let Some(result) = stream.next().await {
1323 batches.push(result?);
1324 }
1325 output_partitions.push(batches);
1326 }
1327 Ok(output_partitions)
1328 }
1329
1330 #[tokio::test]
1331 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1332 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1333 SpawnedTask::spawn(async move {
1334 let schema = test_schema();
1336 let partition = create_vec_batches(50);
1337 let partitions =
1338 vec![partition.clone(), partition.clone(), partition.clone()];
1339
1340 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1342 });
1343
1344 let output_partitions = handle.join().await.unwrap().unwrap();
1345
1346 assert_eq!(5, output_partitions.len());
1347 assert_eq!(30, output_partitions[0].len());
1348 assert_eq!(30, output_partitions[1].len());
1349 assert_eq!(30, output_partitions[2].len());
1350 assert_eq!(30, output_partitions[3].len());
1351 assert_eq!(30, output_partitions[4].len());
1352
1353 Ok(())
1354 }
1355
1356 #[tokio::test]
1357 async fn unsupported_partitioning() {
1358 let task_ctx = Arc::new(TaskContext::default());
1359 let batch = RecordBatch::try_from_iter(vec![(
1361 "my_awesome_field",
1362 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1363 )])
1364 .unwrap();
1365
1366 let schema = batch.schema();
1367 let input = MockExec::new(vec![Ok(batch)], schema);
1368 let partitioning = Partitioning::UnknownPartitioning(1);
1372 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1373 let output_stream = exec.execute(0, task_ctx).unwrap();
1374
1375 let result_string = crate::common::collect(output_stream)
1377 .await
1378 .unwrap_err()
1379 .to_string();
1380 assert!(
1381 result_string
1382 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1383 "actual: {result_string}"
1384 );
1385 }
1386
1387 #[tokio::test]
1388 async fn error_for_input_exec() {
1389 let task_ctx = Arc::new(TaskContext::default());
1393 let input = ErrorExec::new();
1394 let partitioning = Partitioning::RoundRobinBatch(1);
1395 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1396
1397 let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1399
1400 assert!(
1401 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1402 "actual: {result_string}"
1403 );
1404 }
1405
1406 #[tokio::test]
1407 async fn repartition_with_error_in_stream() {
1408 let task_ctx = Arc::new(TaskContext::default());
1409 let batch = RecordBatch::try_from_iter(vec![(
1410 "my_awesome_field",
1411 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1412 )])
1413 .unwrap();
1414
1415 let err = exec_err!("bad data error");
1418
1419 let schema = batch.schema();
1420 let input = MockExec::new(vec![Ok(batch), err], schema);
1421 let partitioning = Partitioning::RoundRobinBatch(1);
1422 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1423
1424 let output_stream = exec.execute(0, task_ctx).unwrap();
1427
1428 let result_string = crate::common::collect(output_stream)
1430 .await
1431 .unwrap_err()
1432 .to_string();
1433 assert!(
1434 result_string.contains("bad data error"),
1435 "actual: {result_string}"
1436 );
1437 }
1438
1439 #[tokio::test]
1440 async fn repartition_with_delayed_stream() {
1441 let task_ctx = Arc::new(TaskContext::default());
1442 let batch1 = RecordBatch::try_from_iter(vec![(
1443 "my_awesome_field",
1444 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1445 )])
1446 .unwrap();
1447
1448 let batch2 = RecordBatch::try_from_iter(vec![(
1449 "my_awesome_field",
1450 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1451 )])
1452 .unwrap();
1453
1454 let schema = batch1.schema();
1457 let expected_batches = vec![batch1.clone(), batch2.clone()];
1458 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1459 let partitioning = Partitioning::RoundRobinBatch(1);
1460
1461 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1462
1463 assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
1464 +------------------+
1465 | my_awesome_field |
1466 +------------------+
1467 | bar |
1468 | baz |
1469 | foo |
1470 | frob |
1471 +------------------+
1472 ");
1473
1474 let output_stream = exec.execute(0, task_ctx).unwrap();
1475 let batches = crate::common::collect(output_stream).await.unwrap();
1476
1477 assert_snapshot!(batches_to_sort_string(&batches), @r"
1478 +------------------+
1479 | my_awesome_field |
1480 +------------------+
1481 | bar |
1482 | baz |
1483 | foo |
1484 | frob |
1485 +------------------+
1486 ");
1487 }
1488
1489 #[tokio::test]
1490 async fn robin_repartition_with_dropping_output_stream() {
1491 let task_ctx = Arc::new(TaskContext::default());
1492 let partitioning = Partitioning::RoundRobinBatch(2);
1493 let input = Arc::new(make_barrier_exec());
1496
1497 let exec = RepartitionExec::try_new(
1499 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1500 partitioning,
1501 )
1502 .unwrap();
1503
1504 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1505 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1506
1507 drop(output_stream0);
1510
1511 let mut background_task = JoinSet::new();
1513 background_task.spawn(async move {
1514 input.wait().await;
1515 });
1516
1517 let batches = crate::common::collect(output_stream1).await.unwrap();
1519
1520 assert_snapshot!(batches_to_sort_string(&batches), @r#"
1521 +------------------+
1522 | my_awesome_field |
1523 +------------------+
1524 | baz |
1525 | frob |
1526 | gaz |
1527 | grob |
1528 +------------------+
1529 "#);
1530 }
1531
1532 #[tokio::test]
1533 async fn hash_repartition_with_dropping_output_stream() {
1537 let task_ctx = Arc::new(TaskContext::default());
1538 let partitioning = Partitioning::Hash(
1539 vec![Arc::new(crate::expressions::Column::new(
1540 "my_awesome_field",
1541 0,
1542 ))],
1543 2,
1544 );
1545
1546 let input = Arc::new(make_barrier_exec());
1548 let exec = RepartitionExec::try_new(
1549 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1550 partitioning.clone(),
1551 )
1552 .unwrap();
1553 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1554 let mut background_task = JoinSet::new();
1555 background_task.spawn(async move {
1556 input.wait().await;
1557 });
1558 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1559
1560 let items_vec = str_batches_to_vec(&batches_without_drop);
1562 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1563 assert_eq!(items_vec.len(), items_set.len());
1564 let source_str_set: HashSet<&str> =
1565 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1566 .iter()
1567 .copied()
1568 .collect();
1569 assert_eq!(items_set.difference(&source_str_set).count(), 0);
1570
1571 let input = Arc::new(make_barrier_exec());
1573 let exec = RepartitionExec::try_new(
1574 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1575 partitioning,
1576 )
1577 .unwrap();
1578 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1579 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1580 drop(output_stream0);
1583 let mut background_task = JoinSet::new();
1584 background_task.spawn(async move {
1585 input.wait().await;
1586 });
1587 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1588
1589 fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1590 batch
1591 .into_iter()
1592 .sorted_by_key(|b| format!("{b:?}"))
1593 .collect()
1594 }
1595
1596 assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1597 }
1598
1599 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1600 batches
1601 .iter()
1602 .flat_map(|batch| {
1603 assert_eq!(batch.columns().len(), 1);
1604 let string_array = as_string_array(batch.column(0))
1605 .expect("Unexpected type for repartitioned batch");
1606
1607 string_array
1608 .iter()
1609 .map(|v| v.expect("Unexpected null"))
1610 .collect::<Vec<_>>()
1611 })
1612 .collect::<Vec<_>>()
1613 }
1614
1615 fn make_barrier_exec() -> BarrierExec {
1617 let batch1 = RecordBatch::try_from_iter(vec![(
1618 "my_awesome_field",
1619 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1620 )])
1621 .unwrap();
1622
1623 let batch2 = RecordBatch::try_from_iter(vec![(
1624 "my_awesome_field",
1625 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1626 )])
1627 .unwrap();
1628
1629 let batch3 = RecordBatch::try_from_iter(vec![(
1630 "my_awesome_field",
1631 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1632 )])
1633 .unwrap();
1634
1635 let batch4 = RecordBatch::try_from_iter(vec![(
1636 "my_awesome_field",
1637 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1638 )])
1639 .unwrap();
1640
1641 let schema = batch1.schema();
1644 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1645 }
1646
1647 #[tokio::test]
1648 async fn test_drop_cancel() -> Result<()> {
1649 let task_ctx = Arc::new(TaskContext::default());
1650 let schema =
1651 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1652
1653 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1654 let refs = blocking_exec.refs();
1655 let repartition_exec = Arc::new(RepartitionExec::try_new(
1656 blocking_exec,
1657 Partitioning::UnknownPartitioning(1),
1658 )?);
1659
1660 let fut = collect(repartition_exec, task_ctx);
1661 let mut fut = fut.boxed();
1662
1663 assert_is_pending(&mut fut);
1664 drop(fut);
1665 assert_strong_count_converges_to_zero(refs).await;
1666
1667 Ok(())
1668 }
1669
1670 #[tokio::test]
1671 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1672 let task_ctx = Arc::new(TaskContext::default());
1673 let batch = RecordBatch::try_from_iter(vec![(
1674 "a",
1675 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1676 )])
1677 .unwrap();
1678 let partitioning = Partitioning::Hash(
1679 vec![Arc::new(crate::expressions::Column::new("a", 0))],
1680 2,
1681 );
1682 let schema = batch.schema();
1683 let input = MockExec::new(vec![Ok(batch)], schema);
1684 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1685 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1686 let batch0 = crate::common::collect(output_stream0).await.unwrap();
1687 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1688 let batch1 = crate::common::collect(output_stream1).await.unwrap();
1689 assert!(batch0.is_empty() || batch1.is_empty());
1690 Ok(())
1691 }
1692
1693 #[tokio::test]
1694 async fn oom() -> Result<()> {
1695 let schema = test_schema();
1697 let partition = create_vec_batches(50);
1698 let input_partitions = vec![partition];
1699 let partitioning = Partitioning::RoundRobinBatch(4);
1700
1701 let runtime = RuntimeEnvBuilder::default()
1703 .with_memory_limit(1, 1.0)
1704 .build_arc()?;
1705
1706 let task_ctx = TaskContext::default().with_runtime(runtime);
1707 let task_ctx = Arc::new(task_ctx);
1708
1709 let exec =
1711 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1712 let exec = RepartitionExec::try_new(exec, partitioning)?;
1713
1714 for i in 0..exec.partitioning().partition_count() {
1716 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1717 let err =
1718 arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into());
1719 let err = err.find_root();
1720 assert!(
1721 matches!(err, DataFusionError::ResourcesExhausted(_)),
1722 "Wrong error type: {err}",
1723 );
1724 }
1725
1726 Ok(())
1727 }
1728
1729 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
1731 let batch = create_batch();
1732 (0..n).map(|_| batch.clone()).collect()
1733 }
1734
1735 fn create_batch() -> RecordBatch {
1737 let schema = test_schema();
1738 RecordBatch::try_new(
1739 schema,
1740 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
1741 )
1742 .unwrap()
1743 }
1744}
1745
1746#[cfg(test)]
1747mod test {
1748 use arrow::compute::SortOptions;
1749 use arrow::datatypes::{DataType, Field, Schema};
1750
1751 use super::*;
1752 use crate::test::TestMemoryExec;
1753 use crate::union::UnionExec;
1754
1755 use datafusion_physical_expr::expressions::col;
1756 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1757
1758 macro_rules! assert_plan {
1764 ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => {
1765 let physical_plan = $PLAN;
1766 let formatted = crate::displayable(&physical_plan).indent(true).to_string();
1767 let actual: Vec<&str> = formatted.trim().lines().collect();
1768
1769 let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES
1770 .iter().map(|s| *s).collect();
1771
1772 assert_eq!(
1773 expected_plan_lines, actual,
1774 "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n"
1775 );
1776 };
1777 }
1778
1779 #[tokio::test]
1780 async fn test_preserve_order() -> Result<()> {
1781 let schema = test_schema();
1782 let sort_exprs = sort_exprs(&schema);
1783 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
1784 let source2 = sorted_memory_exec(&schema, sort_exprs);
1785 let union = UnionExec::new(vec![source1, source2]);
1787 let exec =
1788 RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1789 .unwrap()
1790 .with_preserve_order();
1791
1792 let expected_plan = [
1794 "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC",
1795 " UnionExec",
1796 " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1797 " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1798 ];
1799 assert_plan!(expected_plan, exec);
1800 Ok(())
1801 }
1802
1803 #[tokio::test]
1804 async fn test_preserve_order_one_partition() -> Result<()> {
1805 let schema = test_schema();
1806 let sort_exprs = sort_exprs(&schema);
1807 let source = sorted_memory_exec(&schema, sort_exprs);
1808 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))
1810 .unwrap()
1811 .with_preserve_order();
1812
1813 let expected_plan = [
1815 "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1",
1816 " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1817 ];
1818 assert_plan!(expected_plan, exec);
1819 Ok(())
1820 }
1821
1822 #[tokio::test]
1823 async fn test_preserve_order_input_not_sorted() -> Result<()> {
1824 let schema = test_schema();
1825 let source1 = memory_exec(&schema);
1826 let source2 = memory_exec(&schema);
1827 let union = UnionExec::new(vec![source1, source2]);
1829 let exec =
1830 RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1831 .unwrap()
1832 .with_preserve_order();
1833
1834 let expected_plan = [
1836 "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
1837 " UnionExec",
1838 " DataSourceExec: partitions=1, partition_sizes=[0]",
1839 " DataSourceExec: partitions=1, partition_sizes=[0]",
1840 ];
1841 assert_plan!(expected_plan, exec);
1842 Ok(())
1843 }
1844
1845 fn test_schema() -> Arc<Schema> {
1846 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1847 }
1848
1849 fn sort_exprs(schema: &Schema) -> LexOrdering {
1850 [PhysicalSortExpr {
1851 expr: col("c0", schema).unwrap(),
1852 options: SortOptions::default(),
1853 }]
1854 .into()
1855 }
1856
1857 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
1858 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
1859 }
1860
1861 fn sorted_memory_exec(
1862 schema: &SchemaRef,
1863 sort_exprs: LexOrdering,
1864 ) -> Arc<dyn ExecutionPlan> {
1865 Arc::new(TestMemoryExec::update_cache(Arc::new(
1866 TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
1867 .unwrap()
1868 .try_with_sort_information(vec![sort_exprs])
1869 .unwrap(),
1870 )))
1871 }
1872}