1use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::task::{Context, Poll};
27use std::vec;
28
29use super::common::SharedMemoryReservation;
30use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
31use super::{
32 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
33};
34use crate::coalesce::LimitedBatchCoalescer;
35use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
36use crate::hash_utils::create_hashes;
37use crate::metrics::{BaselineMetrics, SpillMetrics};
38use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
39use crate::sorts::streaming_merge::StreamingMergeBuilder;
40use crate::spill::spill_manager::SpillManager;
41use crate::spill::spill_pool::{self, SpillPoolWriter};
42use crate::stream::{EmptyRecordBatchStream, RecordBatchStreamAdapter};
43use crate::{
44 DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics,
45 check_if_same_properties,
46};
47
48use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
49use arrow::compute::take_arrays;
50use arrow::datatypes::{SchemaRef, UInt32Type};
51use datafusion_common::config::ConfigOptions;
52use datafusion_common::stats::Precision;
53use datafusion_common::utils::transpose;
54use datafusion_common::{
55 ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err,
56};
57use datafusion_common::{Result, not_impl_err};
58use datafusion_common_runtime::SpawnedTask;
59use datafusion_execution::TaskContext;
60use datafusion_execution::memory_pool::MemoryConsumer;
61use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
62use datafusion_physical_expr_common::sort_expr::LexOrdering;
63
64use crate::filter_pushdown::{
65 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
66 FilterPushdownPropagation,
67};
68use crate::joins::SeededRandomState;
69use crate::sort_pushdown::SortOrderPushdownResult;
70use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
71use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
72use futures::stream::Stream;
73use futures::{FutureExt, StreamExt, TryStreamExt};
74use log::trace;
75use parking_lot::Mutex;
76
77mod distributor_channels;
78use crate::repartition::distributor_channels::SendError;
79use distributor_channels::{
80 DistributionReceiver, DistributionSender, channels, partition_aware_channels,
81};
82
83#[derive(Debug)]
140enum RepartitionBatch {
141 Memory(RecordBatch),
143 Spilled,
148}
149
150type MaybeBatch = Option<Result<RepartitionBatch>>;
151type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
152type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
153
154struct OutputChannel {
161 sender: DistributionSender<MaybeBatch>,
162 reservation: SharedMemoryReservation,
163 spill_writer: SpillPoolWriter,
164 shared_coalescer: Option<SharedCoalescer>,
165}
166
167impl OutputChannel {
168 fn coalesce(&mut self, batch: RecordBatch) -> Result<Vec<RecordBatch>> {
169 match &self.shared_coalescer {
170 Some(shared) => Ok(shared.push_and_drain(batch)?),
171 None => Ok(vec![batch]),
172 }
173 }
174
175 async fn send(&mut self, batch: RecordBatch) -> Result<(), SendError<MaybeBatch>> {
181 let size = batch.get_array_memory_size();
182
183 let (payload, is_memory_batch) = {
186 match self.reservation.try_grow(size) {
187 Ok(_) => (Ok(RepartitionBatch::Memory(batch)), true),
188 Err(_) => match self.spill_writer.push_batch(&batch) {
189 Ok(()) => (Ok(RepartitionBatch::Spilled), false),
190 Err(err) => (Err(err), false),
191 },
192 }
193 };
194
195 let result = self.sender.send(Some(payload)).await;
196 if result.is_err() && is_memory_batch {
197 self.reservation.shrink(size);
198 }
199 result
200 }
201
202 async fn finalize(mut self) -> Result<()> {
203 let Some(shared) = self.shared_coalescer.take() else {
204 return Ok(());
205 };
206 for batch in shared.finalize()? {
207 let _ = self.send(batch).await;
210 }
211 Ok(())
212 }
213}
214
215#[derive(Clone)]
225struct SharedCoalescer {
226 inner: Arc<Mutex<LimitedBatchCoalescer>>,
227 active_senders: Arc<AtomicUsize>,
228}
229
230impl SharedCoalescer {
231 fn new(schema: SchemaRef, target_batch_size: usize, num_senders: usize) -> Self {
232 Self {
233 inner: Arc::new(Mutex::new(LimitedBatchCoalescer::new(
234 schema,
235 target_batch_size,
236 None,
237 ))),
238 active_senders: Arc::new(AtomicUsize::new(num_senders)),
239 }
240 }
241
242 fn push_and_drain(&self, batch: RecordBatch) -> Result<Vec<RecordBatch>> {
245 let mut acc = Vec::new();
246 let mut c = self.inner.lock();
247 c.push_batch(batch)?;
248 while let Some(b) = c.next_completed_batch() {
249 acc.push(b);
250 }
251 Ok(acc)
252 }
253
254 fn finalize(&self) -> Result<Vec<RecordBatch>> {
258 let was_last = self.active_senders.fetch_sub(1, Ordering::AcqRel) == 1;
259 if !was_last {
260 return Ok(vec![]);
261 }
262 let mut acc = Vec::new();
263 let mut c = self.inner.lock();
264 c.finish()?;
265 while let Some(b) = c.next_completed_batch() {
266 acc.push(b);
267 }
268 Ok(acc)
269 }
270}
271
272struct PartitionChannels {
294 tx: InputPartitionsToCurrentPartitionSender,
296 rx: InputPartitionsToCurrentPartitionReceiver,
298 reservation: SharedMemoryReservation,
300 shared_coalescer: Option<SharedCoalescer>,
304 spill_writers: Vec<SpillPoolWriter>,
307 spill_readers: Vec<SendableRecordBatchStream>,
310}
311
312struct ConsumingInputStreamsState {
313 channels: HashMap<usize, PartitionChannels>,
316
317 abort_helper: Arc<Vec<SpawnedTask<()>>>,
319}
320
321impl Debug for ConsumingInputStreamsState {
322 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
323 f.debug_struct("ConsumingInputStreamsState")
324 .field("num_channels", &self.channels.len())
325 .field("abort_helper", &self.abort_helper)
326 .finish()
327 }
328}
329
330#[derive(Default)]
332enum RepartitionExecState {
333 #[default]
336 NotInitialized,
337 InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
341 ConsumingInputStreams(ConsumingInputStreamsState),
344}
345
346impl Debug for RepartitionExecState {
347 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
348 match self {
349 RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
350 RepartitionExecState::InputStreamsInitialized(v) => {
351 write!(f, "InputStreamsInitialized({:?})", v.len())
352 }
353 RepartitionExecState::ConsumingInputStreams(v) => {
354 write!(f, "ConsumingInputStreams({v:?})")
355 }
356 }
357 }
358}
359
360impl RepartitionExecState {
361 fn ensure_input_streams_initialized(
362 &mut self,
363 input: &Arc<dyn ExecutionPlan>,
364 metrics: &ExecutionPlanMetricsSet,
365 output_partitions: usize,
366 ctx: &Arc<TaskContext>,
367 ) -> Result<()> {
368 if !matches!(self, RepartitionExecState::NotInitialized) {
369 return Ok(());
370 }
371
372 let num_input_partitions = input.output_partitioning().partition_count();
373 let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
374
375 for i in 0..num_input_partitions {
376 let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
377
378 let timer = metrics.fetch_time.timer();
379 let stream = input.execute(i, Arc::clone(ctx))?;
380 timer.done();
381
382 streams_and_metrics.push((stream, metrics));
383 }
384 *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
385 Ok(())
386 }
387
388 #[expect(clippy::too_many_arguments)]
389 fn consume_input_streams(
390 &mut self,
391 input: &Arc<dyn ExecutionPlan>,
392 metrics: &ExecutionPlanMetricsSet,
393 partitioning: &Partitioning,
394 preserve_order: bool,
395 name: &str,
396 context: &Arc<TaskContext>,
397 spill_manager: SpillManager,
398 ) -> Result<&mut ConsumingInputStreamsState> {
399 let streams_and_metrics = match self {
400 RepartitionExecState::NotInitialized => {
401 self.ensure_input_streams_initialized(
402 input,
403 metrics,
404 partitioning.partition_count(),
405 context,
406 )?;
407 let RepartitionExecState::InputStreamsInitialized(value) = self else {
408 return internal_err!(
411 "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
412 );
413 };
414 value
415 }
416 RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
417 RepartitionExecState::InputStreamsInitialized(value) => value,
418 };
419
420 let num_input_partitions = streams_and_metrics.len();
421 let num_output_partitions = partitioning.partition_count();
422
423 let spill_manager = Arc::new(spill_manager);
424
425 let (txs, rxs) = if preserve_order {
426 let (txs_all, rxs_all) =
429 partition_aware_channels(num_input_partitions, num_output_partitions);
430 let txs = transpose(txs_all);
432 let rxs = transpose(rxs_all);
433 (txs, rxs)
434 } else {
435 let (txs, rxs) = channels(num_output_partitions);
437 let txs = txs
439 .into_iter()
440 .map(|item| vec![item; num_input_partitions])
441 .collect::<Vec<_>>();
442 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
443 (txs, rxs)
444 };
445
446 let mut channels = HashMap::with_capacity(txs.len());
447 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
448 let reservation = Arc::new(
449 MemoryConsumer::new(format!("{name}[{partition}]"))
450 .with_can_spill(true)
451 .register(context.memory_pool()),
452 );
453
454 let max_file_size = context
459 .session_config()
460 .options()
461 .execution
462 .max_spill_file_size_bytes;
463 let num_spill_channels = if preserve_order {
464 num_input_partitions
465 } else {
466 1
467 };
468 let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
469 ..num_spill_channels)
470 .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
471 .unzip();
472
473 let shared_coalescer = (!preserve_order).then(|| {
478 SharedCoalescer::new(
479 input.schema(),
480 context.session_config().batch_size(),
481 num_input_partitions,
482 )
483 });
484
485 channels.insert(
486 partition,
487 PartitionChannels {
488 tx,
489 rx,
490 reservation,
491 spill_readers,
492 spill_writers,
493 shared_coalescer,
494 },
495 );
496 }
497
498 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
500 for (i, (stream, metrics)) in
501 std::mem::take(streams_and_metrics).into_iter().enumerate()
502 {
503 let txs: HashMap<_, _> = channels
504 .iter()
505 .map(|(partition, channels)| {
506 let spill_writer_idx = if preserve_order { i } else { 0 };
509 (
510 *partition,
511 OutputChannel {
512 sender: channels.tx[i].clone(),
513 reservation: Arc::clone(&channels.reservation),
514 spill_writer: channels.spill_writers[spill_writer_idx]
515 .clone(),
516 shared_coalescer: channels.shared_coalescer.clone(),
517 },
518 )
519 })
520 .collect();
521
522 let senders: HashMap<_, _> = txs
524 .iter()
525 .map(|(partition, channel)| (*partition, channel.sender.clone()))
526 .collect();
527
528 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
529 stream,
530 txs,
531 partitioning.clone(),
532 metrics,
533 if preserve_order { 0 } else { i },
535 num_input_partitions,
536 ));
537
538 let wait_for_task =
541 SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
542 spawned_tasks.push(wait_for_task);
543 }
544 *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
545 channels,
546 abort_helper: Arc::new(spawned_tasks),
547 });
548 match self {
549 RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
550 _ => unreachable!(),
551 }
552 }
553}
554
555pub struct BatchPartitioner {
557 state: BatchPartitionerState,
558 timer: metrics::Time,
559}
560
561enum BatchPartitionerState {
562 Hash {
563 exprs: Vec<Arc<dyn PhysicalExpr>>,
564 partition_reducer: StrengthReducedU64,
565 hash_buffer: Vec<u64>,
566 indices: Vec<Vec<u32>>,
567 },
568 RoundRobin {
569 num_partitions: usize,
570 next_idx: usize,
571 },
572}
573
574pub const REPARTITION_RANDOM_STATE: SeededRandomState = SeededRandomState::with_seed(0);
577
578#[derive(Debug, Clone, Copy)]
587enum StrengthReducedU64 {
588 PowerOfTwo { mask: u64 },
589 Reciprocal { divisor: u64, reciprocal: u128 },
590}
591
592impl StrengthReducedU64 {
593 fn new(divisor: u64) -> Self {
594 debug_assert!(divisor > 0);
595
596 if divisor.is_power_of_two() {
597 Self::PowerOfTwo { mask: divisor - 1 }
598 } else {
599 Self::Reciprocal {
600 divisor,
601 reciprocal: u128::MAX / u128::from(divisor) + 1,
603 }
604 }
605 }
606
607 fn partition_indices(self, hash_buffer: &[u64], indices: &mut [Vec<u32>]) {
608 match self {
609 Self::PowerOfTwo { mask } => {
610 for (index, hash) in hash_buffer.iter().enumerate() {
611 indices[(*hash & mask) as usize].push(index as u32);
612 }
613 }
614 Self::Reciprocal {
615 divisor,
616 reciprocal,
617 } => {
618 for (index, hash) in hash_buffer.iter().enumerate() {
619 let quotient = Self::quotient(*hash, reciprocal);
620 let partition = *hash - quotient * divisor;
621 indices[partition as usize].push(index as u32);
622 }
623 }
624 }
625 }
626
627 #[cfg(test)]
628 fn remainder(self, value: u64) -> u64 {
629 match self {
630 Self::PowerOfTwo { mask } => value & mask,
631 Self::Reciprocal {
632 divisor,
633 reciprocal,
634 } => value - Self::quotient(value, reciprocal) * divisor,
635 }
636 }
637
638 #[inline]
639 fn quotient(value: u64, reciprocal: u128) -> u64 {
640 let reciprocal_low = reciprocal as u64;
641 let reciprocal_high = (reciprocal >> 64) as u64;
642 let low_product = u128::from(value) * u128::from(reciprocal_low);
643 let high_product = u128::from(value) * u128::from(reciprocal_high);
644 let carry = ((high_product & u128::from(u64::MAX)) + (low_product >> 64)) >> 64;
645
646 ((high_product >> 64) + carry) as u64
647 }
648}
649
650impl BatchPartitioner {
651 pub fn new_hash_partitioner(
664 exprs: Vec<Arc<dyn PhysicalExpr>>,
665 num_partitions: usize,
666 timer: metrics::Time,
667 ) -> Result<Self> {
668 if num_partitions == 0 {
669 return internal_err!("Hash repartition requires at least one partition");
670 }
671
672 Ok(Self {
673 state: BatchPartitionerState::Hash {
674 exprs,
675 partition_reducer: StrengthReducedU64::new(num_partitions as u64),
676 hash_buffer: vec![],
677 indices: vec![vec![]; num_partitions],
678 },
679 timer,
680 })
681 }
682
683 pub fn new_round_robin_partitioner(
695 num_partitions: usize,
696 timer: metrics::Time,
697 input_partition: usize,
698 num_input_partitions: usize,
699 ) -> Self {
700 Self {
701 state: BatchPartitionerState::RoundRobin {
702 num_partitions,
703 next_idx: (input_partition * num_partitions) / num_input_partitions,
704 },
705 timer,
706 }
707 }
708 pub fn try_new(
723 partitioning: Partitioning,
724 timer: metrics::Time,
725 input_partition: usize,
726 num_input_partitions: usize,
727 ) -> Result<Self> {
728 match partitioning {
729 Partitioning::Hash(exprs, num_partitions) => {
730 Self::new_hash_partitioner(exprs, num_partitions, timer)
731 }
732 Partitioning::RoundRobinBatch(num_partitions) => {
733 Ok(Self::new_round_robin_partitioner(
734 num_partitions,
735 timer,
736 input_partition,
737 num_input_partitions,
738 ))
739 }
740 other => {
741 not_impl_err!("Unsupported repartitioning scheme {other:?}")
742 }
743 }
744 }
745
746 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
756 where
757 F: FnMut(usize, RecordBatch) -> Result<()>,
758 {
759 self.partition_iter(batch)?.try_for_each(|res| match res {
760 Ok((partition, batch)) => f(partition, batch),
761 Err(e) => Err(e),
762 })
763 }
764
765 pub fn partition_iter(
780 &mut self,
781 batch: RecordBatch,
782 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
783 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
784 match &mut self.state {
785 BatchPartitionerState::RoundRobin {
786 num_partitions,
787 next_idx,
788 } => {
789 let idx = *next_idx;
790 *next_idx = (*next_idx + 1) % *num_partitions;
791 Box::new(std::iter::once(Ok((idx, batch))))
792 }
793 BatchPartitionerState::Hash {
794 exprs,
795 partition_reducer,
796 hash_buffer,
797 indices,
798 } => {
799 let timer = self.timer.timer();
801
802 let arrays =
803 evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
804
805 hash_buffer.clear();
806 hash_buffer.resize(batch.num_rows(), 0);
807
808 create_hashes(
809 &arrays,
810 REPARTITION_RANDOM_STATE.random_state(),
811 hash_buffer,
812 )?;
813
814 indices.iter_mut().for_each(|v| v.clear());
815
816 partition_reducer.partition_indices(hash_buffer, indices);
817
818 timer.done();
820
821 let partitioned_batches =
822 Self::partition_grouped_take(&batch, indices, &self.timer)?;
823
824 Box::new(partitioned_batches.into_iter())
825 }
826 };
827
828 Ok(it)
829 }
830
831 fn num_partitions(&self) -> usize {
833 match &self.state {
834 BatchPartitionerState::RoundRobin { num_partitions, .. } => *num_partitions,
835 BatchPartitionerState::Hash { indices, .. } => indices.len(),
836 }
837 }
838
839 fn partition_grouped_take(
856 batch: &RecordBatch,
857 indices: &mut [Vec<u32>],
858 timer: &metrics::Time,
859 ) -> Result<Vec<Result<(usize, RecordBatch)>>> {
860 let mut partition_ranges = Vec::with_capacity(indices.len());
861 let mut reordered_indices = Vec::with_capacity(batch.num_rows());
862
863 for (partition, p_indices) in indices.iter_mut().enumerate() {
864 if p_indices.is_empty() {
865 continue;
866 }
867
868 let start = reordered_indices.len();
869 reordered_indices.extend_from_slice(p_indices);
870 partition_ranges.push((partition, start, p_indices.len()));
871 p_indices.clear();
872 }
873
874 if reordered_indices.is_empty() {
875 return Ok(vec![]);
876 }
877
878 let batches = {
879 let _timer = timer.timer();
880 let indices_array: PrimitiveArray<UInt32Type> = reordered_indices.into();
881 let columns = take_arrays(batch.columns(), &indices_array, None)?;
882
883 let mut options = RecordBatchOptions::new();
884 options = options.with_row_count(Some(indices_array.len()));
885 let reordered_batch =
886 RecordBatch::try_new_with_options(batch.schema(), columns, &options)?;
887
888 partition_ranges
889 .into_iter()
890 .map(|(partition, start, len)| {
891 Ok((partition, reordered_batch.slice(start, len)))
892 })
893 .collect()
894 };
895
896 Ok(batches)
897 }
898}
899
900#[derive(Debug, Clone)]
1031pub struct RepartitionExec {
1032 input: Arc<dyn ExecutionPlan>,
1034 state: Arc<Mutex<RepartitionExecState>>,
1037 metrics: ExecutionPlanMetricsSet,
1039 preserve_order: bool,
1042 cache: Arc<PlanProperties>,
1044}
1045
1046#[derive(Debug, Clone)]
1047struct RepartitionMetrics {
1048 fetch_time: metrics::Time,
1050 repartition_time: metrics::Time,
1052 send_time: Vec<metrics::Time>,
1056}
1057
1058impl RepartitionMetrics {
1059 pub fn new(
1060 input_partition: usize,
1061 num_output_partitions: usize,
1062 metrics: &ExecutionPlanMetricsSet,
1063 ) -> Self {
1064 let fetch_time =
1066 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
1067
1068 let repartition_time =
1070 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
1071
1072 let send_time = (0..num_output_partitions)
1074 .map(|output_partition| {
1075 let label =
1076 metrics::Label::new("outputPartition", output_partition.to_string());
1077 MetricBuilder::new(metrics)
1078 .with_label(label)
1079 .subset_time("send_time", input_partition)
1080 })
1081 .collect();
1082
1083 Self {
1084 fetch_time,
1085 repartition_time,
1086 send_time,
1087 }
1088 }
1089}
1090
1091impl RepartitionExec {
1092 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
1094 &self.input
1095 }
1096
1097 pub fn partitioning(&self) -> &Partitioning {
1099 &self.cache.partitioning
1100 }
1101
1102 pub fn preserve_order(&self) -> bool {
1105 self.preserve_order
1106 }
1107
1108 pub fn name(&self) -> &str {
1110 "RepartitionExec"
1111 }
1112
1113 fn with_new_children_and_same_properties(
1114 &self,
1115 mut children: Vec<Arc<dyn ExecutionPlan>>,
1116 ) -> Self {
1117 Self {
1118 input: children.swap_remove(0),
1119 metrics: ExecutionPlanMetricsSet::new(),
1120 state: Default::default(),
1121 ..Self::clone(self)
1122 }
1123 }
1124}
1125
1126impl DisplayAs for RepartitionExec {
1127 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1128 let input_partition_count = self.input.output_partitioning().partition_count();
1129 match t {
1130 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1131 write!(
1132 f,
1133 "{}: partitioning={}, input_partitions={}",
1134 self.name(),
1135 self.partitioning(),
1136 input_partition_count,
1137 )?;
1138
1139 if self.preserve_order {
1140 write!(f, ", preserve_order=true")?;
1141 } else if input_partition_count <= 1
1142 && self.input.output_ordering().is_some()
1143 {
1144 write!(f, ", maintains_sort_order=true")?;
1147 }
1148
1149 if let Some(sort_exprs) = self.sort_exprs() {
1150 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
1151 }
1152 Ok(())
1153 }
1154 DisplayFormatType::TreeRender => {
1155 writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
1156 let output_partition_count = self.partitioning().partition_count();
1157 let input_to_output_partition_str =
1158 format!("{input_partition_count} -> {output_partition_count}");
1159 writeln!(
1160 f,
1161 "partition_count(in->out)={input_to_output_partition_str}"
1162 )?;
1163
1164 if self.preserve_order {
1165 writeln!(f, "preserve_order={}", self.preserve_order)?;
1166 }
1167 Ok(())
1168 }
1169 }
1170 }
1171}
1172
1173impl ExecutionPlan for RepartitionExec {
1174 fn name(&self) -> &'static str {
1175 "RepartitionExec"
1176 }
1177
1178 fn properties(&self) -> &Arc<PlanProperties> {
1180 &self.cache
1181 }
1182
1183 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1184 vec![&self.input]
1185 }
1186
1187 fn with_new_children(
1188 self: Arc<Self>,
1189 mut children: Vec<Arc<dyn ExecutionPlan>>,
1190 ) -> Result<Arc<dyn ExecutionPlan>> {
1191 check_if_same_properties!(self, children);
1192 let mut repartition = RepartitionExec::try_new(
1193 children.swap_remove(0),
1194 self.partitioning().clone(),
1195 )?;
1196 if self.preserve_order {
1197 repartition = repartition.with_preserve_order();
1198 }
1199 Ok(Arc::new(repartition))
1200 }
1201
1202 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
1203 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
1204 }
1205
1206 fn maintains_input_order(&self) -> Vec<bool> {
1207 Self::maintains_input_order_helper(self.input(), self.preserve_order)
1208 }
1209
1210 fn execute(
1211 &self,
1212 partition: usize,
1213 context: Arc<TaskContext>,
1214 ) -> Result<SendableRecordBatchStream> {
1215 trace!(
1216 "Start {}::execute for partition: {}",
1217 self.name(),
1218 partition
1219 );
1220
1221 let spill_metrics = SpillMetrics::new(&self.metrics, partition);
1222
1223 let input = Arc::clone(&self.input);
1224 let partitioning = self.partitioning().clone();
1225 let metrics = self.metrics.clone();
1226 let preserve_order = self.sort_exprs().is_some();
1227 let name = self.name().to_owned();
1228 let schema = self.schema();
1229 let schema_captured = Arc::clone(&schema);
1230
1231 let spill_manager = SpillManager::new(
1232 Arc::clone(&context.runtime_env()),
1233 spill_metrics,
1234 input.schema(),
1235 );
1236
1237 let sort_exprs = self.sort_exprs().cloned();
1239
1240 let state = Arc::clone(&self.state);
1241 if let Some(mut state) = state.try_lock() {
1242 state.ensure_input_streams_initialized(
1243 &input,
1244 &metrics,
1245 partitioning.partition_count(),
1246 &context,
1247 )?;
1248 }
1249
1250 let num_input_partitions = input.output_partitioning().partition_count();
1251
1252 let stream = futures::stream::once(async move {
1253 let (rx, reservation, spill_readers, abort_helper) = {
1255 let mut state = state.lock();
1257 let state = state.consume_input_streams(
1258 &input,
1259 &metrics,
1260 &partitioning,
1261 preserve_order,
1262 &name,
1263 &context,
1264 spill_manager.clone(),
1265 )?;
1266
1267 let PartitionChannels {
1270 rx,
1271 reservation,
1272 spill_readers,
1273 ..
1274 } = state
1275 .channels
1276 .remove(&partition)
1277 .expect("partition not used yet");
1278
1279 (
1280 rx,
1281 reservation,
1282 spill_readers,
1283 Arc::clone(&state.abort_helper),
1284 )
1285 };
1286
1287 trace!(
1288 "Before returning stream in {name}::execute for partition: {partition}"
1289 );
1290
1291 if preserve_order {
1292 let input_streams = rx
1295 .into_iter()
1296 .zip(spill_readers)
1297 .map(|(receiver, spill_stream)| {
1298 Box::pin(PerPartitionStream::new(
1300 Arc::clone(&schema_captured),
1301 receiver,
1302 Arc::clone(&abort_helper),
1303 Arc::clone(&reservation),
1304 spill_stream,
1305 1, BaselineMetrics::new(&metrics, partition),
1307 )) as SendableRecordBatchStream
1308 })
1309 .collect::<Vec<_>>();
1310 let fetch = None;
1315 let merge_reservation =
1316 MemoryConsumer::new(format!("{name}[Merge {partition}]"))
1317 .register(context.memory_pool());
1318 StreamingMergeBuilder::new()
1319 .with_streams(input_streams)
1320 .with_schema(schema_captured)
1321 .with_expressions(&sort_exprs.unwrap())
1322 .with_metrics(BaselineMetrics::new(&metrics, partition))
1323 .with_batch_size(context.session_config().batch_size())
1324 .with_fetch(fetch)
1325 .with_reservation(merge_reservation)
1326 .with_spill_manager(spill_manager)
1327 .build()
1328 } else {
1329 let spill_stream = spill_readers
1331 .into_iter()
1332 .next()
1333 .expect("at least one spill reader should exist");
1334
1335 Ok(Box::pin(PerPartitionStream::new(
1336 schema_captured,
1337 rx.into_iter()
1338 .next()
1339 .expect("at least one receiver should exist"),
1340 abort_helper,
1341 reservation,
1342 spill_stream,
1343 num_input_partitions,
1344 BaselineMetrics::new(&metrics, partition),
1345 )) as SendableRecordBatchStream)
1346 }
1347 })
1348 .try_flatten();
1349 let stream = RecordBatchStreamAdapter::new(schema, stream);
1350 Ok(Box::pin(stream))
1351 }
1352
1353 fn metrics(&self) -> Option<MetricsSet> {
1354 Some(self.metrics.clone_inner())
1355 }
1356
1357 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
1358 if let Some(partition) = partition {
1359 let partition_count = self.partitioning().partition_count();
1360 if partition_count == 0 {
1361 return Ok(Arc::new(Statistics::new_unknown(&self.schema())));
1362 }
1363
1364 assert_or_internal_err!(
1365 partition < partition_count,
1366 "RepartitionExec invalid partition {} (expected less than {})",
1367 partition,
1368 partition_count
1369 );
1370
1371 let mut stats = Arc::unwrap_or_clone(self.input.partition_statistics(None)?);
1372
1373 stats.num_rows = stats
1375 .num_rows
1376 .get_value()
1377 .map(|rows| Precision::Inexact(rows / partition_count))
1378 .unwrap_or(Precision::Absent);
1379 stats.total_byte_size = stats
1380 .total_byte_size
1381 .get_value()
1382 .map(|bytes| Precision::Inexact(bytes / partition_count))
1383 .unwrap_or(Precision::Absent);
1384
1385 stats.column_statistics = stats
1387 .column_statistics
1388 .iter()
1389 .map(|_| ColumnStatistics::new_unknown())
1390 .collect();
1391
1392 Ok(Arc::new(stats))
1393 } else {
1394 self.input.partition_statistics(None)
1395 }
1396 }
1397
1398 fn cardinality_effect(&self) -> CardinalityEffect {
1399 CardinalityEffect::Equal
1400 }
1401
1402 fn try_swapping_with_projection(
1403 &self,
1404 projection: &ProjectionExec,
1405 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1406 if projection.expr().len() >= projection.input().schema().fields().len() {
1408 return Ok(None);
1409 }
1410
1411 if projection.benefits_from_input_partitioning()[0]
1413 || !all_columns(projection.expr())
1414 {
1415 return Ok(None);
1416 }
1417
1418 let new_projection = make_with_child(projection, self.input())?;
1419
1420 let new_partitioning = match self.partitioning() {
1421 Partitioning::Hash(partitions, size) => {
1422 let mut new_partitions = vec![];
1423 for partition in partitions {
1424 let Some(new_partition) =
1425 update_expr(partition, projection.expr(), false)?
1426 else {
1427 return Ok(None);
1428 };
1429 new_partitions.push(new_partition);
1430 }
1431 Partitioning::Hash(new_partitions, *size)
1432 }
1433 others => others.clone(),
1434 };
1435
1436 Ok(Some(Arc::new(RepartitionExec::try_new(
1437 new_projection,
1438 new_partitioning,
1439 )?)))
1440 }
1441
1442 fn gather_filters_for_pushdown(
1443 &self,
1444 _phase: FilterPushdownPhase,
1445 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1446 _config: &ConfigOptions,
1447 ) -> Result<FilterDescription> {
1448 FilterDescription::from_children(parent_filters, &self.children())
1449 }
1450
1451 fn handle_child_pushdown_result(
1452 &self,
1453 _phase: FilterPushdownPhase,
1454 child_pushdown_result: ChildPushdownResult,
1455 _config: &ConfigOptions,
1456 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1457 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1458 }
1459
1460 fn try_pushdown_sort(
1461 &self,
1462 order: &[PhysicalSortExpr],
1463 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
1464 if !self.maintains_input_order()[0] {
1467 return Ok(SortOrderPushdownResult::Unsupported);
1468 }
1469
1470 self.input.try_pushdown_sort(order)?.try_map(|new_input| {
1472 let mut new_repartition =
1473 RepartitionExec::try_new(new_input, self.partitioning().clone())?;
1474 if self.preserve_order {
1475 new_repartition = new_repartition.with_preserve_order();
1476 }
1477 Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
1478 })
1479 }
1480
1481 fn repartitioned(
1482 &self,
1483 target_partitions: usize,
1484 _config: &ConfigOptions,
1485 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1486 use Partitioning::*;
1487 let mut new_properties = PlanProperties::clone(&self.cache);
1488 new_properties.partitioning = match new_properties.partitioning {
1489 RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1490 Hash(hash, _) => Hash(hash, target_partitions),
1491 UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1492 };
1493 Ok(Some(Arc::new(Self {
1494 input: Arc::clone(&self.input),
1495 state: Arc::clone(&self.state),
1496 metrics: self.metrics.clone(),
1497 preserve_order: self.preserve_order,
1498 cache: new_properties.into(),
1499 })))
1500 }
1501}
1502
1503impl RepartitionExec {
1504 pub fn try_new(
1508 input: Arc<dyn ExecutionPlan>,
1509 partitioning: Partitioning,
1510 ) -> Result<Self> {
1511 let preserve_order = false;
1512 let cache = Self::compute_properties(&input, partitioning, preserve_order);
1513 Ok(RepartitionExec {
1514 input,
1515 state: Default::default(),
1516 metrics: ExecutionPlanMetricsSet::new(),
1517 preserve_order,
1518 cache: Arc::new(cache),
1519 })
1520 }
1521
1522 fn maintains_input_order_helper(
1523 input: &Arc<dyn ExecutionPlan>,
1524 preserve_order: bool,
1525 ) -> Vec<bool> {
1526 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1528 }
1529
1530 fn eq_properties_helper(
1531 input: &Arc<dyn ExecutionPlan>,
1532 preserve_order: bool,
1533 ) -> EquivalenceProperties {
1534 let mut eq_properties = input.equivalence_properties().clone();
1536 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1538 eq_properties.clear_orderings();
1539 }
1540 if input.output_partitioning().partition_count() > 1 {
1543 eq_properties.clear_per_partition_constants();
1544 }
1545 eq_properties
1546 }
1547
1548 fn compute_properties(
1550 input: &Arc<dyn ExecutionPlan>,
1551 partitioning: Partitioning,
1552 preserve_order: bool,
1553 ) -> PlanProperties {
1554 PlanProperties::new(
1555 Self::eq_properties_helper(input, preserve_order),
1556 partitioning,
1557 input.pipeline_behavior(),
1558 input.boundedness(),
1559 )
1560 .with_scheduling_type(SchedulingType::Cooperative)
1561 .with_evaluation_type(EvaluationType::Eager)
1562 }
1563
1564 pub fn with_preserve_order(mut self) -> Self {
1572 self.preserve_order =
1573 self.input.output_ordering().is_some() &&
1575 self.input.output_partitioning().partition_count() > 1;
1578 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1579 Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties);
1580 self
1581 }
1582
1583 fn sort_exprs(&self) -> Option<&LexOrdering> {
1585 if self.preserve_order {
1586 self.input.output_ordering()
1587 } else {
1588 None
1589 }
1590 }
1591
1592 async fn pull_from_input(
1597 mut stream: SendableRecordBatchStream,
1598 mut output_channels: HashMap<usize, OutputChannel>,
1599 partitioning: Partitioning,
1600 metrics: RepartitionMetrics,
1601 input_partition: usize,
1602 num_input_partitions: usize,
1603 ) -> Result<()> {
1604 let mut partitioner = match &partitioning {
1605 Partitioning::Hash(exprs, num_partitions) => {
1606 BatchPartitioner::new_hash_partitioner(
1607 exprs.clone(),
1608 *num_partitions,
1609 metrics.repartition_time.clone(),
1610 )?
1611 }
1612 Partitioning::RoundRobinBatch(num_partitions) => {
1613 BatchPartitioner::new_round_robin_partitioner(
1614 *num_partitions,
1615 metrics.repartition_time.clone(),
1616 input_partition,
1617 num_input_partitions,
1618 )
1619 }
1620 other => {
1621 return not_impl_err!("Unsupported repartitioning scheme {other:?}");
1622 }
1623 };
1624
1625 let mut batches_until_yield = partitioner.num_partitions();
1627 while !output_channels.is_empty() {
1628 let timer = metrics.fetch_time.timer();
1630 let result = stream.next().await;
1631 timer.done();
1632
1633 let batch = match result {
1635 Some(result) => result?,
1636 None => break,
1637 };
1638
1639 if batch.num_rows() == 0 {
1641 continue;
1642 }
1643
1644 for res in partitioner.partition_iter(batch)? {
1645 let (partition, batch) = res?;
1646
1647 let timer = metrics.send_time[partition].timer();
1648 if let Some(output_channel) = output_channels.get_mut(&partition) {
1650 for batch in output_channel.coalesce(batch)? {
1651 if output_channel.send(batch).await.is_err() {
1652 output_channels.remove(&partition);
1655 break;
1656 }
1657 }
1658 }
1659 timer.done();
1660 }
1661
1662 if batches_until_yield == 0 {
1679 tokio::task::yield_now().await;
1680 batches_until_yield = partitioner.num_partitions();
1681 } else {
1682 batches_until_yield -= 1;
1683 }
1684 }
1685
1686 for (_, output_channel) in output_channels.drain() {
1691 output_channel.finalize().await?;
1692 }
1693
1694 Ok(())
1697 }
1698
1699 async fn wait_for_task(
1705 input_task: SpawnedTask<Result<()>>,
1706 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1707 ) {
1708 match input_task.join().await {
1712 Err(e) => {
1714 let e = Arc::new(e);
1715
1716 for (_, tx) in txs {
1717 let err = Err(DataFusionError::Context(
1718 "Join Error".to_string(),
1719 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1720 ));
1721 tx.send(Some(err)).await.ok();
1722 }
1723 }
1724 Ok(Err(e)) => {
1726 let e = Arc::new(e);
1728
1729 for (_, tx) in txs {
1730 let err = Err(DataFusionError::from(&e));
1732 tx.send(Some(err)).await.ok();
1733 }
1734 }
1735 Ok(Ok(())) => {
1737 for (_partition, tx) in txs {
1739 tx.send(None).await.ok();
1740 }
1741 }
1742 }
1743 }
1744}
1745
1746#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1789enum StreamState {
1790 ReadingMemory,
1792 ReadingSpilled,
1795}
1796
1797struct PerPartitionStream {
1800 schema: SchemaRef,
1802
1803 receiver: DistributionReceiver<MaybeBatch>,
1805
1806 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1808
1809 reservation: SharedMemoryReservation,
1811
1812 spill_stream: SendableRecordBatchStream,
1814
1815 state: StreamState,
1817
1818 remaining_partitions: usize,
1822
1823 baseline_metrics: BaselineMetrics,
1825}
1826
1827impl PerPartitionStream {
1828 fn new(
1829 schema: SchemaRef,
1830 receiver: DistributionReceiver<MaybeBatch>,
1831 drop_helper: Arc<Vec<SpawnedTask<()>>>,
1832 reservation: SharedMemoryReservation,
1833 spill_stream: SendableRecordBatchStream,
1834 num_input_partitions: usize,
1835 baseline_metrics: BaselineMetrics,
1836 ) -> Self {
1837 Self {
1838 schema,
1839 receiver,
1840 _drop_helper: drop_helper,
1841 reservation,
1842 spill_stream,
1843 state: StreamState::ReadingMemory,
1844 remaining_partitions: num_input_partitions,
1845 baseline_metrics,
1846 }
1847 }
1848
1849 fn poll_next_inner(
1850 self: &mut Pin<&mut Self>,
1851 cx: &mut Context<'_>,
1852 ) -> Poll<Option<Result<RecordBatch>>> {
1853 use futures::StreamExt;
1854 let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1855 let _timer = cloned_time.timer();
1856
1857 loop {
1858 match self.state {
1859 StreamState::ReadingMemory => {
1860 let value = match self.receiver.recv().poll_unpin(cx) {
1862 Poll::Ready(v) => v,
1863 Poll::Pending => {
1864 return Poll::Pending;
1866 }
1867 };
1868
1869 match value {
1870 Some(Some(v)) => match v {
1871 Ok(RepartitionBatch::Memory(batch)) => {
1872 self.reservation.shrink(batch.get_array_memory_size());
1874 return Poll::Ready(Some(Ok(batch)));
1875 }
1876 Ok(RepartitionBatch::Spilled) => {
1877 self.state = StreamState::ReadingSpilled;
1881 continue;
1882 }
1883 Err(e) => {
1884 return Poll::Ready(Some(Err(e)));
1885 }
1886 },
1887 Some(None) => {
1888 self.remaining_partitions -= 1;
1890 if self.remaining_partitions == 0 {
1891 return Poll::Ready(None);
1893 }
1894 continue;
1896 }
1897 None => {
1898 return Poll::Ready(None);
1900 }
1901 }
1902 }
1903 StreamState::ReadingSpilled => {
1904 match self.spill_stream.poll_next_unpin(cx) {
1906 Poll::Ready(Some(Ok(batch))) => {
1907 self.state = StreamState::ReadingMemory;
1908 return Poll::Ready(Some(Ok(batch)));
1909 }
1910 Poll::Ready(Some(Err(e))) => {
1911 return Poll::Ready(Some(Err(e)));
1912 }
1913 Poll::Ready(None) => {
1914 let spill_schema = self.spill_stream.schema();
1917 self.spill_stream =
1918 Box::pin(EmptyRecordBatchStream::new(spill_schema));
1919 self.state = StreamState::ReadingMemory;
1920 }
1921 Poll::Pending => {
1922 return Poll::Pending;
1925 }
1926 }
1927 }
1928 }
1929 }
1930 }
1931}
1932
1933impl Stream for PerPartitionStream {
1934 type Item = Result<RecordBatch>;
1935
1936 fn poll_next(
1937 mut self: Pin<&mut Self>,
1938 cx: &mut Context<'_>,
1939 ) -> Poll<Option<Self::Item>> {
1940 let poll = self.poll_next_inner(cx);
1941 self.baseline_metrics.record_poll(poll)
1942 }
1943}
1944
1945impl RecordBatchStream for PerPartitionStream {
1946 fn schema(&self) -> SchemaRef {
1948 Arc::clone(&self.schema)
1949 }
1950}
1951
1952#[cfg(test)]
1953mod tests {
1954 use std::collections::HashSet;
1955
1956 use super::*;
1957 use crate::test::TestMemoryExec;
1958 use crate::{
1959 test::{
1960 assert_is_pending,
1961 exec::{
1962 BarrierExec, BlockingExec, ErrorExec, MockExec,
1963 assert_strong_count_converges_to_zero,
1964 },
1965 },
1966 {collect, expressions::col},
1967 };
1968
1969 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1970 use arrow::datatypes::{DataType, Field, Schema};
1971 use datafusion_common::cast::as_string_array;
1972 use datafusion_common::exec_err;
1973 use datafusion_common::test_util::batches_to_sort_string;
1974 use datafusion_common_runtime::JoinSet;
1975 use datafusion_execution::config::SessionConfig;
1976 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1977 use insta::assert_snapshot;
1978
1979 #[test]
1980 fn strength_reduced_u64_remainder_matches_modulo() {
1981 let divisors = [
1982 1,
1983 2,
1984 3,
1985 4,
1986 5,
1987 7,
1988 8,
1989 10,
1990 16,
1991 31,
1992 32,
1993 63,
1994 64,
1995 65,
1996 97,
1997 u64::from(u32::MAX),
1998 u64::from(u32::MAX) + 1,
1999 1_u64 << 32,
2000 (1_u64 << 63) - 1,
2001 1_u64 << 63,
2002 u64::MAX - 1,
2003 u64::MAX,
2004 ];
2005 let values = [
2006 0,
2007 1,
2008 2,
2009 3,
2010 4,
2011 5,
2012 31,
2013 32,
2014 33,
2015 63,
2016 64,
2017 65,
2018 u64::from(u32::MAX) - 1,
2019 u64::from(u32::MAX),
2020 u64::from(u32::MAX) + 1,
2021 (1_u64 << 32) - 1,
2022 1_u64 << 32,
2023 (1_u64 << 32) + 1,
2024 (1_u64 << 63) - 1,
2025 1_u64 << 63,
2026 (1_u64 << 63) + 1,
2027 u64::MAX - 1,
2028 u64::MAX,
2029 ];
2030
2031 for divisor in divisors {
2032 let reducer = StrengthReducedU64::new(divisor);
2033 for value in values {
2034 assert_eq!(
2035 reducer.remainder(value),
2036 value % divisor,
2037 "value={value} divisor={divisor}"
2038 );
2039 }
2040
2041 let mut value = 0x1234_5678_9abc_def0 ^ divisor;
2042 for _ in 0..10_000 {
2043 value = value
2044 .wrapping_mul(6_364_136_223_846_793_005)
2045 .wrapping_add(1_442_695_040_888_963_407);
2046 assert_eq!(
2047 reducer.remainder(value),
2048 value % divisor,
2049 "value={value} divisor={divisor}"
2050 );
2051 }
2052 }
2053 }
2054
2055 #[test]
2056 fn hash_partitioner_requires_nonzero_partitions() {
2057 let metrics = ExecutionPlanMetricsSet::new();
2058 let timer = MetricBuilder::new(&metrics).subset_time("test", 0);
2059
2060 let err = BatchPartitioner::new_hash_partitioner(vec![], 0, timer)
2061 .err()
2062 .expect("zero hash partitions should fail")
2063 .to_string();
2064
2065 assert!(
2066 err.contains("Hash repartition requires at least one partition"),
2067 "actual: {err}"
2068 );
2069 }
2070
2071 #[tokio::test]
2072 async fn one_to_many_round_robin() -> Result<()> {
2073 let schema = test_schema();
2075 let partition = create_vec_batches(50);
2076 let partitions = vec![partition];
2077
2078 let output_partitions =
2080 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
2081
2082 assert_eq!(4, output_partitions.len());
2083 for partition in &output_partitions {
2084 assert_eq!(1, partition.len());
2085 }
2086 assert_eq!(13 * 8, output_partitions[0][0].num_rows());
2087 assert_eq!(13 * 8, output_partitions[1][0].num_rows());
2088 assert_eq!(12 * 8, output_partitions[2][0].num_rows());
2089 assert_eq!(12 * 8, output_partitions[3][0].num_rows());
2090
2091 Ok(())
2092 }
2093
2094 #[tokio::test]
2095 async fn many_to_one_round_robin() -> Result<()> {
2096 let schema = test_schema();
2098 let partition = create_vec_batches(50);
2099 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
2100
2101 let output_partitions =
2103 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
2104
2105 assert_eq!(1, output_partitions.len());
2106 assert_eq!(150 * 8, output_partitions[0][0].num_rows());
2107
2108 Ok(())
2109 }
2110
2111 #[tokio::test]
2112 async fn many_to_many_round_robin() -> Result<()> {
2113 let schema = test_schema();
2115 let partition = create_vec_batches(50);
2116 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
2117
2118 let output_partitions =
2120 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
2121
2122 let total_rows_per_partition = 8 * 50 * 3 / 5;
2123 assert_eq!(5, output_partitions.len());
2124 for partition in output_partitions {
2125 assert_eq!(1, partition.len());
2126 assert_eq!(total_rows_per_partition, partition[0].num_rows());
2127 }
2128
2129 Ok(())
2130 }
2131
2132 #[tokio::test]
2133 async fn many_to_many_hash_partition() -> Result<()> {
2134 let schema = test_schema();
2136 let partition = create_vec_batches(50);
2137 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
2138
2139 let output_partitions = repartition(
2140 &schema,
2141 partitions,
2142 Partitioning::Hash(vec![col("c0", &schema)?], 8),
2143 )
2144 .await?;
2145
2146 let total_rows: usize = output_partitions
2147 .iter()
2148 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
2149 .sum();
2150
2151 assert_eq!(8, output_partitions.len());
2152 assert_eq!(total_rows, 8 * 50 * 3);
2153
2154 Ok(())
2155 }
2156
2157 #[tokio::test]
2158 async fn test_repartition_with_coalescing() -> Result<()> {
2159 let schema = test_schema();
2160 let partition = create_vec_batches(50);
2162 let partitions = vec![partition.clone(), partition.clone()];
2163 let partitioning = Partitioning::RoundRobinBatch(1);
2164
2165 let session_config = SessionConfig::new().with_batch_size(200);
2166 let task_ctx = TaskContext::default().with_session_config(session_config);
2167 let task_ctx = Arc::new(task_ctx);
2168
2169 let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
2171 let exec = RepartitionExec::try_new(exec, partitioning)?;
2172
2173 for i in 0..exec.partitioning().partition_count() {
2174 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2175 while let Some(result) = stream.next().await {
2176 let batch = result?;
2177 assert_eq!(200, batch.num_rows());
2178 }
2179 }
2180 Ok(())
2181 }
2182
2183 fn test_schema() -> Arc<Schema> {
2184 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2185 }
2186
2187 async fn repartition(
2188 schema: &SchemaRef,
2189 input_partitions: Vec<Vec<RecordBatch>>,
2190 partitioning: Partitioning,
2191 ) -> Result<Vec<Vec<RecordBatch>>> {
2192 let task_ctx = Arc::new(TaskContext::default());
2193 let exec =
2195 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
2196 let exec = RepartitionExec::try_new(exec, partitioning)?;
2197
2198 let mut output_partitions = vec![];
2200 for i in 0..exec.partitioning().partition_count() {
2201 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2203 let mut batches = vec![];
2204 while let Some(result) = stream.next().await {
2205 batches.push(result?);
2206 }
2207 output_partitions.push(batches);
2208 }
2209 Ok(output_partitions)
2210 }
2211
2212 #[tokio::test]
2213 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
2214 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
2215 SpawnedTask::spawn(async move {
2216 let schema = test_schema();
2218 let partition = create_vec_batches(50);
2219 let partitions =
2220 vec![partition.clone(), partition.clone(), partition.clone()];
2221
2222 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
2224 });
2225
2226 let output_partitions = handle.join().await.unwrap().unwrap();
2227
2228 let total_rows_per_partition = 8 * 50 * 3 / 5;
2229 assert_eq!(5, output_partitions.len());
2230 for partition in output_partitions {
2231 assert_eq!(1, partition.len());
2232 assert_eq!(total_rows_per_partition, partition[0].num_rows());
2233 }
2234
2235 Ok(())
2236 }
2237
2238 #[tokio::test]
2239 async fn unsupported_partitioning() {
2240 let task_ctx = Arc::new(TaskContext::default());
2241 let batch = RecordBatch::try_from_iter(vec![(
2243 "my_awesome_field",
2244 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2245 )])
2246 .unwrap();
2247
2248 let schema = batch.schema();
2249 let input = MockExec::new(vec![Ok(batch)], schema);
2250 let partitioning = Partitioning::UnknownPartitioning(1);
2254 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2255 let output_stream = exec.execute(0, task_ctx).unwrap();
2256
2257 let result_string = crate::common::collect(output_stream)
2259 .await
2260 .unwrap_err()
2261 .to_string();
2262 assert!(
2263 result_string
2264 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
2265 "actual: {result_string}"
2266 );
2267 }
2268
2269 #[tokio::test]
2270 async fn error_for_input_exec() {
2271 let task_ctx = Arc::new(TaskContext::default());
2275 let input = ErrorExec::new();
2276 let partitioning = Partitioning::RoundRobinBatch(1);
2277 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2278
2279 let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
2281
2282 assert!(
2283 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
2284 "actual: {result_string}"
2285 );
2286 }
2287
2288 #[tokio::test]
2289 async fn repartition_with_error_in_stream() {
2290 let task_ctx = Arc::new(TaskContext::default());
2291 let batch = RecordBatch::try_from_iter(vec![(
2292 "my_awesome_field",
2293 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2294 )])
2295 .unwrap();
2296
2297 let err = exec_err!("bad data error");
2300
2301 let schema = batch.schema();
2302 let input = MockExec::new(vec![Ok(batch), err], schema);
2303 let partitioning = Partitioning::RoundRobinBatch(1);
2304 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2305
2306 let output_stream = exec.execute(0, task_ctx).unwrap();
2309
2310 let result_string = crate::common::collect(output_stream)
2312 .await
2313 .unwrap_err()
2314 .to_string();
2315 assert!(
2316 result_string.contains("bad data error"),
2317 "actual: {result_string}"
2318 );
2319 }
2320
2321 #[tokio::test]
2322 async fn repartition_with_delayed_stream() {
2323 let task_ctx = Arc::new(TaskContext::default());
2324 let batch1 = RecordBatch::try_from_iter(vec![(
2325 "my_awesome_field",
2326 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2327 )])
2328 .unwrap();
2329
2330 let batch2 = RecordBatch::try_from_iter(vec![(
2331 "my_awesome_field",
2332 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2333 )])
2334 .unwrap();
2335
2336 let schema = batch1.schema();
2339 let expected_batches = vec![batch1.clone(), batch2.clone()];
2340 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
2341 let partitioning = Partitioning::RoundRobinBatch(1);
2342
2343 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2344
2345 assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
2346 +------------------+
2347 | my_awesome_field |
2348 +------------------+
2349 | bar |
2350 | baz |
2351 | foo |
2352 | frob |
2353 +------------------+
2354 ");
2355
2356 let output_stream = exec.execute(0, task_ctx).unwrap();
2357 let batches = crate::common::collect(output_stream).await.unwrap();
2358
2359 assert_snapshot!(batches_to_sort_string(&batches), @r"
2360 +------------------+
2361 | my_awesome_field |
2362 +------------------+
2363 | bar |
2364 | baz |
2365 | foo |
2366 | frob |
2367 +------------------+
2368 ");
2369 }
2370
2371 #[tokio::test]
2372 async fn robin_repartition_with_dropping_output_stream() {
2373 let task_ctx = Arc::new(TaskContext::default());
2374 let partitioning = Partitioning::RoundRobinBatch(2);
2375 let input = Arc::new(make_barrier_exec());
2378
2379 let exec = RepartitionExec::try_new(
2381 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2382 partitioning,
2383 )
2384 .unwrap();
2385
2386 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2387 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2388
2389 drop(output_stream0);
2392
2393 let mut background_task = JoinSet::new();
2395 background_task.spawn(async move {
2396 input.wait().await;
2397 });
2398
2399 let batches = crate::common::collect(output_stream1).await.unwrap();
2401
2402 assert_snapshot!(batches_to_sort_string(&batches), @r"
2403 +------------------+
2404 | my_awesome_field |
2405 +------------------+
2406 | baz |
2407 | frob |
2408 | gar |
2409 | goo |
2410 +------------------+
2411 ");
2412 }
2413
2414 #[tokio::test]
2415 async fn hash_repartition_with_dropping_output_stream() {
2419 let task_ctx = Arc::new(TaskContext::default());
2420 let partitioning = Partitioning::Hash(
2421 vec![Arc::new(crate::expressions::Column::new(
2422 "my_awesome_field",
2423 0,
2424 ))],
2425 2,
2426 );
2427
2428 let input = Arc::new(make_barrier_exec());
2430 let exec = RepartitionExec::try_new(
2431 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2432 partitioning.clone(),
2433 )
2434 .unwrap();
2435 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2436 let mut background_task = JoinSet::new();
2437 background_task.spawn(async move {
2438 input.wait().await;
2439 });
2440 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
2441
2442 let items_vec = str_batches_to_vec(&batches_without_drop);
2444 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
2445 assert_eq!(items_vec.len(), items_set.len());
2446 let source_str_set: HashSet<&str> =
2447 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
2448 .iter()
2449 .copied()
2450 .collect();
2451 assert_eq!(items_set.difference(&source_str_set).count(), 0);
2452
2453 let input = Arc::new(make_barrier_exec());
2455 let exec = RepartitionExec::try_new(
2456 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2457 partitioning,
2458 )
2459 .unwrap();
2460 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2461 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2462 drop(output_stream0);
2465 let mut background_task = JoinSet::new();
2466 background_task.spawn(async move {
2467 input.wait().await;
2468 });
2469 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
2470
2471 let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2472 let items_set_with_drop: HashSet<&str> =
2473 items_vec_with_drop.iter().copied().collect();
2474 assert_eq!(
2475 items_set_with_drop.symmetric_difference(&items_set).count(),
2476 0
2477 );
2478 }
2479
2480 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
2481 batches
2482 .iter()
2483 .flat_map(|batch| {
2484 assert_eq!(batch.columns().len(), 1);
2485 let string_array = as_string_array(batch.column(0))
2486 .expect("Unexpected type for repartitioned batch");
2487
2488 string_array
2489 .iter()
2490 .map(|v| v.expect("Unexpected null"))
2491 .collect::<Vec<_>>()
2492 })
2493 .collect::<Vec<_>>()
2494 }
2495
2496 fn make_barrier_exec() -> BarrierExec {
2498 let batch1 = RecordBatch::try_from_iter(vec![(
2499 "my_awesome_field",
2500 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2501 )])
2502 .unwrap();
2503
2504 let batch2 = RecordBatch::try_from_iter(vec![(
2505 "my_awesome_field",
2506 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2507 )])
2508 .unwrap();
2509
2510 let batch3 = RecordBatch::try_from_iter(vec![(
2511 "my_awesome_field",
2512 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
2513 )])
2514 .unwrap();
2515
2516 let batch4 = RecordBatch::try_from_iter(vec![(
2517 "my_awesome_field",
2518 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
2519 )])
2520 .unwrap();
2521
2522 let schema = batch1.schema();
2525 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
2526 }
2527
2528 #[tokio::test]
2529 async fn test_drop_cancel() -> Result<()> {
2530 let task_ctx = Arc::new(TaskContext::default());
2531 let schema =
2532 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2533
2534 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
2535 let refs = blocking_exec.refs();
2536 let repartition_exec = Arc::new(RepartitionExec::try_new(
2537 blocking_exec,
2538 Partitioning::UnknownPartitioning(1),
2539 )?);
2540
2541 let fut = collect(repartition_exec, task_ctx);
2542 let mut fut = fut.boxed();
2543
2544 assert_is_pending(&mut fut);
2545 drop(fut);
2546 assert_strong_count_converges_to_zero(refs).await;
2547
2548 Ok(())
2549 }
2550
2551 #[tokio::test]
2552 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
2553 let task_ctx = Arc::new(TaskContext::default());
2554 let batch = RecordBatch::try_from_iter(vec![(
2555 "a",
2556 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2557 )])
2558 .unwrap();
2559 let partitioning = Partitioning::Hash(
2560 vec![Arc::new(crate::expressions::Column::new("a", 0))],
2561 2,
2562 );
2563 let schema = batch.schema();
2564 let input = MockExec::new(vec![Ok(batch)], schema);
2565 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2566 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2567 let batch0 = crate::common::collect(output_stream0).await.unwrap();
2568 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2569 let batch1 = crate::common::collect(output_stream1).await.unwrap();
2570 assert!(batch0.is_empty() || batch1.is_empty());
2571 Ok(())
2572 }
2573
2574 #[tokio::test]
2575 async fn repartition_with_spilling() -> Result<()> {
2576 let schema = test_schema();
2578 let partition = create_vec_batches(50);
2579 let input_partitions = vec![partition];
2580 let partitioning = Partitioning::RoundRobinBatch(4);
2581
2582 let runtime = RuntimeEnvBuilder::default()
2584 .with_memory_limit(1, 1.0)
2585 .build_arc()?;
2586
2587 let task_ctx = TaskContext::default().with_runtime(runtime);
2588 let task_ctx = Arc::new(task_ctx);
2589
2590 let exec =
2592 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2593 let exec = RepartitionExec::try_new(exec, partitioning)?;
2594
2595 let mut total_rows = 0;
2597 for i in 0..exec.partitioning().partition_count() {
2598 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2599 while let Some(result) = stream.next().await {
2600 let batch = result?;
2601 total_rows += batch.num_rows();
2602 }
2603 }
2604
2605 assert_eq!(total_rows, 50 * 8);
2607
2608 let metrics = exec.metrics().unwrap();
2610 assert!(
2611 metrics.spill_count().unwrap() > 0,
2612 "Expected spill_count > 0, but got {:?}",
2613 metrics.spill_count()
2614 );
2615 println!("Spilled {} times", metrics.spill_count().unwrap());
2616 assert!(
2617 metrics.spilled_bytes().unwrap() > 0,
2618 "Expected spilled_bytes > 0, but got {:?}",
2619 metrics.spilled_bytes()
2620 );
2621 println!(
2622 "Spilled {} bytes in {} spills",
2623 metrics.spilled_bytes().unwrap(),
2624 metrics.spill_count().unwrap()
2625 );
2626 assert!(
2627 metrics.spilled_rows().unwrap() > 0,
2628 "Expected spilled_rows > 0, but got {:?}",
2629 metrics.spilled_rows()
2630 );
2631 println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2632
2633 Ok(())
2634 }
2635
2636 #[tokio::test]
2637 async fn repartition_with_partial_spilling() -> Result<()> {
2638 let schema = test_schema();
2640 let partition = create_vec_batches(50);
2641 let input_partitions = vec![partition];
2642 let partitioning = Partitioning::RoundRobinBatch(4);
2643
2644 let runtime = RuntimeEnvBuilder::default()
2648 .with_memory_limit(8 * 1024, 1.0)
2649 .build_arc()?;
2650
2651 let session_config = SessionConfig::new().with_batch_size(1024);
2652 let task_ctx = TaskContext::default()
2653 .with_runtime(runtime)
2654 .with_session_config(session_config);
2655 let task_ctx = Arc::new(task_ctx);
2656
2657 let exec =
2659 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2660 let exec = RepartitionExec::try_new(exec, partitioning)?;
2661
2662 let mut total_rows = 0;
2664 for i in 0..exec.partitioning().partition_count() {
2665 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2666 while let Some(result) = stream.next().await {
2667 let batch = result?;
2668 total_rows += batch.num_rows();
2669 }
2670 }
2671
2672 assert_eq!(total_rows, 50 * 8);
2674
2675 let metrics = exec.metrics().unwrap();
2677 let spill_count = metrics.spill_count().unwrap();
2678 let spilled_rows = metrics.spilled_rows().unwrap();
2679 let spilled_bytes = metrics.spilled_bytes().unwrap();
2680
2681 assert!(
2682 spill_count > 0,
2683 "Expected some spilling to occur, but got spill_count={spill_count}"
2684 );
2685 assert!(
2686 spilled_rows > 0 && spilled_rows < total_rows,
2687 "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2688 );
2689 assert!(
2690 spilled_bytes > 0,
2691 "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2692 );
2693
2694 println!(
2695 "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2696 spilled_rows,
2697 total_rows,
2698 (spilled_rows as f64 / total_rows as f64) * 100.0,
2699 spill_count,
2700 spilled_bytes
2701 );
2702
2703 Ok(())
2704 }
2705
2706 #[tokio::test]
2707 async fn repartition_without_spilling() -> Result<()> {
2708 let schema = test_schema();
2710 let partition = create_vec_batches(50);
2711 let input_partitions = vec![partition];
2712 let partitioning = Partitioning::RoundRobinBatch(4);
2713
2714 let runtime = RuntimeEnvBuilder::default()
2716 .with_memory_limit(10 * 1024 * 1024, 1.0) .build_arc()?;
2718
2719 let task_ctx = TaskContext::default().with_runtime(runtime);
2720 let task_ctx = Arc::new(task_ctx);
2721
2722 let exec =
2724 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2725 let exec = RepartitionExec::try_new(exec, partitioning)?;
2726
2727 let mut total_rows = 0;
2729 for i in 0..exec.partitioning().partition_count() {
2730 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2731 while let Some(result) = stream.next().await {
2732 let batch = result?;
2733 total_rows += batch.num_rows();
2734 }
2735 }
2736
2737 assert_eq!(total_rows, 50 * 8);
2739
2740 let metrics = exec.metrics().unwrap();
2742 assert_eq!(
2743 metrics.spill_count(),
2744 Some(0),
2745 "Expected no spilling, but got spill_count={:?}",
2746 metrics.spill_count()
2747 );
2748 assert_eq!(
2749 metrics.spilled_bytes(),
2750 Some(0),
2751 "Expected no bytes spilled, but got spilled_bytes={:?}",
2752 metrics.spilled_bytes()
2753 );
2754 assert_eq!(
2755 metrics.spilled_rows(),
2756 Some(0),
2757 "Expected no rows spilled, but got spilled_rows={:?}",
2758 metrics.spilled_rows()
2759 );
2760
2761 println!("No spilling occurred - all data processed in memory");
2762
2763 Ok(())
2764 }
2765
2766 #[tokio::test]
2767 async fn oom() -> Result<()> {
2768 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2769
2770 let schema = test_schema();
2772 let partition = create_vec_batches(50);
2773 let input_partitions = vec![partition];
2774 let partitioning = Partitioning::RoundRobinBatch(4);
2775
2776 let runtime = RuntimeEnvBuilder::default()
2778 .with_memory_limit(1, 1.0)
2779 .with_disk_manager_builder(
2780 DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2781 )
2782 .build_arc()?;
2783
2784 let task_ctx = TaskContext::default().with_runtime(runtime);
2785 let task_ctx = Arc::new(task_ctx);
2786
2787 let exec =
2789 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2790 let exec = RepartitionExec::try_new(exec, partitioning)?;
2791
2792 for i in 0..exec.partitioning().partition_count() {
2794 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2795 let err = stream.next().await.unwrap().unwrap_err();
2796 let err = err.find_root();
2797 assert!(
2798 matches!(err, DataFusionError::ResourcesExhausted(_)),
2799 "Wrong error type: {err}",
2800 );
2801 }
2802
2803 Ok(())
2804 }
2805
2806 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2808 let batch = create_batch();
2809 std::iter::repeat_n(batch, n).collect()
2810 }
2811
2812 fn create_batch() -> RecordBatch {
2814 let schema = test_schema();
2815 RecordBatch::try_new(
2816 schema,
2817 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2818 )
2819 .unwrap()
2820 }
2821
2822 fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2824 let schema = test_schema();
2825 (0..num_batches)
2826 .map(|i| {
2827 let start = (i * 8) as u32;
2828 RecordBatch::try_new(
2829 Arc::clone(&schema),
2830 vec![Arc::new(UInt32Array::from(
2831 (start..start + 8).collect::<Vec<_>>(),
2832 ))],
2833 )
2834 .unwrap()
2835 })
2836 .collect()
2837 }
2838
2839 #[tokio::test]
2840 async fn test_repartition_ordering_with_spilling() -> Result<()> {
2841 let schema = test_schema();
2846 let partition = create_ordered_batches(20);
2849 let input_partitions = vec![partition];
2850
2851 let partitioning = Partitioning::RoundRobinBatch(2);
2853
2854 let runtime = RuntimeEnvBuilder::default()
2856 .with_memory_limit(1, 1.0)
2857 .build_arc()?;
2858
2859 let task_ctx = TaskContext::default().with_runtime(runtime);
2860 let task_ctx = Arc::new(task_ctx);
2861
2862 let exec =
2864 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2865 let exec = RepartitionExec::try_new(exec, partitioning)?;
2866
2867 let mut all_batches = Vec::new();
2869 for i in 0..exec.partitioning().partition_count() {
2870 let mut partition_batches = Vec::new();
2871 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2872 while let Some(result) = stream.next().await {
2873 let batch = result?;
2874 partition_batches.push(batch);
2875 }
2876 all_batches.push(partition_batches);
2877 }
2878
2879 let metrics = exec.metrics().unwrap();
2881 assert!(
2882 metrics.spill_count().unwrap() > 0,
2883 "Expected spilling to occur, but spill_count = 0"
2884 );
2885
2886 for (partition_idx, batches) in all_batches.iter().enumerate() {
2889 let mut last_value = None;
2890 for batch in batches {
2891 let array = batch
2892 .column(0)
2893 .as_any()
2894 .downcast_ref::<UInt32Array>()
2895 .unwrap();
2896
2897 for i in 0..array.len() {
2898 let value = array.value(i);
2899 if let Some(last) = last_value {
2900 assert!(
2901 value > last,
2902 "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2903 );
2904 }
2905 last_value = Some(value);
2906 }
2907 }
2908 }
2909
2910 Ok(())
2911 }
2912}
2913
2914#[cfg(test)]
2915mod test {
2916 use arrow::array::record_batch;
2917 use arrow::compute::SortOptions;
2918 use arrow::datatypes::{DataType, Field, Schema};
2919 use datafusion_common::assert_batches_eq;
2920
2921 use super::*;
2922 use crate::test::TestMemoryExec;
2923 use crate::union::UnionExec;
2924
2925 use datafusion_physical_expr::expressions::col;
2926
2927 macro_rules! assert_plan {
2932 ($PLAN: expr, @ $EXPECTED: expr) => {
2933 let formatted = crate::displayable($PLAN).indent(true).to_string();
2934
2935 insta::assert_snapshot!(
2936 formatted,
2937 @$EXPECTED
2938 );
2939 };
2940 }
2941
2942 #[tokio::test]
2943 async fn test_preserve_order() -> Result<()> {
2944 let schema = test_schema();
2945 let sort_exprs = sort_exprs(&schema);
2946 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2947 let source2 = sorted_memory_exec(&schema, sort_exprs);
2948 let union = UnionExec::try_new(vec![source1, source2])?;
2950 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2951 .with_preserve_order();
2952
2953 assert_plan!(&exec, @r"
2955 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2956 UnionExec
2957 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2958 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2959 ");
2960 Ok(())
2961 }
2962
2963 #[tokio::test]
2964 async fn test_preserve_order_one_partition() -> Result<()> {
2965 let schema = test_schema();
2966 let sort_exprs = sort_exprs(&schema);
2967 let source = sorted_memory_exec(&schema, sort_exprs);
2968 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2970 .with_preserve_order();
2971
2972 assert_plan!(&exec, @r"
2974 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
2975 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2976 ");
2977
2978 Ok(())
2979 }
2980
2981 #[tokio::test]
2982 async fn test_preserve_order_input_not_sorted() -> Result<()> {
2983 let schema = test_schema();
2984 let source1 = memory_exec(&schema);
2985 let source2 = memory_exec(&schema);
2986 let union = UnionExec::try_new(vec![source1, source2])?;
2988 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2989 .with_preserve_order();
2990
2991 assert_plan!(&exec, @r"
2993 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2994 UnionExec
2995 DataSourceExec: partitions=1, partition_sizes=[0]
2996 DataSourceExec: partitions=1, partition_sizes=[0]
2997 ");
2998 Ok(())
2999 }
3000
3001 #[tokio::test]
3002 async fn test_preserve_order_with_spilling() -> Result<()> {
3003 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
3004
3005 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
3009 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
3010 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
3011 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
3012 let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
3013 let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
3014 let schema = batch1.schema();
3015 let sort_exprs = LexOrdering::new([PhysicalSortExpr {
3016 expr: col("c0", &schema).unwrap(),
3017 options: SortOptions::default().asc(),
3018 }])
3019 .unwrap();
3020 let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
3021 let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
3022 let input_partitions = vec![partition1, partition2];
3023
3024 let runtime = RuntimeEnvBuilder::default()
3027 .with_memory_limit(64, 1.0)
3028 .build_arc()?;
3029
3030 let task_ctx = TaskContext::default().with_runtime(runtime);
3031 let task_ctx = Arc::new(task_ctx);
3032
3033 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
3035 .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
3036 let exec = Arc::new(exec);
3037 let exec = Arc::new(TestMemoryExec::update_cache(&exec));
3038 let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
3041 .with_preserve_order();
3042
3043 let mut batches = vec![];
3044
3045 for i in 0..exec.partitioning().partition_count() {
3047 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
3048 while let Some(result) = stream.next().await {
3049 let batch = result?;
3050 batches.push(batch);
3051 }
3052 }
3053
3054 #[rustfmt::skip]
3055 let expected = [
3056 [
3057 "+----+",
3058 "| c0 |",
3059 "+----+",
3060 "| 1 |",
3061 "| 2 |",
3062 "| 3 |",
3063 "| 4 |",
3064 "+----+",
3065 ],
3066 [
3067 "+----+",
3068 "| c0 |",
3069 "+----+",
3070 "| 5 |",
3071 "| 6 |",
3072 "| 7 |",
3073 "| 8 |",
3074 "+----+",
3075 ],
3076 [
3077 "+----+",
3078 "| c0 |",
3079 "+----+",
3080 "| 9 |",
3081 "| 10 |",
3082 "| 11 |",
3083 "| 12 |",
3084 "+----+",
3085 ],
3086 ];
3087
3088 for (batch, expected) in batches.iter().zip(expected.iter()) {
3089 assert_batches_eq!(expected, std::slice::from_ref(batch));
3090 }
3091
3092 let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
3096 let metrics = exec.metrics().unwrap();
3097 assert!(
3098 metrics.spill_count().unwrap() > input_partitions.len(),
3099 "Expected spill_count > {} for order-preserving repartition, but got {:?}",
3100 input_partitions.len(),
3101 metrics.spill_count()
3102 );
3103 assert!(
3104 metrics.spilled_bytes().unwrap()
3105 > all_batches
3106 .iter()
3107 .map(|b| b.get_array_memory_size())
3108 .sum::<usize>(),
3109 "Expected spilled_bytes > {} for order-preserving repartition, got {}",
3110 all_batches
3111 .iter()
3112 .map(|b| b.get_array_memory_size())
3113 .sum::<usize>(),
3114 metrics.spilled_bytes().unwrap()
3115 );
3116 assert!(
3117 metrics.spilled_rows().unwrap()
3118 >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
3119 "Expected spilled_rows > {} for order-preserving repartition, got {}",
3120 all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
3121 metrics.spilled_rows().unwrap()
3122 );
3123
3124 Ok(())
3125 }
3126
3127 #[tokio::test]
3128 async fn test_hash_partitioning_with_spilling() -> Result<()> {
3129 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
3130
3131 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
3133 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
3134 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
3135 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
3136 let schema = batch1.schema();
3137
3138 let partition1 = vec![batch1.clone(), batch3.clone()];
3139 let partition2 = vec![batch2.clone(), batch4.clone()];
3140 let input_partitions = vec![partition1, partition2];
3141
3142 let runtime = RuntimeEnvBuilder::default()
3144 .with_memory_limit(1, 1.0)
3145 .build_arc()?;
3146
3147 let task_ctx = TaskContext::default().with_runtime(runtime);
3148 let task_ctx = Arc::new(task_ctx);
3149
3150 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
3152 let exec = Arc::new(exec);
3153 let exec = Arc::new(TestMemoryExec::update_cache(&exec));
3154 let hash_expr = col("c0", &schema)?;
3156 let exec =
3157 RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
3158
3159 let mut join_set = tokio::task::JoinSet::new();
3162 for i in 0..exec.partitioning().partition_count() {
3163 let stream = exec.execute(i, Arc::clone(&task_ctx))?;
3164 join_set.spawn(async move {
3165 let mut count = 0;
3166 futures::pin_mut!(stream);
3167 while let Some(result) = stream.next().await {
3168 let batch = result?;
3169 count += batch.num_rows();
3170 }
3171 Ok::<usize, DataFusionError>(count)
3172 });
3173 }
3174
3175 let mut total_rows = 0;
3177 while let Some(result) = join_set.join_next().await {
3178 total_rows += result.unwrap()?;
3179 }
3180
3181 let all_batches = [batch1, batch2, batch3, batch4];
3183 let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
3184 assert_eq!(total_rows, expected_rows);
3185
3186 let metrics = exec.metrics().unwrap();
3188 let spill_count = metrics.spill_count().unwrap_or(0);
3190 assert!(spill_count > 0);
3191 let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
3192 assert!(spilled_bytes > 0);
3193 let spilled_rows = metrics.spilled_rows().unwrap_or(0);
3194 assert!(spilled_rows > 0);
3195
3196 Ok(())
3197 }
3198
3199 #[tokio::test]
3200 async fn test_repartition() -> Result<()> {
3201 let schema = test_schema();
3202 let sort_exprs = sort_exprs(&schema);
3203 let source = sorted_memory_exec(&schema, sort_exprs);
3204 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
3206 .repartitioned(20, &Default::default())?
3207 .unwrap();
3208
3209 assert_plan!(exec.as_ref(), @r"
3211 RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
3212 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
3213 ");
3214 Ok(())
3215 }
3216
3217 fn test_schema() -> Arc<Schema> {
3218 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
3219 }
3220
3221 fn sort_exprs(schema: &Schema) -> LexOrdering {
3222 [PhysicalSortExpr {
3223 expr: col("c0", schema).unwrap(),
3224 options: SortOptions::default(),
3225 }]
3226 .into()
3227 }
3228
3229 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
3230 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
3231 }
3232
3233 fn sorted_memory_exec(
3234 schema: &SchemaRef,
3235 sort_exprs: LexOrdering,
3236 ) -> Arc<dyn ExecutionPlan> {
3237 let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
3238 .unwrap()
3239 .try_with_sort_information(vec![sort_exprs])
3240 .unwrap();
3241 let exec = Arc::new(exec);
3242 Arc::new(TestMemoryExec::update_cache(&exec))
3243 }
3244}