1use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::coalesce::LimitedBatchCoalescer;
34use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
35use crate::hash_utils::create_hashes;
36use crate::metrics::{BaselineMetrics, SpillMetrics};
37use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
38use crate::sorts::streaming_merge::StreamingMergeBuilder;
39use crate::spill::spill_manager::SpillManager;
40use crate::spill::spill_pool::{self, SpillPoolWriter};
41use crate::stream::RecordBatchStreamAdapter;
42use crate::{
43 DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics,
44 check_if_same_properties,
45};
46
47use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
48use arrow::compute::take_arrays;
49use arrow::datatypes::{SchemaRef, UInt32Type};
50use datafusion_common::config::ConfigOptions;
51use datafusion_common::stats::Precision;
52use datafusion_common::utils::transpose;
53use datafusion_common::{
54 ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err,
55 internal_datafusion_err, internal_err,
56};
57use datafusion_common::{Result, not_impl_err};
58use datafusion_common_runtime::SpawnedTask;
59use datafusion_execution::TaskContext;
60use datafusion_execution::memory_pool::MemoryConsumer;
61use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
62use datafusion_physical_expr_common::sort_expr::LexOrdering;
63
64use crate::filter_pushdown::{
65 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
66 FilterPushdownPropagation,
67};
68use crate::joins::SeededRandomState;
69use crate::sort_pushdown::SortOrderPushdownResult;
70use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
71use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
72use futures::stream::Stream;
73use futures::{FutureExt, StreamExt, TryStreamExt, ready};
74use log::trace;
75use parking_lot::Mutex;
76
77mod distributor_channels;
78use distributor_channels::{
79 DistributionReceiver, DistributionSender, channels, partition_aware_channels,
80};
81
82#[derive(Debug)]
131enum RepartitionBatch {
132 Memory(RecordBatch),
134 Spilled,
139}
140
141type MaybeBatch = Option<Result<RepartitionBatch>>;
142type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
143type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
144
145struct OutputChannel {
147 sender: DistributionSender<MaybeBatch>,
148 reservation: SharedMemoryReservation,
149 spill_writer: SpillPoolWriter,
150}
151
152struct PartitionChannels {
174 tx: InputPartitionsToCurrentPartitionSender,
176 rx: InputPartitionsToCurrentPartitionReceiver,
178 reservation: SharedMemoryReservation,
180 spill_writers: Vec<SpillPoolWriter>,
183 spill_readers: Vec<SendableRecordBatchStream>,
186}
187
188struct ConsumingInputStreamsState {
189 channels: HashMap<usize, PartitionChannels>,
192
193 abort_helper: Arc<Vec<SpawnedTask<()>>>,
195}
196
197impl Debug for ConsumingInputStreamsState {
198 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("ConsumingInputStreamsState")
200 .field("num_channels", &self.channels.len())
201 .field("abort_helper", &self.abort_helper)
202 .finish()
203 }
204}
205
206#[derive(Default)]
208enum RepartitionExecState {
209 #[default]
212 NotInitialized,
213 InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
217 ConsumingInputStreams(ConsumingInputStreamsState),
220}
221
222impl Debug for RepartitionExecState {
223 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
224 match self {
225 RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
226 RepartitionExecState::InputStreamsInitialized(v) => {
227 write!(f, "InputStreamsInitialized({:?})", v.len())
228 }
229 RepartitionExecState::ConsumingInputStreams(v) => {
230 write!(f, "ConsumingInputStreams({v:?})")
231 }
232 }
233 }
234}
235
236impl RepartitionExecState {
237 fn ensure_input_streams_initialized(
238 &mut self,
239 input: &Arc<dyn ExecutionPlan>,
240 metrics: &ExecutionPlanMetricsSet,
241 output_partitions: usize,
242 ctx: &Arc<TaskContext>,
243 ) -> Result<()> {
244 if !matches!(self, RepartitionExecState::NotInitialized) {
245 return Ok(());
246 }
247
248 let num_input_partitions = input.output_partitioning().partition_count();
249 let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
250
251 for i in 0..num_input_partitions {
252 let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
253
254 let timer = metrics.fetch_time.timer();
255 let stream = input.execute(i, Arc::clone(ctx))?;
256 timer.done();
257
258 streams_and_metrics.push((stream, metrics));
259 }
260 *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
261 Ok(())
262 }
263
264 #[expect(clippy::too_many_arguments)]
265 fn consume_input_streams(
266 &mut self,
267 input: &Arc<dyn ExecutionPlan>,
268 metrics: &ExecutionPlanMetricsSet,
269 partitioning: &Partitioning,
270 preserve_order: bool,
271 name: &str,
272 context: &Arc<TaskContext>,
273 spill_manager: SpillManager,
274 ) -> Result<&mut ConsumingInputStreamsState> {
275 let streams_and_metrics = match self {
276 RepartitionExecState::NotInitialized => {
277 self.ensure_input_streams_initialized(
278 input,
279 metrics,
280 partitioning.partition_count(),
281 context,
282 )?;
283 let RepartitionExecState::InputStreamsInitialized(value) = self else {
284 return internal_err!(
287 "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
288 );
289 };
290 value
291 }
292 RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
293 RepartitionExecState::InputStreamsInitialized(value) => value,
294 };
295
296 let num_input_partitions = streams_and_metrics.len();
297 let num_output_partitions = partitioning.partition_count();
298
299 let spill_manager = Arc::new(spill_manager);
300
301 let (txs, rxs) = if preserve_order {
302 let (txs_all, rxs_all) =
305 partition_aware_channels(num_input_partitions, num_output_partitions);
306 let txs = transpose(txs_all);
308 let rxs = transpose(rxs_all);
309 (txs, rxs)
310 } else {
311 let (txs, rxs) = channels(num_output_partitions);
313 let txs = txs
315 .into_iter()
316 .map(|item| vec![item; num_input_partitions])
317 .collect::<Vec<_>>();
318 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
319 (txs, rxs)
320 };
321
322 let mut channels = HashMap::with_capacity(txs.len());
323 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
324 let reservation = Arc::new(Mutex::new(
325 MemoryConsumer::new(format!("{name}[{partition}]"))
326 .with_can_spill(true)
327 .register(context.memory_pool()),
328 ));
329
330 let max_file_size = context
335 .session_config()
336 .options()
337 .execution
338 .max_spill_file_size_bytes;
339 let num_spill_channels = if preserve_order {
340 num_input_partitions
341 } else {
342 1
343 };
344 let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
345 ..num_spill_channels)
346 .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
347 .unzip();
348
349 channels.insert(
350 partition,
351 PartitionChannels {
352 tx,
353 rx,
354 reservation,
355 spill_readers,
356 spill_writers,
357 },
358 );
359 }
360
361 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
363 for (i, (stream, metrics)) in
364 std::mem::take(streams_and_metrics).into_iter().enumerate()
365 {
366 let txs: HashMap<_, _> = channels
367 .iter()
368 .map(|(partition, channels)| {
369 let spill_writer_idx = if preserve_order { i } else { 0 };
372 (
373 *partition,
374 OutputChannel {
375 sender: channels.tx[i].clone(),
376 reservation: Arc::clone(&channels.reservation),
377 spill_writer: channels.spill_writers[spill_writer_idx]
378 .clone(),
379 },
380 )
381 })
382 .collect();
383
384 let senders: HashMap<_, _> = txs
386 .iter()
387 .map(|(partition, channel)| (*partition, channel.sender.clone()))
388 .collect();
389
390 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
391 stream,
392 txs,
393 partitioning.clone(),
394 metrics,
395 if preserve_order { 0 } else { i },
397 num_input_partitions,
398 ));
399
400 let wait_for_task =
403 SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
404 spawned_tasks.push(wait_for_task);
405 }
406 *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
407 channels,
408 abort_helper: Arc::new(spawned_tasks),
409 });
410 match self {
411 RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
412 _ => unreachable!(),
413 }
414 }
415}
416
417pub struct BatchPartitioner {
419 state: BatchPartitionerState,
420 timer: metrics::Time,
421}
422
423enum BatchPartitionerState {
424 Hash {
425 exprs: Vec<Arc<dyn PhysicalExpr>>,
426 num_partitions: usize,
427 hash_buffer: Vec<u64>,
428 indices: Vec<Vec<u32>>,
429 },
430 RoundRobin {
431 num_partitions: usize,
432 next_idx: usize,
433 },
434}
435
436pub const REPARTITION_RANDOM_STATE: SeededRandomState =
439 SeededRandomState::with_seeds(0, 0, 0, 0);
440
441impl BatchPartitioner {
442 pub fn new_hash_partitioner(
452 exprs: Vec<Arc<dyn PhysicalExpr>>,
453 num_partitions: usize,
454 timer: metrics::Time,
455 ) -> Self {
456 Self {
457 state: BatchPartitionerState::Hash {
458 exprs,
459 num_partitions,
460 hash_buffer: vec![],
461 indices: vec![vec![]; num_partitions],
462 },
463 timer,
464 }
465 }
466
467 pub fn new_round_robin_partitioner(
479 num_partitions: usize,
480 timer: metrics::Time,
481 input_partition: usize,
482 num_input_partitions: usize,
483 ) -> Self {
484 Self {
485 state: BatchPartitionerState::RoundRobin {
486 num_partitions,
487 next_idx: (input_partition * num_partitions) / num_input_partitions,
488 },
489 timer,
490 }
491 }
492 pub fn try_new(
506 partitioning: Partitioning,
507 timer: metrics::Time,
508 input_partition: usize,
509 num_input_partitions: usize,
510 ) -> Result<Self> {
511 match partitioning {
512 Partitioning::Hash(exprs, num_partitions) => {
513 Ok(Self::new_hash_partitioner(exprs, num_partitions, timer))
514 }
515 Partitioning::RoundRobinBatch(num_partitions) => {
516 Ok(Self::new_round_robin_partitioner(
517 num_partitions,
518 timer,
519 input_partition,
520 num_input_partitions,
521 ))
522 }
523 other => {
524 not_impl_err!("Unsupported repartitioning scheme {other:?}")
525 }
526 }
527 }
528
529 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
539 where
540 F: FnMut(usize, RecordBatch) -> Result<()>,
541 {
542 self.partition_iter(batch)?.try_for_each(|res| match res {
543 Ok((partition, batch)) => f(partition, batch),
544 Err(e) => Err(e),
545 })
546 }
547
548 fn partition_iter(
554 &mut self,
555 batch: RecordBatch,
556 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
557 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
558 match &mut self.state {
559 BatchPartitionerState::RoundRobin {
560 num_partitions,
561 next_idx,
562 } => {
563 let idx = *next_idx;
564 *next_idx = (*next_idx + 1) % *num_partitions;
565 Box::new(std::iter::once(Ok((idx, batch))))
566 }
567 BatchPartitionerState::Hash {
568 exprs,
569 num_partitions: partitions,
570 hash_buffer,
571 indices,
572 } => {
573 let timer = self.timer.timer();
575
576 let arrays =
577 evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
578
579 hash_buffer.clear();
580 hash_buffer.resize(batch.num_rows(), 0);
581
582 create_hashes(
583 &arrays,
584 REPARTITION_RANDOM_STATE.random_state(),
585 hash_buffer,
586 )?;
587
588 indices.iter_mut().for_each(|v| v.clear());
589
590 for (index, hash) in hash_buffer.iter().enumerate() {
591 indices[(*hash % *partitions as u64) as usize].push(index as u32);
592 }
593
594 timer.done();
596
597 let partitioner_timer = &self.timer;
599
600 let mut partitioned_batches = vec![];
601 for (partition, p_indices) in indices.iter_mut().enumerate() {
602 if !p_indices.is_empty() {
603 let taken_indices = std::mem::take(p_indices);
604 let indices_array: PrimitiveArray<UInt32Type> =
605 taken_indices.into();
606
607 let _timer = partitioner_timer.timer();
609
610 let columns =
612 take_arrays(batch.columns(), &indices_array, None)?;
613
614 let mut options = RecordBatchOptions::new();
615 options = options.with_row_count(Some(indices_array.len()));
616 let batch = RecordBatch::try_new_with_options(
617 batch.schema(),
618 columns,
619 &options,
620 )
621 .unwrap();
622
623 partitioned_batches.push(Ok((partition, batch)));
624
625 let (_, buffer, _) = indices_array.into_parts();
627 let mut vec =
628 buffer.into_inner().into_vec::<u32>().map_err(|e| {
629 internal_datafusion_err!(
630 "Could not convert buffer to vec: {e:?}"
631 )
632 })?;
633 vec.clear();
634 *p_indices = vec;
635 }
636 }
637
638 Box::new(partitioned_batches.into_iter())
639 }
640 };
641
642 Ok(it)
643 }
644
645 fn num_partitions(&self) -> usize {
647 match self.state {
648 BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
649 BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
650 }
651 }
652}
653
654#[derive(Debug, Clone)]
757pub struct RepartitionExec {
758 input: Arc<dyn ExecutionPlan>,
760 state: Arc<Mutex<RepartitionExecState>>,
763 metrics: ExecutionPlanMetricsSet,
765 preserve_order: bool,
768 cache: Arc<PlanProperties>,
770}
771
772#[derive(Debug, Clone)]
773struct RepartitionMetrics {
774 fetch_time: metrics::Time,
776 repartition_time: metrics::Time,
778 send_time: Vec<metrics::Time>,
782}
783
784impl RepartitionMetrics {
785 pub fn new(
786 input_partition: usize,
787 num_output_partitions: usize,
788 metrics: &ExecutionPlanMetricsSet,
789 ) -> Self {
790 let fetch_time =
792 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
793
794 let repartition_time =
796 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
797
798 let send_time = (0..num_output_partitions)
800 .map(|output_partition| {
801 let label =
802 metrics::Label::new("outputPartition", output_partition.to_string());
803 MetricBuilder::new(metrics)
804 .with_label(label)
805 .subset_time("send_time", input_partition)
806 })
807 .collect();
808
809 Self {
810 fetch_time,
811 repartition_time,
812 send_time,
813 }
814 }
815}
816
817impl RepartitionExec {
818 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
820 &self.input
821 }
822
823 pub fn partitioning(&self) -> &Partitioning {
825 &self.cache.partitioning
826 }
827
828 pub fn preserve_order(&self) -> bool {
831 self.preserve_order
832 }
833
834 pub fn name(&self) -> &str {
836 "RepartitionExec"
837 }
838
839 fn with_new_children_and_same_properties(
840 &self,
841 mut children: Vec<Arc<dyn ExecutionPlan>>,
842 ) -> Self {
843 Self {
844 input: children.swap_remove(0),
845 metrics: ExecutionPlanMetricsSet::new(),
846 state: Default::default(),
847 ..Self::clone(self)
848 }
849 }
850}
851
852impl DisplayAs for RepartitionExec {
853 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
854 let input_partition_count = self.input.output_partitioning().partition_count();
855 match t {
856 DisplayFormatType::Default | DisplayFormatType::Verbose => {
857 write!(
858 f,
859 "{}: partitioning={}, input_partitions={}",
860 self.name(),
861 self.partitioning(),
862 input_partition_count,
863 )?;
864
865 if self.preserve_order {
866 write!(f, ", preserve_order=true")?;
867 } else if input_partition_count <= 1
868 && self.input.output_ordering().is_some()
869 {
870 write!(f, ", maintains_sort_order=true")?;
873 }
874
875 if let Some(sort_exprs) = self.sort_exprs() {
876 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
877 }
878 Ok(())
879 }
880 DisplayFormatType::TreeRender => {
881 writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
882 let output_partition_count = self.partitioning().partition_count();
883 let input_to_output_partition_str =
884 format!("{input_partition_count} -> {output_partition_count}");
885 writeln!(
886 f,
887 "partition_count(in->out)={input_to_output_partition_str}"
888 )?;
889
890 if self.preserve_order {
891 writeln!(f, "preserve_order={}", self.preserve_order)?;
892 }
893 Ok(())
894 }
895 }
896 }
897}
898
899impl ExecutionPlan for RepartitionExec {
900 fn name(&self) -> &'static str {
901 "RepartitionExec"
902 }
903
904 fn as_any(&self) -> &dyn Any {
906 self
907 }
908
909 fn properties(&self) -> &Arc<PlanProperties> {
910 &self.cache
911 }
912
913 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
914 vec![&self.input]
915 }
916
917 fn with_new_children(
918 self: Arc<Self>,
919 mut children: Vec<Arc<dyn ExecutionPlan>>,
920 ) -> Result<Arc<dyn ExecutionPlan>> {
921 check_if_same_properties!(self, children);
922 let mut repartition = RepartitionExec::try_new(
923 children.swap_remove(0),
924 self.partitioning().clone(),
925 )?;
926 if self.preserve_order {
927 repartition = repartition.with_preserve_order();
928 }
929 Ok(Arc::new(repartition))
930 }
931
932 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
933 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
934 }
935
936 fn maintains_input_order(&self) -> Vec<bool> {
937 Self::maintains_input_order_helper(self.input(), self.preserve_order)
938 }
939
940 fn execute(
941 &self,
942 partition: usize,
943 context: Arc<TaskContext>,
944 ) -> Result<SendableRecordBatchStream> {
945 trace!(
946 "Start {}::execute for partition: {}",
947 self.name(),
948 partition
949 );
950
951 let spill_metrics = SpillMetrics::new(&self.metrics, partition);
952
953 let input = Arc::clone(&self.input);
954 let partitioning = self.partitioning().clone();
955 let metrics = self.metrics.clone();
956 let preserve_order = self.sort_exprs().is_some();
957 let name = self.name().to_owned();
958 let schema = self.schema();
959 let schema_captured = Arc::clone(&schema);
960
961 let spill_manager = SpillManager::new(
962 Arc::clone(&context.runtime_env()),
963 spill_metrics,
964 input.schema(),
965 );
966
967 let sort_exprs = self.sort_exprs().cloned();
969
970 let state = Arc::clone(&self.state);
971 if let Some(mut state) = state.try_lock() {
972 state.ensure_input_streams_initialized(
973 &input,
974 &metrics,
975 partitioning.partition_count(),
976 &context,
977 )?;
978 }
979
980 let num_input_partitions = input.output_partitioning().partition_count();
981
982 let stream = futures::stream::once(async move {
983 let (rx, reservation, spill_readers, abort_helper) = {
985 let mut state = state.lock();
987 let state = state.consume_input_streams(
988 &input,
989 &metrics,
990 &partitioning,
991 preserve_order,
992 &name,
993 &context,
994 spill_manager.clone(),
995 )?;
996
997 let PartitionChannels {
1000 rx,
1001 reservation,
1002 spill_readers,
1003 ..
1004 } = state
1005 .channels
1006 .remove(&partition)
1007 .expect("partition not used yet");
1008
1009 (
1010 rx,
1011 reservation,
1012 spill_readers,
1013 Arc::clone(&state.abort_helper),
1014 )
1015 };
1016
1017 trace!(
1018 "Before returning stream in {name}::execute for partition: {partition}"
1019 );
1020
1021 if preserve_order {
1022 let input_streams = rx
1025 .into_iter()
1026 .zip(spill_readers)
1027 .map(|(receiver, spill_stream)| {
1028 Box::pin(PerPartitionStream::new(
1030 Arc::clone(&schema_captured),
1031 receiver,
1032 Arc::clone(&abort_helper),
1033 Arc::clone(&reservation),
1034 spill_stream,
1035 1, BaselineMetrics::new(&metrics, partition),
1037 None, )) as SendableRecordBatchStream
1039 })
1040 .collect::<Vec<_>>();
1041 let fetch = None;
1046 let merge_reservation =
1047 MemoryConsumer::new(format!("{name}[Merge {partition}]"))
1048 .register(context.memory_pool());
1049 StreamingMergeBuilder::new()
1050 .with_streams(input_streams)
1051 .with_schema(schema_captured)
1052 .with_expressions(&sort_exprs.unwrap())
1053 .with_metrics(BaselineMetrics::new(&metrics, partition))
1054 .with_batch_size(context.session_config().batch_size())
1055 .with_fetch(fetch)
1056 .with_reservation(merge_reservation)
1057 .with_spill_manager(spill_manager)
1058 .build()
1059 } else {
1060 let spill_stream = spill_readers
1062 .into_iter()
1063 .next()
1064 .expect("at least one spill reader should exist");
1065
1066 Ok(Box::pin(PerPartitionStream::new(
1067 schema_captured,
1068 rx.into_iter()
1069 .next()
1070 .expect("at least one receiver should exist"),
1071 abort_helper,
1072 reservation,
1073 spill_stream,
1074 num_input_partitions,
1075 BaselineMetrics::new(&metrics, partition),
1076 Some(context.session_config().batch_size()),
1077 )) as SendableRecordBatchStream)
1078 }
1079 })
1080 .try_flatten();
1081 let stream = RecordBatchStreamAdapter::new(schema, stream);
1082 Ok(Box::pin(stream))
1083 }
1084
1085 fn metrics(&self) -> Option<MetricsSet> {
1086 Some(self.metrics.clone_inner())
1087 }
1088
1089 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1090 if let Some(partition) = partition {
1091 let partition_count = self.partitioning().partition_count();
1092 if partition_count == 0 {
1093 return Ok(Statistics::new_unknown(&self.schema()));
1094 }
1095
1096 assert_or_internal_err!(
1097 partition < partition_count,
1098 "RepartitionExec invalid partition {} (expected less than {})",
1099 partition,
1100 partition_count
1101 );
1102
1103 let mut stats = self.input.partition_statistics(None)?;
1104
1105 stats.num_rows = stats
1107 .num_rows
1108 .get_value()
1109 .map(|rows| Precision::Inexact(rows / partition_count))
1110 .unwrap_or(Precision::Absent);
1111 stats.total_byte_size = stats
1112 .total_byte_size
1113 .get_value()
1114 .map(|bytes| Precision::Inexact(bytes / partition_count))
1115 .unwrap_or(Precision::Absent);
1116
1117 stats.column_statistics = stats
1119 .column_statistics
1120 .iter()
1121 .map(|_| ColumnStatistics::new_unknown())
1122 .collect();
1123
1124 Ok(stats)
1125 } else {
1126 self.input.partition_statistics(None)
1127 }
1128 }
1129
1130 fn cardinality_effect(&self) -> CardinalityEffect {
1131 CardinalityEffect::Equal
1132 }
1133
1134 fn try_swapping_with_projection(
1135 &self,
1136 projection: &ProjectionExec,
1137 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1138 if projection.expr().len() >= projection.input().schema().fields().len() {
1140 return Ok(None);
1141 }
1142
1143 if projection.benefits_from_input_partitioning()[0]
1145 || !all_columns(projection.expr())
1146 {
1147 return Ok(None);
1148 }
1149
1150 let new_projection = make_with_child(projection, self.input())?;
1151
1152 let new_partitioning = match self.partitioning() {
1153 Partitioning::Hash(partitions, size) => {
1154 let mut new_partitions = vec![];
1155 for partition in partitions {
1156 let Some(new_partition) =
1157 update_expr(partition, projection.expr(), false)?
1158 else {
1159 return Ok(None);
1160 };
1161 new_partitions.push(new_partition);
1162 }
1163 Partitioning::Hash(new_partitions, *size)
1164 }
1165 others => others.clone(),
1166 };
1167
1168 Ok(Some(Arc::new(RepartitionExec::try_new(
1169 new_projection,
1170 new_partitioning,
1171 )?)))
1172 }
1173
1174 fn gather_filters_for_pushdown(
1175 &self,
1176 _phase: FilterPushdownPhase,
1177 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1178 _config: &ConfigOptions,
1179 ) -> Result<FilterDescription> {
1180 FilterDescription::from_children(parent_filters, &self.children())
1181 }
1182
1183 fn handle_child_pushdown_result(
1184 &self,
1185 _phase: FilterPushdownPhase,
1186 child_pushdown_result: ChildPushdownResult,
1187 _config: &ConfigOptions,
1188 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1189 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1190 }
1191
1192 fn try_pushdown_sort(
1193 &self,
1194 order: &[PhysicalSortExpr],
1195 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
1196 if !self.maintains_input_order()[0] {
1199 return Ok(SortOrderPushdownResult::Unsupported);
1200 }
1201
1202 self.input.try_pushdown_sort(order)?.try_map(|new_input| {
1204 let mut new_repartition =
1205 RepartitionExec::try_new(new_input, self.partitioning().clone())?;
1206 if self.preserve_order {
1207 new_repartition = new_repartition.with_preserve_order();
1208 }
1209 Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
1210 })
1211 }
1212
1213 fn repartitioned(
1214 &self,
1215 target_partitions: usize,
1216 _config: &ConfigOptions,
1217 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1218 use Partitioning::*;
1219 let mut new_properties = PlanProperties::clone(&self.cache);
1220 new_properties.partitioning = match new_properties.partitioning {
1221 RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1222 Hash(hash, _) => Hash(hash, target_partitions),
1223 UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1224 };
1225 Ok(Some(Arc::new(Self {
1226 input: Arc::clone(&self.input),
1227 state: Arc::clone(&self.state),
1228 metrics: self.metrics.clone(),
1229 preserve_order: self.preserve_order,
1230 cache: new_properties.into(),
1231 })))
1232 }
1233}
1234
1235impl RepartitionExec {
1236 pub fn try_new(
1240 input: Arc<dyn ExecutionPlan>,
1241 partitioning: Partitioning,
1242 ) -> Result<Self> {
1243 let preserve_order = false;
1244 let cache = Self::compute_properties(&input, partitioning, preserve_order);
1245 Ok(RepartitionExec {
1246 input,
1247 state: Default::default(),
1248 metrics: ExecutionPlanMetricsSet::new(),
1249 preserve_order,
1250 cache: Arc::new(cache),
1251 })
1252 }
1253
1254 fn maintains_input_order_helper(
1255 input: &Arc<dyn ExecutionPlan>,
1256 preserve_order: bool,
1257 ) -> Vec<bool> {
1258 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1260 }
1261
1262 fn eq_properties_helper(
1263 input: &Arc<dyn ExecutionPlan>,
1264 preserve_order: bool,
1265 ) -> EquivalenceProperties {
1266 let mut eq_properties = input.equivalence_properties().clone();
1268 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1270 eq_properties.clear_orderings();
1271 }
1272 if input.output_partitioning().partition_count() > 1 {
1275 eq_properties.clear_per_partition_constants();
1276 }
1277 eq_properties
1278 }
1279
1280 fn compute_properties(
1282 input: &Arc<dyn ExecutionPlan>,
1283 partitioning: Partitioning,
1284 preserve_order: bool,
1285 ) -> PlanProperties {
1286 PlanProperties::new(
1287 Self::eq_properties_helper(input, preserve_order),
1288 partitioning,
1289 input.pipeline_behavior(),
1290 input.boundedness(),
1291 )
1292 .with_scheduling_type(SchedulingType::Cooperative)
1293 .with_evaluation_type(EvaluationType::Eager)
1294 }
1295
1296 pub fn with_preserve_order(mut self) -> Self {
1304 self.preserve_order =
1305 self.input.output_ordering().is_some() &&
1307 self.input.output_partitioning().partition_count() > 1;
1310 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1311 Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties);
1312 self
1313 }
1314
1315 fn sort_exprs(&self) -> Option<&LexOrdering> {
1317 if self.preserve_order {
1318 self.input.output_ordering()
1319 } else {
1320 None
1321 }
1322 }
1323
1324 async fn pull_from_input(
1329 mut stream: SendableRecordBatchStream,
1330 mut output_channels: HashMap<usize, OutputChannel>,
1331 partitioning: Partitioning,
1332 metrics: RepartitionMetrics,
1333 input_partition: usize,
1334 num_input_partitions: usize,
1335 ) -> Result<()> {
1336 let mut partitioner = match &partitioning {
1337 Partitioning::Hash(exprs, num_partitions) => {
1338 BatchPartitioner::new_hash_partitioner(
1339 exprs.clone(),
1340 *num_partitions,
1341 metrics.repartition_time.clone(),
1342 )
1343 }
1344 Partitioning::RoundRobinBatch(num_partitions) => {
1345 BatchPartitioner::new_round_robin_partitioner(
1346 *num_partitions,
1347 metrics.repartition_time.clone(),
1348 input_partition,
1349 num_input_partitions,
1350 )
1351 }
1352 other => {
1353 return not_impl_err!("Unsupported repartitioning scheme {other:?}");
1354 }
1355 };
1356
1357 let mut batches_until_yield = partitioner.num_partitions();
1359 while !output_channels.is_empty() {
1360 let timer = metrics.fetch_time.timer();
1362 let result = stream.next().await;
1363 timer.done();
1364
1365 let batch = match result {
1367 Some(result) => result?,
1368 None => break,
1369 };
1370
1371 if batch.num_rows() == 0 {
1373 continue;
1374 }
1375
1376 for res in partitioner.partition_iter(batch)? {
1377 let (partition, batch) = res?;
1378 let size = batch.get_array_memory_size();
1379
1380 let timer = metrics.send_time[partition].timer();
1381 if let Some(channel) = output_channels.get_mut(&partition) {
1383 let (batch_to_send, is_memory_batch) =
1384 match channel.reservation.lock().try_grow(size) {
1385 Ok(_) => {
1386 (RepartitionBatch::Memory(batch), true)
1388 }
1389 Err(_) => {
1390 channel.spill_writer.push_batch(&batch)?;
1393 (RepartitionBatch::Spilled, false)
1395 }
1396 };
1397
1398 if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
1399 if is_memory_batch {
1402 channel.reservation.lock().shrink(size);
1403 }
1404 output_channels.remove(&partition);
1405 }
1406 }
1407 timer.done();
1408 }
1409
1410 if batches_until_yield == 0 {
1427 tokio::task::yield_now().await;
1428 batches_until_yield = partitioner.num_partitions();
1429 } else {
1430 batches_until_yield -= 1;
1431 }
1432 }
1433
1434 Ok(())
1437 }
1438
1439 async fn wait_for_task(
1445 input_task: SpawnedTask<Result<()>>,
1446 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1447 ) {
1448 match input_task.join().await {
1452 Err(e) => {
1454 let e = Arc::new(e);
1455
1456 for (_, tx) in txs {
1457 let err = Err(DataFusionError::Context(
1458 "Join Error".to_string(),
1459 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1460 ));
1461 tx.send(Some(err)).await.ok();
1462 }
1463 }
1464 Ok(Err(e)) => {
1466 let e = Arc::new(e);
1468
1469 for (_, tx) in txs {
1470 let err = Err(DataFusionError::from(&e));
1472 tx.send(Some(err)).await.ok();
1473 }
1474 }
1475 Ok(Ok(())) => {
1477 for (_partition, tx) in txs {
1479 tx.send(None).await.ok();
1480 }
1481 }
1482 }
1483 }
1484}
1485
1486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1529enum StreamState {
1530 ReadingMemory,
1532 ReadingSpilled,
1535}
1536
1537struct PerPartitionStream {
1540 schema: SchemaRef,
1542
1543 receiver: DistributionReceiver<MaybeBatch>,
1545
1546 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1548
1549 reservation: SharedMemoryReservation,
1551
1552 spill_stream: SendableRecordBatchStream,
1554
1555 state: StreamState,
1557
1558 remaining_partitions: usize,
1562
1563 baseline_metrics: BaselineMetrics,
1565
1566 batch_coalescer: Option<LimitedBatchCoalescer>,
1568}
1569
1570impl PerPartitionStream {
1571 #[expect(clippy::too_many_arguments)]
1572 fn new(
1573 schema: SchemaRef,
1574 receiver: DistributionReceiver<MaybeBatch>,
1575 drop_helper: Arc<Vec<SpawnedTask<()>>>,
1576 reservation: SharedMemoryReservation,
1577 spill_stream: SendableRecordBatchStream,
1578 num_input_partitions: usize,
1579 baseline_metrics: BaselineMetrics,
1580 batch_size: Option<usize>,
1581 ) -> Self {
1582 let batch_coalescer =
1583 batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
1584 Self {
1585 schema,
1586 receiver,
1587 _drop_helper: drop_helper,
1588 reservation,
1589 spill_stream,
1590 state: StreamState::ReadingMemory,
1591 remaining_partitions: num_input_partitions,
1592 baseline_metrics,
1593 batch_coalescer,
1594 }
1595 }
1596
1597 fn poll_next_inner(
1598 self: &mut Pin<&mut Self>,
1599 cx: &mut Context<'_>,
1600 ) -> Poll<Option<Result<RecordBatch>>> {
1601 use futures::StreamExt;
1602 let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1603 let _timer = cloned_time.timer();
1604
1605 loop {
1606 match self.state {
1607 StreamState::ReadingMemory => {
1608 let value = match self.receiver.recv().poll_unpin(cx) {
1610 Poll::Ready(v) => v,
1611 Poll::Pending => {
1612 return Poll::Pending;
1614 }
1615 };
1616
1617 match value {
1618 Some(Some(v)) => match v {
1619 Ok(RepartitionBatch::Memory(batch)) => {
1620 self.reservation
1622 .lock()
1623 .shrink(batch.get_array_memory_size());
1624 return Poll::Ready(Some(Ok(batch)));
1625 }
1626 Ok(RepartitionBatch::Spilled) => {
1627 self.state = StreamState::ReadingSpilled;
1631 continue;
1632 }
1633 Err(e) => {
1634 return Poll::Ready(Some(Err(e)));
1635 }
1636 },
1637 Some(None) => {
1638 self.remaining_partitions -= 1;
1640 if self.remaining_partitions == 0 {
1641 return Poll::Ready(None);
1643 }
1644 continue;
1646 }
1647 None => {
1648 return Poll::Ready(None);
1650 }
1651 }
1652 }
1653 StreamState::ReadingSpilled => {
1654 match self.spill_stream.poll_next_unpin(cx) {
1656 Poll::Ready(Some(Ok(batch))) => {
1657 self.state = StreamState::ReadingMemory;
1658 return Poll::Ready(Some(Ok(batch)));
1659 }
1660 Poll::Ready(Some(Err(e))) => {
1661 return Poll::Ready(Some(Err(e)));
1662 }
1663 Poll::Ready(None) => {
1664 self.state = StreamState::ReadingMemory;
1666 }
1667 Poll::Pending => {
1668 return Poll::Pending;
1671 }
1672 }
1673 }
1674 }
1675 }
1676 }
1677
1678 fn poll_next_and_coalesce(
1679 self: &mut Pin<&mut Self>,
1680 cx: &mut Context<'_>,
1681 coalescer: &mut LimitedBatchCoalescer,
1682 ) -> Poll<Option<Result<RecordBatch>>> {
1683 let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1684 let mut completed = false;
1685
1686 loop {
1687 if let Some(batch) = coalescer.next_completed_batch() {
1688 return Poll::Ready(Some(Ok(batch)));
1689 }
1690 if completed {
1691 return Poll::Ready(None);
1692 }
1693
1694 match ready!(self.poll_next_inner(cx)) {
1695 Some(Ok(batch)) => {
1696 let _timer = cloned_time.timer();
1697 if let Err(err) = coalescer.push_batch(batch) {
1698 return Poll::Ready(Some(Err(err)));
1699 }
1700 }
1701 Some(err) => {
1702 return Poll::Ready(Some(err));
1703 }
1704 None => {
1705 completed = true;
1706 let _timer = cloned_time.timer();
1707 if let Err(err) = coalescer.finish() {
1708 return Poll::Ready(Some(Err(err)));
1709 }
1710 }
1711 }
1712 }
1713 }
1714}
1715
1716impl Stream for PerPartitionStream {
1717 type Item = Result<RecordBatch>;
1718
1719 fn poll_next(
1720 mut self: Pin<&mut Self>,
1721 cx: &mut Context<'_>,
1722 ) -> Poll<Option<Self::Item>> {
1723 let poll;
1724 if let Some(mut coalescer) = self.batch_coalescer.take() {
1725 poll = self.poll_next_and_coalesce(cx, &mut coalescer);
1726 self.batch_coalescer = Some(coalescer);
1727 } else {
1728 poll = self.poll_next_inner(cx);
1729 }
1730 self.baseline_metrics.record_poll(poll)
1731 }
1732}
1733
1734impl RecordBatchStream for PerPartitionStream {
1735 fn schema(&self) -> SchemaRef {
1737 Arc::clone(&self.schema)
1738 }
1739}
1740
1741#[cfg(test)]
1742mod tests {
1743 use std::collections::HashSet;
1744
1745 use super::*;
1746 use crate::test::TestMemoryExec;
1747 use crate::{
1748 test::{
1749 assert_is_pending,
1750 exec::{
1751 BarrierExec, BlockingExec, ErrorExec, MockExec,
1752 assert_strong_count_converges_to_zero,
1753 },
1754 },
1755 {collect, expressions::col},
1756 };
1757
1758 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1759 use arrow::datatypes::{DataType, Field, Schema};
1760 use datafusion_common::cast::as_string_array;
1761 use datafusion_common::exec_err;
1762 use datafusion_common::test_util::batches_to_sort_string;
1763 use datafusion_common_runtime::JoinSet;
1764 use datafusion_execution::config::SessionConfig;
1765 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1766 use insta::assert_snapshot;
1767
1768 #[tokio::test]
1769 async fn one_to_many_round_robin() -> Result<()> {
1770 let schema = test_schema();
1772 let partition = create_vec_batches(50);
1773 let partitions = vec![partition];
1774
1775 let output_partitions =
1777 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1778
1779 assert_eq!(4, output_partitions.len());
1780 for partition in &output_partitions {
1781 assert_eq!(1, partition.len());
1782 }
1783 assert_eq!(13 * 8, output_partitions[0][0].num_rows());
1784 assert_eq!(13 * 8, output_partitions[1][0].num_rows());
1785 assert_eq!(12 * 8, output_partitions[2][0].num_rows());
1786 assert_eq!(12 * 8, output_partitions[3][0].num_rows());
1787
1788 Ok(())
1789 }
1790
1791 #[tokio::test]
1792 async fn many_to_one_round_robin() -> Result<()> {
1793 let schema = test_schema();
1795 let partition = create_vec_batches(50);
1796 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1797
1798 let output_partitions =
1800 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1801
1802 assert_eq!(1, output_partitions.len());
1803 assert_eq!(150 * 8, output_partitions[0][0].num_rows());
1804
1805 Ok(())
1806 }
1807
1808 #[tokio::test]
1809 async fn many_to_many_round_robin() -> Result<()> {
1810 let schema = test_schema();
1812 let partition = create_vec_batches(50);
1813 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1814
1815 let output_partitions =
1817 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1818
1819 let total_rows_per_partition = 8 * 50 * 3 / 5;
1820 assert_eq!(5, output_partitions.len());
1821 for partition in output_partitions {
1822 assert_eq!(1, partition.len());
1823 assert_eq!(total_rows_per_partition, partition[0].num_rows());
1824 }
1825
1826 Ok(())
1827 }
1828
1829 #[tokio::test]
1830 async fn many_to_many_hash_partition() -> Result<()> {
1831 let schema = test_schema();
1833 let partition = create_vec_batches(50);
1834 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1835
1836 let output_partitions = repartition(
1837 &schema,
1838 partitions,
1839 Partitioning::Hash(vec![col("c0", &schema)?], 8),
1840 )
1841 .await?;
1842
1843 let total_rows: usize = output_partitions
1844 .iter()
1845 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1846 .sum();
1847
1848 assert_eq!(8, output_partitions.len());
1849 assert_eq!(total_rows, 8 * 50 * 3);
1850
1851 Ok(())
1852 }
1853
1854 #[tokio::test]
1855 async fn test_repartition_with_coalescing() -> Result<()> {
1856 let schema = test_schema();
1857 let partition = create_vec_batches(50);
1859 let partitions = vec![partition.clone(), partition.clone()];
1860 let partitioning = Partitioning::RoundRobinBatch(1);
1861
1862 let session_config = SessionConfig::new().with_batch_size(200);
1863 let task_ctx = TaskContext::default().with_session_config(session_config);
1864 let task_ctx = Arc::new(task_ctx);
1865
1866 let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
1868 let exec = RepartitionExec::try_new(exec, partitioning)?;
1869
1870 for i in 0..exec.partitioning().partition_count() {
1871 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1872 while let Some(result) = stream.next().await {
1873 let batch = result?;
1874 assert_eq!(200, batch.num_rows());
1875 }
1876 }
1877 Ok(())
1878 }
1879
1880 fn test_schema() -> Arc<Schema> {
1881 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1882 }
1883
1884 async fn repartition(
1885 schema: &SchemaRef,
1886 input_partitions: Vec<Vec<RecordBatch>>,
1887 partitioning: Partitioning,
1888 ) -> Result<Vec<Vec<RecordBatch>>> {
1889 let task_ctx = Arc::new(TaskContext::default());
1890 let exec =
1892 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1893 let exec = RepartitionExec::try_new(exec, partitioning)?;
1894
1895 let mut output_partitions = vec![];
1897 for i in 0..exec.partitioning().partition_count() {
1898 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1900 let mut batches = vec![];
1901 while let Some(result) = stream.next().await {
1902 batches.push(result?);
1903 }
1904 output_partitions.push(batches);
1905 }
1906 Ok(output_partitions)
1907 }
1908
1909 #[tokio::test]
1910 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1911 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1912 SpawnedTask::spawn(async move {
1913 let schema = test_schema();
1915 let partition = create_vec_batches(50);
1916 let partitions =
1917 vec![partition.clone(), partition.clone(), partition.clone()];
1918
1919 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1921 });
1922
1923 let output_partitions = handle.join().await.unwrap().unwrap();
1924
1925 let total_rows_per_partition = 8 * 50 * 3 / 5;
1926 assert_eq!(5, output_partitions.len());
1927 for partition in output_partitions {
1928 assert_eq!(1, partition.len());
1929 assert_eq!(total_rows_per_partition, partition[0].num_rows());
1930 }
1931
1932 Ok(())
1933 }
1934
1935 #[tokio::test]
1936 async fn unsupported_partitioning() {
1937 let task_ctx = Arc::new(TaskContext::default());
1938 let batch = RecordBatch::try_from_iter(vec![(
1940 "my_awesome_field",
1941 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1942 )])
1943 .unwrap();
1944
1945 let schema = batch.schema();
1946 let input = MockExec::new(vec![Ok(batch)], schema);
1947 let partitioning = Partitioning::UnknownPartitioning(1);
1951 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1952 let output_stream = exec.execute(0, task_ctx).unwrap();
1953
1954 let result_string = crate::common::collect(output_stream)
1956 .await
1957 .unwrap_err()
1958 .to_string();
1959 assert!(
1960 result_string
1961 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1962 "actual: {result_string}"
1963 );
1964 }
1965
1966 #[tokio::test]
1967 async fn error_for_input_exec() {
1968 let task_ctx = Arc::new(TaskContext::default());
1972 let input = ErrorExec::new();
1973 let partitioning = Partitioning::RoundRobinBatch(1);
1974 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1975
1976 let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1978
1979 assert!(
1980 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1981 "actual: {result_string}"
1982 );
1983 }
1984
1985 #[tokio::test]
1986 async fn repartition_with_error_in_stream() {
1987 let task_ctx = Arc::new(TaskContext::default());
1988 let batch = RecordBatch::try_from_iter(vec![(
1989 "my_awesome_field",
1990 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1991 )])
1992 .unwrap();
1993
1994 let err = exec_err!("bad data error");
1997
1998 let schema = batch.schema();
1999 let input = MockExec::new(vec![Ok(batch), err], schema);
2000 let partitioning = Partitioning::RoundRobinBatch(1);
2001 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2002
2003 let output_stream = exec.execute(0, task_ctx).unwrap();
2006
2007 let result_string = crate::common::collect(output_stream)
2009 .await
2010 .unwrap_err()
2011 .to_string();
2012 assert!(
2013 result_string.contains("bad data error"),
2014 "actual: {result_string}"
2015 );
2016 }
2017
2018 #[tokio::test]
2019 async fn repartition_with_delayed_stream() {
2020 let task_ctx = Arc::new(TaskContext::default());
2021 let batch1 = RecordBatch::try_from_iter(vec![(
2022 "my_awesome_field",
2023 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2024 )])
2025 .unwrap();
2026
2027 let batch2 = RecordBatch::try_from_iter(vec![(
2028 "my_awesome_field",
2029 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2030 )])
2031 .unwrap();
2032
2033 let schema = batch1.schema();
2036 let expected_batches = vec![batch1.clone(), batch2.clone()];
2037 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
2038 let partitioning = Partitioning::RoundRobinBatch(1);
2039
2040 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2041
2042 assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
2043 +------------------+
2044 | my_awesome_field |
2045 +------------------+
2046 | bar |
2047 | baz |
2048 | foo |
2049 | frob |
2050 +------------------+
2051 ");
2052
2053 let output_stream = exec.execute(0, task_ctx).unwrap();
2054 let batches = crate::common::collect(output_stream).await.unwrap();
2055
2056 assert_snapshot!(batches_to_sort_string(&batches), @r"
2057 +------------------+
2058 | my_awesome_field |
2059 +------------------+
2060 | bar |
2061 | baz |
2062 | foo |
2063 | frob |
2064 +------------------+
2065 ");
2066 }
2067
2068 #[tokio::test]
2069 async fn robin_repartition_with_dropping_output_stream() {
2070 let task_ctx = Arc::new(TaskContext::default());
2071 let partitioning = Partitioning::RoundRobinBatch(2);
2072 let input = Arc::new(make_barrier_exec());
2075
2076 let exec = RepartitionExec::try_new(
2078 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2079 partitioning,
2080 )
2081 .unwrap();
2082
2083 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2084 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2085
2086 drop(output_stream0);
2089
2090 let mut background_task = JoinSet::new();
2092 background_task.spawn(async move {
2093 input.wait().await;
2094 });
2095
2096 let batches = crate::common::collect(output_stream1).await.unwrap();
2098
2099 assert_snapshot!(batches_to_sort_string(&batches), @r"
2100 +------------------+
2101 | my_awesome_field |
2102 +------------------+
2103 | baz |
2104 | frob |
2105 | gar |
2106 | goo |
2107 +------------------+
2108 ");
2109 }
2110
2111 #[tokio::test]
2112 async fn hash_repartition_with_dropping_output_stream() {
2116 let task_ctx = Arc::new(TaskContext::default());
2117 let partitioning = Partitioning::Hash(
2118 vec![Arc::new(crate::expressions::Column::new(
2119 "my_awesome_field",
2120 0,
2121 ))],
2122 2,
2123 );
2124
2125 let input = Arc::new(make_barrier_exec());
2127 let exec = RepartitionExec::try_new(
2128 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2129 partitioning.clone(),
2130 )
2131 .unwrap();
2132 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2133 let mut background_task = JoinSet::new();
2134 background_task.spawn(async move {
2135 input.wait().await;
2136 });
2137 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
2138
2139 let items_vec = str_batches_to_vec(&batches_without_drop);
2141 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
2142 assert_eq!(items_vec.len(), items_set.len());
2143 let source_str_set: HashSet<&str> =
2144 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
2145 .iter()
2146 .copied()
2147 .collect();
2148 assert_eq!(items_set.difference(&source_str_set).count(), 0);
2149
2150 let input = Arc::new(make_barrier_exec());
2152 let exec = RepartitionExec::try_new(
2153 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2154 partitioning,
2155 )
2156 .unwrap();
2157 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2158 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2159 drop(output_stream0);
2162 let mut background_task = JoinSet::new();
2163 background_task.spawn(async move {
2164 input.wait().await;
2165 });
2166 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
2167
2168 let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2169 let items_set_with_drop: HashSet<&str> =
2170 items_vec_with_drop.iter().copied().collect();
2171 assert_eq!(
2172 items_set_with_drop.symmetric_difference(&items_set).count(),
2173 0
2174 );
2175 }
2176
2177 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
2178 batches
2179 .iter()
2180 .flat_map(|batch| {
2181 assert_eq!(batch.columns().len(), 1);
2182 let string_array = as_string_array(batch.column(0))
2183 .expect("Unexpected type for repartitioned batch");
2184
2185 string_array
2186 .iter()
2187 .map(|v| v.expect("Unexpected null"))
2188 .collect::<Vec<_>>()
2189 })
2190 .collect::<Vec<_>>()
2191 }
2192
2193 fn make_barrier_exec() -> BarrierExec {
2195 let batch1 = RecordBatch::try_from_iter(vec![(
2196 "my_awesome_field",
2197 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2198 )])
2199 .unwrap();
2200
2201 let batch2 = RecordBatch::try_from_iter(vec![(
2202 "my_awesome_field",
2203 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2204 )])
2205 .unwrap();
2206
2207 let batch3 = RecordBatch::try_from_iter(vec![(
2208 "my_awesome_field",
2209 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
2210 )])
2211 .unwrap();
2212
2213 let batch4 = RecordBatch::try_from_iter(vec![(
2214 "my_awesome_field",
2215 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
2216 )])
2217 .unwrap();
2218
2219 let schema = batch1.schema();
2222 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
2223 }
2224
2225 #[tokio::test]
2226 async fn test_drop_cancel() -> Result<()> {
2227 let task_ctx = Arc::new(TaskContext::default());
2228 let schema =
2229 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2230
2231 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
2232 let refs = blocking_exec.refs();
2233 let repartition_exec = Arc::new(RepartitionExec::try_new(
2234 blocking_exec,
2235 Partitioning::UnknownPartitioning(1),
2236 )?);
2237
2238 let fut = collect(repartition_exec, task_ctx);
2239 let mut fut = fut.boxed();
2240
2241 assert_is_pending(&mut fut);
2242 drop(fut);
2243 assert_strong_count_converges_to_zero(refs).await;
2244
2245 Ok(())
2246 }
2247
2248 #[tokio::test]
2249 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
2250 let task_ctx = Arc::new(TaskContext::default());
2251 let batch = RecordBatch::try_from_iter(vec![(
2252 "a",
2253 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2254 )])
2255 .unwrap();
2256 let partitioning = Partitioning::Hash(
2257 vec![Arc::new(crate::expressions::Column::new("a", 0))],
2258 2,
2259 );
2260 let schema = batch.schema();
2261 let input = MockExec::new(vec![Ok(batch)], schema);
2262 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2263 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2264 let batch0 = crate::common::collect(output_stream0).await.unwrap();
2265 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2266 let batch1 = crate::common::collect(output_stream1).await.unwrap();
2267 assert!(batch0.is_empty() || batch1.is_empty());
2268 Ok(())
2269 }
2270
2271 #[tokio::test]
2272 async fn repartition_with_spilling() -> Result<()> {
2273 let schema = test_schema();
2275 let partition = create_vec_batches(50);
2276 let input_partitions = vec![partition];
2277 let partitioning = Partitioning::RoundRobinBatch(4);
2278
2279 let runtime = RuntimeEnvBuilder::default()
2281 .with_memory_limit(1, 1.0)
2282 .build_arc()?;
2283
2284 let task_ctx = TaskContext::default().with_runtime(runtime);
2285 let task_ctx = Arc::new(task_ctx);
2286
2287 let exec =
2289 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2290 let exec = RepartitionExec::try_new(exec, partitioning)?;
2291
2292 let mut total_rows = 0;
2294 for i in 0..exec.partitioning().partition_count() {
2295 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2296 while let Some(result) = stream.next().await {
2297 let batch = result?;
2298 total_rows += batch.num_rows();
2299 }
2300 }
2301
2302 assert_eq!(total_rows, 50 * 8);
2304
2305 let metrics = exec.metrics().unwrap();
2307 assert!(
2308 metrics.spill_count().unwrap() > 0,
2309 "Expected spill_count > 0, but got {:?}",
2310 metrics.spill_count()
2311 );
2312 println!("Spilled {} times", metrics.spill_count().unwrap());
2313 assert!(
2314 metrics.spilled_bytes().unwrap() > 0,
2315 "Expected spilled_bytes > 0, but got {:?}",
2316 metrics.spilled_bytes()
2317 );
2318 println!(
2319 "Spilled {} bytes in {} spills",
2320 metrics.spilled_bytes().unwrap(),
2321 metrics.spill_count().unwrap()
2322 );
2323 assert!(
2324 metrics.spilled_rows().unwrap() > 0,
2325 "Expected spilled_rows > 0, but got {:?}",
2326 metrics.spilled_rows()
2327 );
2328 println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2329
2330 Ok(())
2331 }
2332
2333 #[tokio::test]
2334 async fn repartition_with_partial_spilling() -> Result<()> {
2335 let schema = test_schema();
2337 let partition = create_vec_batches(50);
2338 let input_partitions = vec![partition];
2339 let partitioning = Partitioning::RoundRobinBatch(4);
2340
2341 let runtime = RuntimeEnvBuilder::default()
2344 .with_memory_limit(2 * 1024, 1.0)
2345 .build_arc()?;
2346
2347 let task_ctx = TaskContext::default().with_runtime(runtime);
2348 let task_ctx = Arc::new(task_ctx);
2349
2350 let exec =
2352 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2353 let exec = RepartitionExec::try_new(exec, partitioning)?;
2354
2355 let mut total_rows = 0;
2357 for i in 0..exec.partitioning().partition_count() {
2358 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2359 while let Some(result) = stream.next().await {
2360 let batch = result?;
2361 total_rows += batch.num_rows();
2362 }
2363 }
2364
2365 assert_eq!(total_rows, 50 * 8);
2367
2368 let metrics = exec.metrics().unwrap();
2370 let spill_count = metrics.spill_count().unwrap();
2371 let spilled_rows = metrics.spilled_rows().unwrap();
2372 let spilled_bytes = metrics.spilled_bytes().unwrap();
2373
2374 assert!(
2375 spill_count > 0,
2376 "Expected some spilling to occur, but got spill_count={spill_count}"
2377 );
2378 assert!(
2379 spilled_rows > 0 && spilled_rows < total_rows,
2380 "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2381 );
2382 assert!(
2383 spilled_bytes > 0,
2384 "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2385 );
2386
2387 println!(
2388 "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2389 spilled_rows,
2390 total_rows,
2391 (spilled_rows as f64 / total_rows as f64) * 100.0,
2392 spill_count,
2393 spilled_bytes
2394 );
2395
2396 Ok(())
2397 }
2398
2399 #[tokio::test]
2400 async fn repartition_without_spilling() -> Result<()> {
2401 let schema = test_schema();
2403 let partition = create_vec_batches(50);
2404 let input_partitions = vec![partition];
2405 let partitioning = Partitioning::RoundRobinBatch(4);
2406
2407 let runtime = RuntimeEnvBuilder::default()
2409 .with_memory_limit(10 * 1024 * 1024, 1.0) .build_arc()?;
2411
2412 let task_ctx = TaskContext::default().with_runtime(runtime);
2413 let task_ctx = Arc::new(task_ctx);
2414
2415 let exec =
2417 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2418 let exec = RepartitionExec::try_new(exec, partitioning)?;
2419
2420 let mut total_rows = 0;
2422 for i in 0..exec.partitioning().partition_count() {
2423 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2424 while let Some(result) = stream.next().await {
2425 let batch = result?;
2426 total_rows += batch.num_rows();
2427 }
2428 }
2429
2430 assert_eq!(total_rows, 50 * 8);
2432
2433 let metrics = exec.metrics().unwrap();
2435 assert_eq!(
2436 metrics.spill_count(),
2437 Some(0),
2438 "Expected no spilling, but got spill_count={:?}",
2439 metrics.spill_count()
2440 );
2441 assert_eq!(
2442 metrics.spilled_bytes(),
2443 Some(0),
2444 "Expected no bytes spilled, but got spilled_bytes={:?}",
2445 metrics.spilled_bytes()
2446 );
2447 assert_eq!(
2448 metrics.spilled_rows(),
2449 Some(0),
2450 "Expected no rows spilled, but got spilled_rows={:?}",
2451 metrics.spilled_rows()
2452 );
2453
2454 println!("No spilling occurred - all data processed in memory");
2455
2456 Ok(())
2457 }
2458
2459 #[tokio::test]
2460 async fn oom() -> Result<()> {
2461 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2462
2463 let schema = test_schema();
2465 let partition = create_vec_batches(50);
2466 let input_partitions = vec![partition];
2467 let partitioning = Partitioning::RoundRobinBatch(4);
2468
2469 let runtime = RuntimeEnvBuilder::default()
2471 .with_memory_limit(1, 1.0)
2472 .with_disk_manager_builder(
2473 DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2474 )
2475 .build_arc()?;
2476
2477 let task_ctx = TaskContext::default().with_runtime(runtime);
2478 let task_ctx = Arc::new(task_ctx);
2479
2480 let exec =
2482 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2483 let exec = RepartitionExec::try_new(exec, partitioning)?;
2484
2485 for i in 0..exec.partitioning().partition_count() {
2487 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2488 let err = stream.next().await.unwrap().unwrap_err();
2489 let err = err.find_root();
2490 assert!(
2491 matches!(err, DataFusionError::ResourcesExhausted(_)),
2492 "Wrong error type: {err}",
2493 );
2494 }
2495
2496 Ok(())
2497 }
2498
2499 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2501 let batch = create_batch();
2502 std::iter::repeat_n(batch, n).collect()
2503 }
2504
2505 fn create_batch() -> RecordBatch {
2507 let schema = test_schema();
2508 RecordBatch::try_new(
2509 schema,
2510 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2511 )
2512 .unwrap()
2513 }
2514
2515 fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2517 let schema = test_schema();
2518 (0..num_batches)
2519 .map(|i| {
2520 let start = (i * 8) as u32;
2521 RecordBatch::try_new(
2522 Arc::clone(&schema),
2523 vec![Arc::new(UInt32Array::from(
2524 (start..start + 8).collect::<Vec<_>>(),
2525 ))],
2526 )
2527 .unwrap()
2528 })
2529 .collect()
2530 }
2531
2532 #[tokio::test]
2533 async fn test_repartition_ordering_with_spilling() -> Result<()> {
2534 let schema = test_schema();
2539 let partition = create_ordered_batches(20);
2542 let input_partitions = vec![partition];
2543
2544 let partitioning = Partitioning::RoundRobinBatch(2);
2546
2547 let runtime = RuntimeEnvBuilder::default()
2549 .with_memory_limit(1, 1.0)
2550 .build_arc()?;
2551
2552 let task_ctx = TaskContext::default().with_runtime(runtime);
2553 let task_ctx = Arc::new(task_ctx);
2554
2555 let exec =
2557 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2558 let exec = RepartitionExec::try_new(exec, partitioning)?;
2559
2560 let mut all_batches = Vec::new();
2562 for i in 0..exec.partitioning().partition_count() {
2563 let mut partition_batches = Vec::new();
2564 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2565 while let Some(result) = stream.next().await {
2566 let batch = result?;
2567 partition_batches.push(batch);
2568 }
2569 all_batches.push(partition_batches);
2570 }
2571
2572 let metrics = exec.metrics().unwrap();
2574 assert!(
2575 metrics.spill_count().unwrap() > 0,
2576 "Expected spilling to occur, but spill_count = 0"
2577 );
2578
2579 for (partition_idx, batches) in all_batches.iter().enumerate() {
2582 let mut last_value = None;
2583 for batch in batches {
2584 let array = batch
2585 .column(0)
2586 .as_any()
2587 .downcast_ref::<UInt32Array>()
2588 .unwrap();
2589
2590 for i in 0..array.len() {
2591 let value = array.value(i);
2592 if let Some(last) = last_value {
2593 assert!(
2594 value > last,
2595 "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2596 );
2597 }
2598 last_value = Some(value);
2599 }
2600 }
2601 }
2602
2603 Ok(())
2604 }
2605}
2606
2607#[cfg(test)]
2608mod test {
2609 use arrow::array::record_batch;
2610 use arrow::compute::SortOptions;
2611 use arrow::datatypes::{DataType, Field, Schema};
2612 use datafusion_common::assert_batches_eq;
2613
2614 use super::*;
2615 use crate::test::TestMemoryExec;
2616 use crate::union::UnionExec;
2617
2618 use datafusion_physical_expr::expressions::col;
2619 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2620
2621 macro_rules! assert_plan {
2626 ($PLAN: expr, @ $EXPECTED: expr) => {
2627 let formatted = crate::displayable($PLAN).indent(true).to_string();
2628
2629 insta::assert_snapshot!(
2630 formatted,
2631 @$EXPECTED
2632 );
2633 };
2634 }
2635
2636 #[tokio::test]
2637 async fn test_preserve_order() -> Result<()> {
2638 let schema = test_schema();
2639 let sort_exprs = sort_exprs(&schema);
2640 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2641 let source2 = sorted_memory_exec(&schema, sort_exprs);
2642 let union = UnionExec::try_new(vec![source1, source2])?;
2644 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2645 .with_preserve_order();
2646
2647 assert_plan!(&exec, @r"
2649 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2650 UnionExec
2651 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2652 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2653 ");
2654 Ok(())
2655 }
2656
2657 #[tokio::test]
2658 async fn test_preserve_order_one_partition() -> Result<()> {
2659 let schema = test_schema();
2660 let sort_exprs = sort_exprs(&schema);
2661 let source = sorted_memory_exec(&schema, sort_exprs);
2662 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2664 .with_preserve_order();
2665
2666 assert_plan!(&exec, @r"
2668 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
2669 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2670 ");
2671
2672 Ok(())
2673 }
2674
2675 #[tokio::test]
2676 async fn test_preserve_order_input_not_sorted() -> Result<()> {
2677 let schema = test_schema();
2678 let source1 = memory_exec(&schema);
2679 let source2 = memory_exec(&schema);
2680 let union = UnionExec::try_new(vec![source1, source2])?;
2682 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2683 .with_preserve_order();
2684
2685 assert_plan!(&exec, @r"
2687 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2688 UnionExec
2689 DataSourceExec: partitions=1, partition_sizes=[0]
2690 DataSourceExec: partitions=1, partition_sizes=[0]
2691 ");
2692 Ok(())
2693 }
2694
2695 #[tokio::test]
2696 async fn test_preserve_order_with_spilling() -> Result<()> {
2697 use datafusion_execution::TaskContext;
2698 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2699
2700 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2704 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2705 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2706 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2707 let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
2708 let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
2709 let schema = batch1.schema();
2710 let sort_exprs = LexOrdering::new([PhysicalSortExpr {
2711 expr: col("c0", &schema).unwrap(),
2712 options: SortOptions::default().asc(),
2713 }])
2714 .unwrap();
2715 let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
2716 let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
2717 let input_partitions = vec![partition1, partition2];
2718
2719 let runtime = RuntimeEnvBuilder::default()
2722 .with_memory_limit(64, 1.0)
2723 .build_arc()?;
2724
2725 let task_ctx = TaskContext::default().with_runtime(runtime);
2726 let task_ctx = Arc::new(task_ctx);
2727
2728 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
2730 .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
2731 let exec = Arc::new(exec);
2732 let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2733 let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
2736 .with_preserve_order();
2737
2738 let mut batches = vec![];
2739
2740 for i in 0..exec.partitioning().partition_count() {
2742 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2743 while let Some(result) = stream.next().await {
2744 let batch = result?;
2745 batches.push(batch);
2746 }
2747 }
2748
2749 #[rustfmt::skip]
2750 let expected = [
2751 [
2752 "+----+",
2753 "| c0 |",
2754 "+----+",
2755 "| 1 |",
2756 "| 2 |",
2757 "| 3 |",
2758 "| 4 |",
2759 "+----+",
2760 ],
2761 [
2762 "+----+",
2763 "| c0 |",
2764 "+----+",
2765 "| 5 |",
2766 "| 6 |",
2767 "| 7 |",
2768 "| 8 |",
2769 "+----+",
2770 ],
2771 [
2772 "+----+",
2773 "| c0 |",
2774 "+----+",
2775 "| 9 |",
2776 "| 10 |",
2777 "| 11 |",
2778 "| 12 |",
2779 "+----+",
2780 ],
2781 ];
2782
2783 for (batch, expected) in batches.iter().zip(expected.iter()) {
2784 assert_batches_eq!(expected, std::slice::from_ref(batch));
2785 }
2786
2787 let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
2791 let metrics = exec.metrics().unwrap();
2792 assert!(
2793 metrics.spill_count().unwrap() > input_partitions.len(),
2794 "Expected spill_count > {} for order-preserving repartition, but got {:?}",
2795 input_partitions.len(),
2796 metrics.spill_count()
2797 );
2798 assert!(
2799 metrics.spilled_bytes().unwrap()
2800 > all_batches
2801 .iter()
2802 .map(|b| b.get_array_memory_size())
2803 .sum::<usize>(),
2804 "Expected spilled_bytes > {} for order-preserving repartition, got {}",
2805 all_batches
2806 .iter()
2807 .map(|b| b.get_array_memory_size())
2808 .sum::<usize>(),
2809 metrics.spilled_bytes().unwrap()
2810 );
2811 assert!(
2812 metrics.spilled_rows().unwrap()
2813 >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2814 "Expected spilled_rows > {} for order-preserving repartition, got {}",
2815 all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2816 metrics.spilled_rows().unwrap()
2817 );
2818
2819 Ok(())
2820 }
2821
2822 #[tokio::test]
2823 async fn test_hash_partitioning_with_spilling() -> Result<()> {
2824 use datafusion_execution::TaskContext;
2825 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2826
2827 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2829 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2830 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2831 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2832 let schema = batch1.schema();
2833
2834 let partition1 = vec![batch1.clone(), batch3.clone()];
2835 let partition2 = vec![batch2.clone(), batch4.clone()];
2836 let input_partitions = vec![partition1, partition2];
2837
2838 let runtime = RuntimeEnvBuilder::default()
2840 .with_memory_limit(1, 1.0)
2841 .build_arc()?;
2842
2843 let task_ctx = TaskContext::default().with_runtime(runtime);
2844 let task_ctx = Arc::new(task_ctx);
2845
2846 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
2848 let exec = Arc::new(exec);
2849 let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2850 let hash_expr = col("c0", &schema)?;
2852 let exec =
2853 RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
2854
2855 let mut join_set = tokio::task::JoinSet::new();
2858 for i in 0..exec.partitioning().partition_count() {
2859 let stream = exec.execute(i, Arc::clone(&task_ctx))?;
2860 join_set.spawn(async move {
2861 let mut count = 0;
2862 futures::pin_mut!(stream);
2863 while let Some(result) = stream.next().await {
2864 let batch = result?;
2865 count += batch.num_rows();
2866 }
2867 Ok::<usize, DataFusionError>(count)
2868 });
2869 }
2870
2871 let mut total_rows = 0;
2873 while let Some(result) = join_set.join_next().await {
2874 total_rows += result.unwrap()?;
2875 }
2876
2877 let all_batches = [batch1, batch2, batch3, batch4];
2879 let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
2880 assert_eq!(total_rows, expected_rows);
2881
2882 let metrics = exec.metrics().unwrap();
2884 let spill_count = metrics.spill_count().unwrap_or(0);
2886 assert!(spill_count > 0);
2887 let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
2888 assert!(spilled_bytes > 0);
2889 let spilled_rows = metrics.spilled_rows().unwrap_or(0);
2890 assert!(spilled_rows > 0);
2891
2892 Ok(())
2893 }
2894
2895 #[tokio::test]
2896 async fn test_repartition() -> Result<()> {
2897 let schema = test_schema();
2898 let sort_exprs = sort_exprs(&schema);
2899 let source = sorted_memory_exec(&schema, sort_exprs);
2900 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2902 .repartitioned(20, &Default::default())?
2903 .unwrap();
2904
2905 assert_plan!(exec.as_ref(), @r"
2907 RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
2908 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2909 ");
2910 Ok(())
2911 }
2912
2913 fn test_schema() -> Arc<Schema> {
2914 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2915 }
2916
2917 fn sort_exprs(schema: &Schema) -> LexOrdering {
2918 [PhysicalSortExpr {
2919 expr: col("c0", schema).unwrap(),
2920 options: SortOptions::default(),
2921 }]
2922 .into()
2923 }
2924
2925 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2926 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2927 }
2928
2929 fn sorted_memory_exec(
2930 schema: &SchemaRef,
2931 sort_exprs: LexOrdering,
2932 ) -> Arc<dyn ExecutionPlan> {
2933 let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2934 .unwrap()
2935 .try_with_sort_information(vec![sort_exprs])
2936 .unwrap();
2937 let exec = Arc::new(exec);
2938 Arc::new(TestMemoryExec::update_cache(&exec))
2939 }
2940}