1use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::coalesce::LimitedBatchCoalescer;
34use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
35use crate::hash_utils::create_hashes;
36use crate::metrics::{BaselineMetrics, SpillMetrics};
37use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
38use crate::sorts::streaming_merge::StreamingMergeBuilder;
39use crate::spill::spill_manager::SpillManager;
40use crate::spill::spill_pool::{self, SpillPoolWriter};
41use crate::stream::RecordBatchStreamAdapter;
42use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
43
44use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
45use arrow::compute::take_arrays;
46use arrow::datatypes::{SchemaRef, UInt32Type};
47use datafusion_common::config::ConfigOptions;
48use datafusion_common::stats::Precision;
49use datafusion_common::utils::transpose;
50use datafusion_common::{
51 ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err,
52};
53use datafusion_common::{Result, not_impl_err};
54use datafusion_common_runtime::SpawnedTask;
55use datafusion_execution::TaskContext;
56use datafusion_execution::memory_pool::MemoryConsumer;
57use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
58use datafusion_physical_expr_common::sort_expr::LexOrdering;
59
60use crate::filter_pushdown::{
61 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
62 FilterPushdownPropagation,
63};
64use crate::joins::SeededRandomState;
65use crate::sort_pushdown::SortOrderPushdownResult;
66use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
67use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
68use futures::stream::Stream;
69use futures::{FutureExt, StreamExt, TryStreamExt, ready};
70use log::trace;
71use parking_lot::Mutex;
72
73mod distributor_channels;
74use distributor_channels::{
75 DistributionReceiver, DistributionSender, channels, partition_aware_channels,
76};
77
78#[derive(Debug)]
127enum RepartitionBatch {
128 Memory(RecordBatch),
130 Spilled,
135}
136
137type MaybeBatch = Option<Result<RepartitionBatch>>;
138type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
139type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
140
141struct OutputChannel {
143 sender: DistributionSender<MaybeBatch>,
144 reservation: SharedMemoryReservation,
145 spill_writer: SpillPoolWriter,
146}
147
148struct PartitionChannels {
170 tx: InputPartitionsToCurrentPartitionSender,
172 rx: InputPartitionsToCurrentPartitionReceiver,
174 reservation: SharedMemoryReservation,
176 spill_writers: Vec<SpillPoolWriter>,
179 spill_readers: Vec<SendableRecordBatchStream>,
182}
183
184struct ConsumingInputStreamsState {
185 channels: HashMap<usize, PartitionChannels>,
188
189 abort_helper: Arc<Vec<SpawnedTask<()>>>,
191}
192
193impl Debug for ConsumingInputStreamsState {
194 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("ConsumingInputStreamsState")
196 .field("num_channels", &self.channels.len())
197 .field("abort_helper", &self.abort_helper)
198 .finish()
199 }
200}
201
202#[derive(Default)]
204enum RepartitionExecState {
205 #[default]
208 NotInitialized,
209 InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
213 ConsumingInputStreams(ConsumingInputStreamsState),
216}
217
218impl Debug for RepartitionExecState {
219 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220 match self {
221 RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
222 RepartitionExecState::InputStreamsInitialized(v) => {
223 write!(f, "InputStreamsInitialized({:?})", v.len())
224 }
225 RepartitionExecState::ConsumingInputStreams(v) => {
226 write!(f, "ConsumingInputStreams({v:?})")
227 }
228 }
229 }
230}
231
232impl RepartitionExecState {
233 fn ensure_input_streams_initialized(
234 &mut self,
235 input: &Arc<dyn ExecutionPlan>,
236 metrics: &ExecutionPlanMetricsSet,
237 output_partitions: usize,
238 ctx: &Arc<TaskContext>,
239 ) -> Result<()> {
240 if !matches!(self, RepartitionExecState::NotInitialized) {
241 return Ok(());
242 }
243
244 let num_input_partitions = input.output_partitioning().partition_count();
245 let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
246
247 for i in 0..num_input_partitions {
248 let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
249
250 let timer = metrics.fetch_time.timer();
251 let stream = input.execute(i, Arc::clone(ctx))?;
252 timer.done();
253
254 streams_and_metrics.push((stream, metrics));
255 }
256 *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
257 Ok(())
258 }
259
260 #[expect(clippy::too_many_arguments)]
261 fn consume_input_streams(
262 &mut self,
263 input: &Arc<dyn ExecutionPlan>,
264 metrics: &ExecutionPlanMetricsSet,
265 partitioning: &Partitioning,
266 preserve_order: bool,
267 name: &str,
268 context: &Arc<TaskContext>,
269 spill_manager: SpillManager,
270 ) -> Result<&mut ConsumingInputStreamsState> {
271 let streams_and_metrics = match self {
272 RepartitionExecState::NotInitialized => {
273 self.ensure_input_streams_initialized(
274 input,
275 metrics,
276 partitioning.partition_count(),
277 context,
278 )?;
279 let RepartitionExecState::InputStreamsInitialized(value) = self else {
280 return internal_err!(
283 "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
284 );
285 };
286 value
287 }
288 RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
289 RepartitionExecState::InputStreamsInitialized(value) => value,
290 };
291
292 let num_input_partitions = streams_and_metrics.len();
293 let num_output_partitions = partitioning.partition_count();
294
295 let spill_manager = Arc::new(spill_manager);
296
297 let (txs, rxs) = if preserve_order {
298 let (txs_all, rxs_all) =
301 partition_aware_channels(num_input_partitions, num_output_partitions);
302 let txs = transpose(txs_all);
304 let rxs = transpose(rxs_all);
305 (txs, rxs)
306 } else {
307 let (txs, rxs) = channels(num_output_partitions);
309 let txs = txs
311 .into_iter()
312 .map(|item| vec![item; num_input_partitions])
313 .collect::<Vec<_>>();
314 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
315 (txs, rxs)
316 };
317
318 let mut channels = HashMap::with_capacity(txs.len());
319 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
320 let reservation = Arc::new(Mutex::new(
321 MemoryConsumer::new(format!("{name}[{partition}]"))
322 .with_can_spill(true)
323 .register(context.memory_pool()),
324 ));
325
326 let max_file_size = context
331 .session_config()
332 .options()
333 .execution
334 .max_spill_file_size_bytes;
335 let num_spill_channels = if preserve_order {
336 num_input_partitions
337 } else {
338 1
339 };
340 let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
341 ..num_spill_channels)
342 .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
343 .unzip();
344
345 channels.insert(
346 partition,
347 PartitionChannels {
348 tx,
349 rx,
350 reservation,
351 spill_readers,
352 spill_writers,
353 },
354 );
355 }
356
357 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
359 for (i, (stream, metrics)) in
360 std::mem::take(streams_and_metrics).into_iter().enumerate()
361 {
362 let txs: HashMap<_, _> = channels
363 .iter()
364 .map(|(partition, channels)| {
365 let spill_writer_idx = if preserve_order { i } else { 0 };
368 (
369 *partition,
370 OutputChannel {
371 sender: channels.tx[i].clone(),
372 reservation: Arc::clone(&channels.reservation),
373 spill_writer: channels.spill_writers[spill_writer_idx]
374 .clone(),
375 },
376 )
377 })
378 .collect();
379
380 let senders: HashMap<_, _> = txs
382 .iter()
383 .map(|(partition, channel)| (*partition, channel.sender.clone()))
384 .collect();
385
386 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
387 stream,
388 txs,
389 partitioning.clone(),
390 metrics,
391 if preserve_order { 0 } else { i },
393 num_input_partitions,
394 ));
395
396 let wait_for_task =
399 SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
400 spawned_tasks.push(wait_for_task);
401 }
402 *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
403 channels,
404 abort_helper: Arc::new(spawned_tasks),
405 });
406 match self {
407 RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
408 _ => unreachable!(),
409 }
410 }
411}
412
413pub struct BatchPartitioner {
415 state: BatchPartitionerState,
416 timer: metrics::Time,
417}
418
419enum BatchPartitionerState {
420 Hash {
421 exprs: Vec<Arc<dyn PhysicalExpr>>,
422 num_partitions: usize,
423 hash_buffer: Vec<u64>,
424 },
425 RoundRobin {
426 num_partitions: usize,
427 next_idx: usize,
428 },
429}
430
431pub const REPARTITION_RANDOM_STATE: SeededRandomState =
434 SeededRandomState::with_seeds(0, 0, 0, 0);
435
436impl BatchPartitioner {
437 pub fn new_hash_partitioner(
447 exprs: Vec<Arc<dyn PhysicalExpr>>,
448 num_partitions: usize,
449 timer: metrics::Time,
450 ) -> Self {
451 Self {
452 state: BatchPartitionerState::Hash {
453 exprs,
454 num_partitions,
455 hash_buffer: vec![],
456 },
457 timer,
458 }
459 }
460
461 pub fn new_round_robin_partitioner(
473 num_partitions: usize,
474 timer: metrics::Time,
475 input_partition: usize,
476 num_input_partitions: usize,
477 ) -> Self {
478 Self {
479 state: BatchPartitionerState::RoundRobin {
480 num_partitions,
481 next_idx: (input_partition * num_partitions) / num_input_partitions,
482 },
483 timer,
484 }
485 }
486 pub fn try_new(
500 partitioning: Partitioning,
501 timer: metrics::Time,
502 input_partition: usize,
503 num_input_partitions: usize,
504 ) -> Result<Self> {
505 match partitioning {
506 Partitioning::Hash(exprs, num_partitions) => {
507 Ok(Self::new_hash_partitioner(exprs, num_partitions, timer))
508 }
509 Partitioning::RoundRobinBatch(num_partitions) => {
510 Ok(Self::new_round_robin_partitioner(
511 num_partitions,
512 timer,
513 input_partition,
514 num_input_partitions,
515 ))
516 }
517 other => {
518 not_impl_err!("Unsupported repartitioning scheme {other:?}")
519 }
520 }
521 }
522
523 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
533 where
534 F: FnMut(usize, RecordBatch) -> Result<()>,
535 {
536 self.partition_iter(batch)?.try_for_each(|res| match res {
537 Ok((partition, batch)) => f(partition, batch),
538 Err(e) => Err(e),
539 })
540 }
541
542 fn partition_iter(
548 &mut self,
549 batch: RecordBatch,
550 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
551 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
552 match &mut self.state {
553 BatchPartitionerState::RoundRobin {
554 num_partitions,
555 next_idx,
556 } => {
557 let idx = *next_idx;
558 *next_idx = (*next_idx + 1) % *num_partitions;
559 Box::new(std::iter::once(Ok((idx, batch))))
560 }
561 BatchPartitionerState::Hash {
562 exprs,
563 num_partitions: partitions,
564 hash_buffer,
565 } => {
566 let timer = self.timer.timer();
568
569 let arrays =
570 evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
571
572 hash_buffer.clear();
573 hash_buffer.resize(batch.num_rows(), 0);
574
575 create_hashes(
576 &arrays,
577 REPARTITION_RANDOM_STATE.random_state(),
578 hash_buffer,
579 )?;
580
581 let mut indices: Vec<_> = (0..*partitions)
582 .map(|_| Vec::with_capacity(batch.num_rows()))
583 .collect();
584
585 for (index, hash) in hash_buffer.iter().enumerate() {
586 indices[(*hash % *partitions as u64) as usize].push(index as u32);
587 }
588
589 timer.done();
591
592 let partitioner_timer = &self.timer;
594 let it = indices
595 .into_iter()
596 .enumerate()
597 .filter_map(|(partition, indices)| {
598 let indices: PrimitiveArray<UInt32Type> = indices.into();
599 (!indices.is_empty()).then_some((partition, indices))
600 })
601 .map(move |(partition, indices)| {
602 let _timer = partitioner_timer.timer();
604
605 let columns = take_arrays(batch.columns(), &indices, None)?;
607
608 let mut options = RecordBatchOptions::new();
609 options = options.with_row_count(Some(indices.len()));
610 let batch = RecordBatch::try_new_with_options(
611 batch.schema(),
612 columns,
613 &options,
614 )
615 .unwrap();
616
617 Ok((partition, batch))
618 });
619
620 Box::new(it)
621 }
622 };
623
624 Ok(it)
625 }
626
627 fn num_partitions(&self) -> usize {
629 match self.state {
630 BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
631 BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
632 }
633 }
634}
635
636#[derive(Debug, Clone)]
735pub struct RepartitionExec {
736 input: Arc<dyn ExecutionPlan>,
738 state: Arc<Mutex<RepartitionExecState>>,
741 metrics: ExecutionPlanMetricsSet,
743 preserve_order: bool,
746 cache: PlanProperties,
748}
749
750#[derive(Debug, Clone)]
751struct RepartitionMetrics {
752 fetch_time: metrics::Time,
754 repartition_time: metrics::Time,
756 send_time: Vec<metrics::Time>,
760}
761
762impl RepartitionMetrics {
763 pub fn new(
764 input_partition: usize,
765 num_output_partitions: usize,
766 metrics: &ExecutionPlanMetricsSet,
767 ) -> Self {
768 let fetch_time =
770 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
771
772 let repartition_time =
774 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
775
776 let send_time = (0..num_output_partitions)
778 .map(|output_partition| {
779 let label =
780 metrics::Label::new("outputPartition", output_partition.to_string());
781 MetricBuilder::new(metrics)
782 .with_label(label)
783 .subset_time("send_time", input_partition)
784 })
785 .collect();
786
787 Self {
788 fetch_time,
789 repartition_time,
790 send_time,
791 }
792 }
793}
794
795impl RepartitionExec {
796 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
798 &self.input
799 }
800
801 pub fn partitioning(&self) -> &Partitioning {
803 &self.cache.partitioning
804 }
805
806 pub fn preserve_order(&self) -> bool {
809 self.preserve_order
810 }
811
812 pub fn name(&self) -> &str {
814 "RepartitionExec"
815 }
816}
817
818impl DisplayAs for RepartitionExec {
819 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
820 let input_partition_count = self.input.output_partitioning().partition_count();
821 match t {
822 DisplayFormatType::Default | DisplayFormatType::Verbose => {
823 write!(
824 f,
825 "{}: partitioning={}, input_partitions={}",
826 self.name(),
827 self.partitioning(),
828 input_partition_count,
829 )?;
830
831 if self.preserve_order {
832 write!(f, ", preserve_order=true")?;
833 } else if input_partition_count <= 1
834 && self.input.output_ordering().is_some()
835 {
836 write!(f, ", maintains_sort_order=true")?;
839 }
840
841 if let Some(sort_exprs) = self.sort_exprs() {
842 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
843 }
844 Ok(())
845 }
846 DisplayFormatType::TreeRender => {
847 writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
848 let output_partition_count = self.partitioning().partition_count();
849 let input_to_output_partition_str =
850 format!("{input_partition_count} -> {output_partition_count}");
851 writeln!(
852 f,
853 "partition_count(in->out)={input_to_output_partition_str}"
854 )?;
855
856 if self.preserve_order {
857 writeln!(f, "preserve_order={}", self.preserve_order)?;
858 }
859 Ok(())
860 }
861 }
862 }
863}
864
865impl ExecutionPlan for RepartitionExec {
866 fn name(&self) -> &'static str {
867 "RepartitionExec"
868 }
869
870 fn as_any(&self) -> &dyn Any {
872 self
873 }
874
875 fn properties(&self) -> &PlanProperties {
876 &self.cache
877 }
878
879 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
880 vec![&self.input]
881 }
882
883 fn with_new_children(
884 self: Arc<Self>,
885 mut children: Vec<Arc<dyn ExecutionPlan>>,
886 ) -> Result<Arc<dyn ExecutionPlan>> {
887 let mut repartition = RepartitionExec::try_new(
888 children.swap_remove(0),
889 self.partitioning().clone(),
890 )?;
891 if self.preserve_order {
892 repartition = repartition.with_preserve_order();
893 }
894 Ok(Arc::new(repartition))
895 }
896
897 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
898 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
899 }
900
901 fn maintains_input_order(&self) -> Vec<bool> {
902 Self::maintains_input_order_helper(self.input(), self.preserve_order)
903 }
904
905 fn execute(
906 &self,
907 partition: usize,
908 context: Arc<TaskContext>,
909 ) -> Result<SendableRecordBatchStream> {
910 trace!(
911 "Start {}::execute for partition: {}",
912 self.name(),
913 partition
914 );
915
916 let spill_metrics = SpillMetrics::new(&self.metrics, partition);
917
918 let input = Arc::clone(&self.input);
919 let partitioning = self.partitioning().clone();
920 let metrics = self.metrics.clone();
921 let preserve_order = self.sort_exprs().is_some();
922 let name = self.name().to_owned();
923 let schema = self.schema();
924 let schema_captured = Arc::clone(&schema);
925
926 let spill_manager = SpillManager::new(
927 Arc::clone(&context.runtime_env()),
928 spill_metrics,
929 input.schema(),
930 );
931
932 let sort_exprs = self.sort_exprs().cloned();
934
935 let state = Arc::clone(&self.state);
936 if let Some(mut state) = state.try_lock() {
937 state.ensure_input_streams_initialized(
938 &input,
939 &metrics,
940 partitioning.partition_count(),
941 &context,
942 )?;
943 }
944
945 let num_input_partitions = input.output_partitioning().partition_count();
946
947 let stream = futures::stream::once(async move {
948 let (rx, reservation, spill_readers, abort_helper) = {
950 let mut state = state.lock();
952 let state = state.consume_input_streams(
953 &input,
954 &metrics,
955 &partitioning,
956 preserve_order,
957 &name,
958 &context,
959 spill_manager.clone(),
960 )?;
961
962 let PartitionChannels {
965 rx,
966 reservation,
967 spill_readers,
968 ..
969 } = state
970 .channels
971 .remove(&partition)
972 .expect("partition not used yet");
973
974 (
975 rx,
976 reservation,
977 spill_readers,
978 Arc::clone(&state.abort_helper),
979 )
980 };
981
982 trace!(
983 "Before returning stream in {name}::execute for partition: {partition}"
984 );
985
986 if preserve_order {
987 let input_streams = rx
990 .into_iter()
991 .zip(spill_readers)
992 .map(|(receiver, spill_stream)| {
993 Box::pin(PerPartitionStream::new(
995 Arc::clone(&schema_captured),
996 receiver,
997 Arc::clone(&abort_helper),
998 Arc::clone(&reservation),
999 spill_stream,
1000 1, BaselineMetrics::new(&metrics, partition),
1002 None, )) as SendableRecordBatchStream
1004 })
1005 .collect::<Vec<_>>();
1006 let fetch = None;
1011 let merge_reservation =
1012 MemoryConsumer::new(format!("{name}[Merge {partition}]"))
1013 .register(context.memory_pool());
1014 StreamingMergeBuilder::new()
1015 .with_streams(input_streams)
1016 .with_schema(schema_captured)
1017 .with_expressions(&sort_exprs.unwrap())
1018 .with_metrics(BaselineMetrics::new(&metrics, partition))
1019 .with_batch_size(context.session_config().batch_size())
1020 .with_fetch(fetch)
1021 .with_reservation(merge_reservation)
1022 .with_spill_manager(spill_manager)
1023 .build()
1024 } else {
1025 let spill_stream = spill_readers
1027 .into_iter()
1028 .next()
1029 .expect("at least one spill reader should exist");
1030
1031 Ok(Box::pin(PerPartitionStream::new(
1032 schema_captured,
1033 rx.into_iter()
1034 .next()
1035 .expect("at least one receiver should exist"),
1036 abort_helper,
1037 reservation,
1038 spill_stream,
1039 num_input_partitions,
1040 BaselineMetrics::new(&metrics, partition),
1041 Some(context.session_config().batch_size()),
1042 )) as SendableRecordBatchStream)
1043 }
1044 })
1045 .try_flatten();
1046 let stream = RecordBatchStreamAdapter::new(schema, stream);
1047 Ok(Box::pin(stream))
1048 }
1049
1050 fn metrics(&self) -> Option<MetricsSet> {
1051 Some(self.metrics.clone_inner())
1052 }
1053
1054 fn statistics(&self) -> Result<Statistics> {
1055 self.input.partition_statistics(None)
1056 }
1057
1058 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1059 if let Some(partition) = partition {
1060 let partition_count = self.partitioning().partition_count();
1061 if partition_count == 0 {
1062 return Ok(Statistics::new_unknown(&self.schema()));
1063 }
1064
1065 assert_or_internal_err!(
1066 partition < partition_count,
1067 "RepartitionExec invalid partition {} (expected less than {})",
1068 partition,
1069 partition_count
1070 );
1071
1072 let mut stats = self.input.partition_statistics(None)?;
1073
1074 stats.num_rows = stats
1076 .num_rows
1077 .get_value()
1078 .map(|rows| Precision::Inexact(rows / partition_count))
1079 .unwrap_or(Precision::Absent);
1080 stats.total_byte_size = stats
1081 .total_byte_size
1082 .get_value()
1083 .map(|bytes| Precision::Inexact(bytes / partition_count))
1084 .unwrap_or(Precision::Absent);
1085
1086 stats.column_statistics = stats
1088 .column_statistics
1089 .iter()
1090 .map(|_| ColumnStatistics::new_unknown())
1091 .collect();
1092
1093 Ok(stats)
1094 } else {
1095 self.input.partition_statistics(None)
1096 }
1097 }
1098
1099 fn cardinality_effect(&self) -> CardinalityEffect {
1100 CardinalityEffect::Equal
1101 }
1102
1103 fn try_swapping_with_projection(
1104 &self,
1105 projection: &ProjectionExec,
1106 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1107 if projection.expr().len() >= projection.input().schema().fields().len() {
1109 return Ok(None);
1110 }
1111
1112 if projection.benefits_from_input_partitioning()[0]
1114 || !all_columns(projection.expr())
1115 {
1116 return Ok(None);
1117 }
1118
1119 let new_projection = make_with_child(projection, self.input())?;
1120
1121 let new_partitioning = match self.partitioning() {
1122 Partitioning::Hash(partitions, size) => {
1123 let mut new_partitions = vec![];
1124 for partition in partitions {
1125 let Some(new_partition) =
1126 update_expr(partition, projection.expr(), false)?
1127 else {
1128 return Ok(None);
1129 };
1130 new_partitions.push(new_partition);
1131 }
1132 Partitioning::Hash(new_partitions, *size)
1133 }
1134 others => others.clone(),
1135 };
1136
1137 Ok(Some(Arc::new(RepartitionExec::try_new(
1138 new_projection,
1139 new_partitioning,
1140 )?)))
1141 }
1142
1143 fn gather_filters_for_pushdown(
1144 &self,
1145 _phase: FilterPushdownPhase,
1146 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1147 _config: &ConfigOptions,
1148 ) -> Result<FilterDescription> {
1149 FilterDescription::from_children(parent_filters, &self.children())
1150 }
1151
1152 fn handle_child_pushdown_result(
1153 &self,
1154 _phase: FilterPushdownPhase,
1155 child_pushdown_result: ChildPushdownResult,
1156 _config: &ConfigOptions,
1157 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1158 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1159 }
1160
1161 fn try_pushdown_sort(
1162 &self,
1163 order: &[PhysicalSortExpr],
1164 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
1165 if !self.maintains_input_order()[0] {
1168 return Ok(SortOrderPushdownResult::Unsupported);
1169 }
1170
1171 self.input.try_pushdown_sort(order)?.try_map(|new_input| {
1173 let mut new_repartition =
1174 RepartitionExec::try_new(new_input, self.partitioning().clone())?;
1175 if self.preserve_order {
1176 new_repartition = new_repartition.with_preserve_order();
1177 }
1178 Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
1179 })
1180 }
1181
1182 fn repartitioned(
1183 &self,
1184 target_partitions: usize,
1185 _config: &ConfigOptions,
1186 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1187 use Partitioning::*;
1188 let mut new_properties = self.cache.clone();
1189 new_properties.partitioning = match new_properties.partitioning {
1190 RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1191 Hash(hash, _) => Hash(hash, target_partitions),
1192 UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1193 };
1194 Ok(Some(Arc::new(Self {
1195 input: Arc::clone(&self.input),
1196 state: Arc::clone(&self.state),
1197 metrics: self.metrics.clone(),
1198 preserve_order: self.preserve_order,
1199 cache: new_properties,
1200 })))
1201 }
1202}
1203
1204impl RepartitionExec {
1205 pub fn try_new(
1209 input: Arc<dyn ExecutionPlan>,
1210 partitioning: Partitioning,
1211 ) -> Result<Self> {
1212 let preserve_order = false;
1213 let cache = Self::compute_properties(&input, partitioning, preserve_order);
1214 Ok(RepartitionExec {
1215 input,
1216 state: Default::default(),
1217 metrics: ExecutionPlanMetricsSet::new(),
1218 preserve_order,
1219 cache,
1220 })
1221 }
1222
1223 fn maintains_input_order_helper(
1224 input: &Arc<dyn ExecutionPlan>,
1225 preserve_order: bool,
1226 ) -> Vec<bool> {
1227 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1229 }
1230
1231 fn eq_properties_helper(
1232 input: &Arc<dyn ExecutionPlan>,
1233 preserve_order: bool,
1234 ) -> EquivalenceProperties {
1235 let mut eq_properties = input.equivalence_properties().clone();
1237 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1239 eq_properties.clear_orderings();
1240 }
1241 if input.output_partitioning().partition_count() > 1 {
1244 eq_properties.clear_per_partition_constants();
1245 }
1246 eq_properties
1247 }
1248
1249 fn compute_properties(
1251 input: &Arc<dyn ExecutionPlan>,
1252 partitioning: Partitioning,
1253 preserve_order: bool,
1254 ) -> PlanProperties {
1255 PlanProperties::new(
1256 Self::eq_properties_helper(input, preserve_order),
1257 partitioning,
1258 input.pipeline_behavior(),
1259 input.boundedness(),
1260 )
1261 .with_scheduling_type(SchedulingType::Cooperative)
1262 .with_evaluation_type(EvaluationType::Eager)
1263 }
1264
1265 pub fn with_preserve_order(mut self) -> Self {
1273 self.preserve_order =
1274 self.input.output_ordering().is_some() &&
1276 self.input.output_partitioning().partition_count() > 1;
1279 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1280 self.cache = self.cache.with_eq_properties(eq_properties);
1281 self
1282 }
1283
1284 fn sort_exprs(&self) -> Option<&LexOrdering> {
1286 if self.preserve_order {
1287 self.input.output_ordering()
1288 } else {
1289 None
1290 }
1291 }
1292
1293 async fn pull_from_input(
1298 mut stream: SendableRecordBatchStream,
1299 mut output_channels: HashMap<usize, OutputChannel>,
1300 partitioning: Partitioning,
1301 metrics: RepartitionMetrics,
1302 input_partition: usize,
1303 num_input_partitions: usize,
1304 ) -> Result<()> {
1305 let mut partitioner = match &partitioning {
1306 Partitioning::Hash(exprs, num_partitions) => {
1307 BatchPartitioner::new_hash_partitioner(
1308 exprs.clone(),
1309 *num_partitions,
1310 metrics.repartition_time.clone(),
1311 )
1312 }
1313 Partitioning::RoundRobinBatch(num_partitions) => {
1314 BatchPartitioner::new_round_robin_partitioner(
1315 *num_partitions,
1316 metrics.repartition_time.clone(),
1317 input_partition,
1318 num_input_partitions,
1319 )
1320 }
1321 other => {
1322 return not_impl_err!("Unsupported repartitioning scheme {other:?}");
1323 }
1324 };
1325
1326 let mut batches_until_yield = partitioner.num_partitions();
1328 while !output_channels.is_empty() {
1329 let timer = metrics.fetch_time.timer();
1331 let result = stream.next().await;
1332 timer.done();
1333
1334 let batch = match result {
1336 Some(result) => result?,
1337 None => break,
1338 };
1339
1340 if batch.num_rows() == 0 {
1342 continue;
1343 }
1344
1345 for res in partitioner.partition_iter(batch)? {
1346 let (partition, batch) = res?;
1347 let size = batch.get_array_memory_size();
1348
1349 let timer = metrics.send_time[partition].timer();
1350 if let Some(channel) = output_channels.get_mut(&partition) {
1352 let (batch_to_send, is_memory_batch) =
1353 match channel.reservation.lock().try_grow(size) {
1354 Ok(_) => {
1355 (RepartitionBatch::Memory(batch), true)
1357 }
1358 Err(_) => {
1359 channel.spill_writer.push_batch(&batch)?;
1362 (RepartitionBatch::Spilled, false)
1364 }
1365 };
1366
1367 if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
1368 if is_memory_batch {
1371 channel.reservation.lock().shrink(size);
1372 }
1373 output_channels.remove(&partition);
1374 }
1375 }
1376 timer.done();
1377 }
1378
1379 if batches_until_yield == 0 {
1396 tokio::task::yield_now().await;
1397 batches_until_yield = partitioner.num_partitions();
1398 } else {
1399 batches_until_yield -= 1;
1400 }
1401 }
1402
1403 Ok(())
1406 }
1407
1408 async fn wait_for_task(
1414 input_task: SpawnedTask<Result<()>>,
1415 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1416 ) {
1417 match input_task.join().await {
1421 Err(e) => {
1423 let e = Arc::new(e);
1424
1425 for (_, tx) in txs {
1426 let err = Err(DataFusionError::Context(
1427 "Join Error".to_string(),
1428 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1429 ));
1430 tx.send(Some(err)).await.ok();
1431 }
1432 }
1433 Ok(Err(e)) => {
1435 let e = Arc::new(e);
1437
1438 for (_, tx) in txs {
1439 let err = Err(DataFusionError::from(&e));
1441 tx.send(Some(err)).await.ok();
1442 }
1443 }
1444 Ok(Ok(())) => {
1446 for (_partition, tx) in txs {
1448 tx.send(None).await.ok();
1449 }
1450 }
1451 }
1452 }
1453}
1454
1455#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1498enum StreamState {
1499 ReadingMemory,
1501 ReadingSpilled,
1504}
1505
1506struct PerPartitionStream {
1509 schema: SchemaRef,
1511
1512 receiver: DistributionReceiver<MaybeBatch>,
1514
1515 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1517
1518 reservation: SharedMemoryReservation,
1520
1521 spill_stream: SendableRecordBatchStream,
1523
1524 state: StreamState,
1526
1527 remaining_partitions: usize,
1531
1532 baseline_metrics: BaselineMetrics,
1534
1535 batch_coalescer: Option<LimitedBatchCoalescer>,
1537}
1538
1539impl PerPartitionStream {
1540 #[expect(clippy::too_many_arguments)]
1541 fn new(
1542 schema: SchemaRef,
1543 receiver: DistributionReceiver<MaybeBatch>,
1544 drop_helper: Arc<Vec<SpawnedTask<()>>>,
1545 reservation: SharedMemoryReservation,
1546 spill_stream: SendableRecordBatchStream,
1547 num_input_partitions: usize,
1548 baseline_metrics: BaselineMetrics,
1549 batch_size: Option<usize>,
1550 ) -> Self {
1551 let batch_coalescer =
1552 batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
1553 Self {
1554 schema,
1555 receiver,
1556 _drop_helper: drop_helper,
1557 reservation,
1558 spill_stream,
1559 state: StreamState::ReadingMemory,
1560 remaining_partitions: num_input_partitions,
1561 baseline_metrics,
1562 batch_coalescer,
1563 }
1564 }
1565
1566 fn poll_next_inner(
1567 self: &mut Pin<&mut Self>,
1568 cx: &mut Context<'_>,
1569 ) -> Poll<Option<Result<RecordBatch>>> {
1570 use futures::StreamExt;
1571 let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1572 let _timer = cloned_time.timer();
1573
1574 loop {
1575 match self.state {
1576 StreamState::ReadingMemory => {
1577 let value = match self.receiver.recv().poll_unpin(cx) {
1579 Poll::Ready(v) => v,
1580 Poll::Pending => {
1581 return Poll::Pending;
1583 }
1584 };
1585
1586 match value {
1587 Some(Some(v)) => match v {
1588 Ok(RepartitionBatch::Memory(batch)) => {
1589 self.reservation
1591 .lock()
1592 .shrink(batch.get_array_memory_size());
1593 return Poll::Ready(Some(Ok(batch)));
1594 }
1595 Ok(RepartitionBatch::Spilled) => {
1596 self.state = StreamState::ReadingSpilled;
1600 continue;
1601 }
1602 Err(e) => {
1603 return Poll::Ready(Some(Err(e)));
1604 }
1605 },
1606 Some(None) => {
1607 self.remaining_partitions -= 1;
1609 if self.remaining_partitions == 0 {
1610 return Poll::Ready(None);
1612 }
1613 continue;
1615 }
1616 None => {
1617 return Poll::Ready(None);
1619 }
1620 }
1621 }
1622 StreamState::ReadingSpilled => {
1623 match self.spill_stream.poll_next_unpin(cx) {
1625 Poll::Ready(Some(Ok(batch))) => {
1626 self.state = StreamState::ReadingMemory;
1627 return Poll::Ready(Some(Ok(batch)));
1628 }
1629 Poll::Ready(Some(Err(e))) => {
1630 return Poll::Ready(Some(Err(e)));
1631 }
1632 Poll::Ready(None) => {
1633 self.state = StreamState::ReadingMemory;
1635 }
1636 Poll::Pending => {
1637 return Poll::Pending;
1640 }
1641 }
1642 }
1643 }
1644 }
1645 }
1646
1647 fn poll_next_and_coalesce(
1648 self: &mut Pin<&mut Self>,
1649 cx: &mut Context<'_>,
1650 coalescer: &mut LimitedBatchCoalescer,
1651 ) -> Poll<Option<Result<RecordBatch>>> {
1652 let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1653 let mut completed = false;
1654
1655 loop {
1656 if let Some(batch) = coalescer.next_completed_batch() {
1657 return Poll::Ready(Some(Ok(batch)));
1658 }
1659 if completed {
1660 return Poll::Ready(None);
1661 }
1662
1663 match ready!(self.poll_next_inner(cx)) {
1664 Some(Ok(batch)) => {
1665 let _timer = cloned_time.timer();
1666 if let Err(err) = coalescer.push_batch(batch) {
1667 return Poll::Ready(Some(Err(err)));
1668 }
1669 }
1670 Some(err) => {
1671 return Poll::Ready(Some(err));
1672 }
1673 None => {
1674 completed = true;
1675 let _timer = cloned_time.timer();
1676 if let Err(err) = coalescer.finish() {
1677 return Poll::Ready(Some(Err(err)));
1678 }
1679 }
1680 }
1681 }
1682 }
1683}
1684
1685impl Stream for PerPartitionStream {
1686 type Item = Result<RecordBatch>;
1687
1688 fn poll_next(
1689 mut self: Pin<&mut Self>,
1690 cx: &mut Context<'_>,
1691 ) -> Poll<Option<Self::Item>> {
1692 let poll;
1693 if let Some(mut coalescer) = self.batch_coalescer.take() {
1694 poll = self.poll_next_and_coalesce(cx, &mut coalescer);
1695 self.batch_coalescer = Some(coalescer);
1696 } else {
1697 poll = self.poll_next_inner(cx);
1698 }
1699 self.baseline_metrics.record_poll(poll)
1700 }
1701}
1702
1703impl RecordBatchStream for PerPartitionStream {
1704 fn schema(&self) -> SchemaRef {
1706 Arc::clone(&self.schema)
1707 }
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712 use std::collections::HashSet;
1713
1714 use super::*;
1715 use crate::test::TestMemoryExec;
1716 use crate::{
1717 test::{
1718 assert_is_pending,
1719 exec::{
1720 BarrierExec, BlockingExec, ErrorExec, MockExec,
1721 assert_strong_count_converges_to_zero,
1722 },
1723 },
1724 {collect, expressions::col},
1725 };
1726
1727 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1728 use arrow::datatypes::{DataType, Field, Schema};
1729 use datafusion_common::cast::as_string_array;
1730 use datafusion_common::exec_err;
1731 use datafusion_common::test_util::batches_to_sort_string;
1732 use datafusion_common_runtime::JoinSet;
1733 use datafusion_execution::config::SessionConfig;
1734 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1735 use insta::assert_snapshot;
1736
1737 #[tokio::test]
1738 async fn one_to_many_round_robin() -> Result<()> {
1739 let schema = test_schema();
1741 let partition = create_vec_batches(50);
1742 let partitions = vec![partition];
1743
1744 let output_partitions =
1746 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1747
1748 assert_eq!(4, output_partitions.len());
1749 for partition in &output_partitions {
1750 assert_eq!(1, partition.len());
1751 }
1752 assert_eq!(13 * 8, output_partitions[0][0].num_rows());
1753 assert_eq!(13 * 8, output_partitions[1][0].num_rows());
1754 assert_eq!(12 * 8, output_partitions[2][0].num_rows());
1755 assert_eq!(12 * 8, output_partitions[3][0].num_rows());
1756
1757 Ok(())
1758 }
1759
1760 #[tokio::test]
1761 async fn many_to_one_round_robin() -> Result<()> {
1762 let schema = test_schema();
1764 let partition = create_vec_batches(50);
1765 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1766
1767 let output_partitions =
1769 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1770
1771 assert_eq!(1, output_partitions.len());
1772 assert_eq!(150 * 8, output_partitions[0][0].num_rows());
1773
1774 Ok(())
1775 }
1776
1777 #[tokio::test]
1778 async fn many_to_many_round_robin() -> Result<()> {
1779 let schema = test_schema();
1781 let partition = create_vec_batches(50);
1782 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1783
1784 let output_partitions =
1786 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1787
1788 let total_rows_per_partition = 8 * 50 * 3 / 5;
1789 assert_eq!(5, output_partitions.len());
1790 for partition in output_partitions {
1791 assert_eq!(1, partition.len());
1792 assert_eq!(total_rows_per_partition, partition[0].num_rows());
1793 }
1794
1795 Ok(())
1796 }
1797
1798 #[tokio::test]
1799 async fn many_to_many_hash_partition() -> Result<()> {
1800 let schema = test_schema();
1802 let partition = create_vec_batches(50);
1803 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1804
1805 let output_partitions = repartition(
1806 &schema,
1807 partitions,
1808 Partitioning::Hash(vec![col("c0", &schema)?], 8),
1809 )
1810 .await?;
1811
1812 let total_rows: usize = output_partitions
1813 .iter()
1814 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1815 .sum();
1816
1817 assert_eq!(8, output_partitions.len());
1818 assert_eq!(total_rows, 8 * 50 * 3);
1819
1820 Ok(())
1821 }
1822
1823 #[tokio::test]
1824 async fn test_repartition_with_coalescing() -> Result<()> {
1825 let schema = test_schema();
1826 let partition = create_vec_batches(50);
1828 let partitions = vec![partition.clone(), partition.clone()];
1829 let partitioning = Partitioning::RoundRobinBatch(1);
1830
1831 let session_config = SessionConfig::new().with_batch_size(200);
1832 let task_ctx = TaskContext::default().with_session_config(session_config);
1833 let task_ctx = Arc::new(task_ctx);
1834
1835 let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
1837 let exec = RepartitionExec::try_new(exec, partitioning)?;
1838
1839 for i in 0..exec.partitioning().partition_count() {
1840 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1841 while let Some(result) = stream.next().await {
1842 let batch = result?;
1843 assert_eq!(200, batch.num_rows());
1844 }
1845 }
1846 Ok(())
1847 }
1848
1849 fn test_schema() -> Arc<Schema> {
1850 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1851 }
1852
1853 async fn repartition(
1854 schema: &SchemaRef,
1855 input_partitions: Vec<Vec<RecordBatch>>,
1856 partitioning: Partitioning,
1857 ) -> Result<Vec<Vec<RecordBatch>>> {
1858 let task_ctx = Arc::new(TaskContext::default());
1859 let exec =
1861 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1862 let exec = RepartitionExec::try_new(exec, partitioning)?;
1863
1864 let mut output_partitions = vec![];
1866 for i in 0..exec.partitioning().partition_count() {
1867 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1869 let mut batches = vec![];
1870 while let Some(result) = stream.next().await {
1871 batches.push(result?);
1872 }
1873 output_partitions.push(batches);
1874 }
1875 Ok(output_partitions)
1876 }
1877
1878 #[tokio::test]
1879 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1880 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1881 SpawnedTask::spawn(async move {
1882 let schema = test_schema();
1884 let partition = create_vec_batches(50);
1885 let partitions =
1886 vec![partition.clone(), partition.clone(), partition.clone()];
1887
1888 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1890 });
1891
1892 let output_partitions = handle.join().await.unwrap().unwrap();
1893
1894 let total_rows_per_partition = 8 * 50 * 3 / 5;
1895 assert_eq!(5, output_partitions.len());
1896 for partition in output_partitions {
1897 assert_eq!(1, partition.len());
1898 assert_eq!(total_rows_per_partition, partition[0].num_rows());
1899 }
1900
1901 Ok(())
1902 }
1903
1904 #[tokio::test]
1905 async fn unsupported_partitioning() {
1906 let task_ctx = Arc::new(TaskContext::default());
1907 let batch = RecordBatch::try_from_iter(vec![(
1909 "my_awesome_field",
1910 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1911 )])
1912 .unwrap();
1913
1914 let schema = batch.schema();
1915 let input = MockExec::new(vec![Ok(batch)], schema);
1916 let partitioning = Partitioning::UnknownPartitioning(1);
1920 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1921 let output_stream = exec.execute(0, task_ctx).unwrap();
1922
1923 let result_string = crate::common::collect(output_stream)
1925 .await
1926 .unwrap_err()
1927 .to_string();
1928 assert!(
1929 result_string
1930 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1931 "actual: {result_string}"
1932 );
1933 }
1934
1935 #[tokio::test]
1936 async fn error_for_input_exec() {
1937 let task_ctx = Arc::new(TaskContext::default());
1941 let input = ErrorExec::new();
1942 let partitioning = Partitioning::RoundRobinBatch(1);
1943 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1944
1945 let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1947
1948 assert!(
1949 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1950 "actual: {result_string}"
1951 );
1952 }
1953
1954 #[tokio::test]
1955 async fn repartition_with_error_in_stream() {
1956 let task_ctx = Arc::new(TaskContext::default());
1957 let batch = RecordBatch::try_from_iter(vec![(
1958 "my_awesome_field",
1959 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1960 )])
1961 .unwrap();
1962
1963 let err = exec_err!("bad data error");
1966
1967 let schema = batch.schema();
1968 let input = MockExec::new(vec![Ok(batch), err], schema);
1969 let partitioning = Partitioning::RoundRobinBatch(1);
1970 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1971
1972 let output_stream = exec.execute(0, task_ctx).unwrap();
1975
1976 let result_string = crate::common::collect(output_stream)
1978 .await
1979 .unwrap_err()
1980 .to_string();
1981 assert!(
1982 result_string.contains("bad data error"),
1983 "actual: {result_string}"
1984 );
1985 }
1986
1987 #[tokio::test]
1988 async fn repartition_with_delayed_stream() {
1989 let task_ctx = Arc::new(TaskContext::default());
1990 let batch1 = RecordBatch::try_from_iter(vec![(
1991 "my_awesome_field",
1992 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1993 )])
1994 .unwrap();
1995
1996 let batch2 = RecordBatch::try_from_iter(vec![(
1997 "my_awesome_field",
1998 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1999 )])
2000 .unwrap();
2001
2002 let schema = batch1.schema();
2005 let expected_batches = vec![batch1.clone(), batch2.clone()];
2006 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
2007 let partitioning = Partitioning::RoundRobinBatch(1);
2008
2009 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2010
2011 assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
2012 +------------------+
2013 | my_awesome_field |
2014 +------------------+
2015 | bar |
2016 | baz |
2017 | foo |
2018 | frob |
2019 +------------------+
2020 ");
2021
2022 let output_stream = exec.execute(0, task_ctx).unwrap();
2023 let batches = crate::common::collect(output_stream).await.unwrap();
2024
2025 assert_snapshot!(batches_to_sort_string(&batches), @r"
2026 +------------------+
2027 | my_awesome_field |
2028 +------------------+
2029 | bar |
2030 | baz |
2031 | foo |
2032 | frob |
2033 +------------------+
2034 ");
2035 }
2036
2037 #[tokio::test]
2038 async fn robin_repartition_with_dropping_output_stream() {
2039 let task_ctx = Arc::new(TaskContext::default());
2040 let partitioning = Partitioning::RoundRobinBatch(2);
2041 let input = Arc::new(make_barrier_exec());
2044
2045 let exec = RepartitionExec::try_new(
2047 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2048 partitioning,
2049 )
2050 .unwrap();
2051
2052 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2053 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2054
2055 drop(output_stream0);
2058
2059 let mut background_task = JoinSet::new();
2061 background_task.spawn(async move {
2062 input.wait().await;
2063 });
2064
2065 let batches = crate::common::collect(output_stream1).await.unwrap();
2067
2068 assert_snapshot!(batches_to_sort_string(&batches), @r"
2069 +------------------+
2070 | my_awesome_field |
2071 +------------------+
2072 | baz |
2073 | frob |
2074 | gar |
2075 | goo |
2076 +------------------+
2077 ");
2078 }
2079
2080 #[tokio::test]
2081 async fn hash_repartition_with_dropping_output_stream() {
2085 let task_ctx = Arc::new(TaskContext::default());
2086 let partitioning = Partitioning::Hash(
2087 vec![Arc::new(crate::expressions::Column::new(
2088 "my_awesome_field",
2089 0,
2090 ))],
2091 2,
2092 );
2093
2094 let input = Arc::new(make_barrier_exec());
2096 let exec = RepartitionExec::try_new(
2097 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2098 partitioning.clone(),
2099 )
2100 .unwrap();
2101 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2102 let mut background_task = JoinSet::new();
2103 background_task.spawn(async move {
2104 input.wait().await;
2105 });
2106 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
2107
2108 let items_vec = str_batches_to_vec(&batches_without_drop);
2110 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
2111 assert_eq!(items_vec.len(), items_set.len());
2112 let source_str_set: HashSet<&str> =
2113 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
2114 .iter()
2115 .copied()
2116 .collect();
2117 assert_eq!(items_set.difference(&source_str_set).count(), 0);
2118
2119 let input = Arc::new(make_barrier_exec());
2121 let exec = RepartitionExec::try_new(
2122 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2123 partitioning,
2124 )
2125 .unwrap();
2126 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2127 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2128 drop(output_stream0);
2131 let mut background_task = JoinSet::new();
2132 background_task.spawn(async move {
2133 input.wait().await;
2134 });
2135 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
2136
2137 let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2138 let items_set_with_drop: HashSet<&str> =
2139 items_vec_with_drop.iter().copied().collect();
2140 assert_eq!(
2141 items_set_with_drop.symmetric_difference(&items_set).count(),
2142 0
2143 );
2144 }
2145
2146 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
2147 batches
2148 .iter()
2149 .flat_map(|batch| {
2150 assert_eq!(batch.columns().len(), 1);
2151 let string_array = as_string_array(batch.column(0))
2152 .expect("Unexpected type for repartitioned batch");
2153
2154 string_array
2155 .iter()
2156 .map(|v| v.expect("Unexpected null"))
2157 .collect::<Vec<_>>()
2158 })
2159 .collect::<Vec<_>>()
2160 }
2161
2162 fn make_barrier_exec() -> BarrierExec {
2164 let batch1 = RecordBatch::try_from_iter(vec![(
2165 "my_awesome_field",
2166 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2167 )])
2168 .unwrap();
2169
2170 let batch2 = RecordBatch::try_from_iter(vec![(
2171 "my_awesome_field",
2172 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2173 )])
2174 .unwrap();
2175
2176 let batch3 = RecordBatch::try_from_iter(vec![(
2177 "my_awesome_field",
2178 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
2179 )])
2180 .unwrap();
2181
2182 let batch4 = RecordBatch::try_from_iter(vec![(
2183 "my_awesome_field",
2184 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
2185 )])
2186 .unwrap();
2187
2188 let schema = batch1.schema();
2191 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
2192 }
2193
2194 #[tokio::test]
2195 async fn test_drop_cancel() -> Result<()> {
2196 let task_ctx = Arc::new(TaskContext::default());
2197 let schema =
2198 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2199
2200 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
2201 let refs = blocking_exec.refs();
2202 let repartition_exec = Arc::new(RepartitionExec::try_new(
2203 blocking_exec,
2204 Partitioning::UnknownPartitioning(1),
2205 )?);
2206
2207 let fut = collect(repartition_exec, task_ctx);
2208 let mut fut = fut.boxed();
2209
2210 assert_is_pending(&mut fut);
2211 drop(fut);
2212 assert_strong_count_converges_to_zero(refs).await;
2213
2214 Ok(())
2215 }
2216
2217 #[tokio::test]
2218 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
2219 let task_ctx = Arc::new(TaskContext::default());
2220 let batch = RecordBatch::try_from_iter(vec![(
2221 "a",
2222 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2223 )])
2224 .unwrap();
2225 let partitioning = Partitioning::Hash(
2226 vec![Arc::new(crate::expressions::Column::new("a", 0))],
2227 2,
2228 );
2229 let schema = batch.schema();
2230 let input = MockExec::new(vec![Ok(batch)], schema);
2231 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2232 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2233 let batch0 = crate::common::collect(output_stream0).await.unwrap();
2234 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2235 let batch1 = crate::common::collect(output_stream1).await.unwrap();
2236 assert!(batch0.is_empty() || batch1.is_empty());
2237 Ok(())
2238 }
2239
2240 #[tokio::test]
2241 async fn repartition_with_spilling() -> Result<()> {
2242 let schema = test_schema();
2244 let partition = create_vec_batches(50);
2245 let input_partitions = vec![partition];
2246 let partitioning = Partitioning::RoundRobinBatch(4);
2247
2248 let runtime = RuntimeEnvBuilder::default()
2250 .with_memory_limit(1, 1.0)
2251 .build_arc()?;
2252
2253 let task_ctx = TaskContext::default().with_runtime(runtime);
2254 let task_ctx = Arc::new(task_ctx);
2255
2256 let exec =
2258 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2259 let exec = RepartitionExec::try_new(exec, partitioning)?;
2260
2261 let mut total_rows = 0;
2263 for i in 0..exec.partitioning().partition_count() {
2264 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2265 while let Some(result) = stream.next().await {
2266 let batch = result?;
2267 total_rows += batch.num_rows();
2268 }
2269 }
2270
2271 assert_eq!(total_rows, 50 * 8);
2273
2274 let metrics = exec.metrics().unwrap();
2276 assert!(
2277 metrics.spill_count().unwrap() > 0,
2278 "Expected spill_count > 0, but got {:?}",
2279 metrics.spill_count()
2280 );
2281 println!("Spilled {} times", metrics.spill_count().unwrap());
2282 assert!(
2283 metrics.spilled_bytes().unwrap() > 0,
2284 "Expected spilled_bytes > 0, but got {:?}",
2285 metrics.spilled_bytes()
2286 );
2287 println!(
2288 "Spilled {} bytes in {} spills",
2289 metrics.spilled_bytes().unwrap(),
2290 metrics.spill_count().unwrap()
2291 );
2292 assert!(
2293 metrics.spilled_rows().unwrap() > 0,
2294 "Expected spilled_rows > 0, but got {:?}",
2295 metrics.spilled_rows()
2296 );
2297 println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2298
2299 Ok(())
2300 }
2301
2302 #[tokio::test]
2303 async fn repartition_with_partial_spilling() -> Result<()> {
2304 let schema = test_schema();
2306 let partition = create_vec_batches(50);
2307 let input_partitions = vec![partition];
2308 let partitioning = Partitioning::RoundRobinBatch(4);
2309
2310 let runtime = RuntimeEnvBuilder::default()
2313 .with_memory_limit(2 * 1024, 1.0)
2314 .build_arc()?;
2315
2316 let task_ctx = TaskContext::default().with_runtime(runtime);
2317 let task_ctx = Arc::new(task_ctx);
2318
2319 let exec =
2321 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2322 let exec = RepartitionExec::try_new(exec, partitioning)?;
2323
2324 let mut total_rows = 0;
2326 for i in 0..exec.partitioning().partition_count() {
2327 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2328 while let Some(result) = stream.next().await {
2329 let batch = result?;
2330 total_rows += batch.num_rows();
2331 }
2332 }
2333
2334 assert_eq!(total_rows, 50 * 8);
2336
2337 let metrics = exec.metrics().unwrap();
2339 let spill_count = metrics.spill_count().unwrap();
2340 let spilled_rows = metrics.spilled_rows().unwrap();
2341 let spilled_bytes = metrics.spilled_bytes().unwrap();
2342
2343 assert!(
2344 spill_count > 0,
2345 "Expected some spilling to occur, but got spill_count={spill_count}"
2346 );
2347 assert!(
2348 spilled_rows > 0 && spilled_rows < total_rows,
2349 "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2350 );
2351 assert!(
2352 spilled_bytes > 0,
2353 "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2354 );
2355
2356 println!(
2357 "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2358 spilled_rows,
2359 total_rows,
2360 (spilled_rows as f64 / total_rows as f64) * 100.0,
2361 spill_count,
2362 spilled_bytes
2363 );
2364
2365 Ok(())
2366 }
2367
2368 #[tokio::test]
2369 async fn repartition_without_spilling() -> Result<()> {
2370 let schema = test_schema();
2372 let partition = create_vec_batches(50);
2373 let input_partitions = vec![partition];
2374 let partitioning = Partitioning::RoundRobinBatch(4);
2375
2376 let runtime = RuntimeEnvBuilder::default()
2378 .with_memory_limit(10 * 1024 * 1024, 1.0) .build_arc()?;
2380
2381 let task_ctx = TaskContext::default().with_runtime(runtime);
2382 let task_ctx = Arc::new(task_ctx);
2383
2384 let exec =
2386 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2387 let exec = RepartitionExec::try_new(exec, partitioning)?;
2388
2389 let mut total_rows = 0;
2391 for i in 0..exec.partitioning().partition_count() {
2392 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2393 while let Some(result) = stream.next().await {
2394 let batch = result?;
2395 total_rows += batch.num_rows();
2396 }
2397 }
2398
2399 assert_eq!(total_rows, 50 * 8);
2401
2402 let metrics = exec.metrics().unwrap();
2404 assert_eq!(
2405 metrics.spill_count(),
2406 Some(0),
2407 "Expected no spilling, but got spill_count={:?}",
2408 metrics.spill_count()
2409 );
2410 assert_eq!(
2411 metrics.spilled_bytes(),
2412 Some(0),
2413 "Expected no bytes spilled, but got spilled_bytes={:?}",
2414 metrics.spilled_bytes()
2415 );
2416 assert_eq!(
2417 metrics.spilled_rows(),
2418 Some(0),
2419 "Expected no rows spilled, but got spilled_rows={:?}",
2420 metrics.spilled_rows()
2421 );
2422
2423 println!("No spilling occurred - all data processed in memory");
2424
2425 Ok(())
2426 }
2427
2428 #[tokio::test]
2429 async fn oom() -> Result<()> {
2430 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2431
2432 let schema = test_schema();
2434 let partition = create_vec_batches(50);
2435 let input_partitions = vec![partition];
2436 let partitioning = Partitioning::RoundRobinBatch(4);
2437
2438 let runtime = RuntimeEnvBuilder::default()
2440 .with_memory_limit(1, 1.0)
2441 .with_disk_manager_builder(
2442 DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2443 )
2444 .build_arc()?;
2445
2446 let task_ctx = TaskContext::default().with_runtime(runtime);
2447 let task_ctx = Arc::new(task_ctx);
2448
2449 let exec =
2451 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2452 let exec = RepartitionExec::try_new(exec, partitioning)?;
2453
2454 for i in 0..exec.partitioning().partition_count() {
2456 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2457 let err = stream.next().await.unwrap().unwrap_err();
2458 let err = err.find_root();
2459 assert!(
2460 matches!(err, DataFusionError::ResourcesExhausted(_)),
2461 "Wrong error type: {err}",
2462 );
2463 }
2464
2465 Ok(())
2466 }
2467
2468 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2470 let batch = create_batch();
2471 (0..n).map(|_| batch.clone()).collect()
2472 }
2473
2474 fn create_batch() -> RecordBatch {
2476 let schema = test_schema();
2477 RecordBatch::try_new(
2478 schema,
2479 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2480 )
2481 .unwrap()
2482 }
2483
2484 fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2486 let schema = test_schema();
2487 (0..num_batches)
2488 .map(|i| {
2489 let start = (i * 8) as u32;
2490 RecordBatch::try_new(
2491 Arc::clone(&schema),
2492 vec![Arc::new(UInt32Array::from(
2493 (start..start + 8).collect::<Vec<_>>(),
2494 ))],
2495 )
2496 .unwrap()
2497 })
2498 .collect()
2499 }
2500
2501 #[tokio::test]
2502 async fn test_repartition_ordering_with_spilling() -> Result<()> {
2503 let schema = test_schema();
2508 let partition = create_ordered_batches(20);
2511 let input_partitions = vec![partition];
2512
2513 let partitioning = Partitioning::RoundRobinBatch(2);
2515
2516 let runtime = RuntimeEnvBuilder::default()
2518 .with_memory_limit(1, 1.0)
2519 .build_arc()?;
2520
2521 let task_ctx = TaskContext::default().with_runtime(runtime);
2522 let task_ctx = Arc::new(task_ctx);
2523
2524 let exec =
2526 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2527 let exec = RepartitionExec::try_new(exec, partitioning)?;
2528
2529 let mut all_batches = Vec::new();
2531 for i in 0..exec.partitioning().partition_count() {
2532 let mut partition_batches = Vec::new();
2533 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2534 while let Some(result) = stream.next().await {
2535 let batch = result?;
2536 partition_batches.push(batch);
2537 }
2538 all_batches.push(partition_batches);
2539 }
2540
2541 let metrics = exec.metrics().unwrap();
2543 assert!(
2544 metrics.spill_count().unwrap() > 0,
2545 "Expected spilling to occur, but spill_count = 0"
2546 );
2547
2548 for (partition_idx, batches) in all_batches.iter().enumerate() {
2551 let mut last_value = None;
2552 for batch in batches {
2553 let array = batch
2554 .column(0)
2555 .as_any()
2556 .downcast_ref::<UInt32Array>()
2557 .unwrap();
2558
2559 for i in 0..array.len() {
2560 let value = array.value(i);
2561 if let Some(last) = last_value {
2562 assert!(
2563 value > last,
2564 "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2565 );
2566 }
2567 last_value = Some(value);
2568 }
2569 }
2570 }
2571
2572 Ok(())
2573 }
2574}
2575
2576#[cfg(test)]
2577mod test {
2578 use arrow::array::record_batch;
2579 use arrow::compute::SortOptions;
2580 use arrow::datatypes::{DataType, Field, Schema};
2581 use datafusion_common::assert_batches_eq;
2582
2583 use super::*;
2584 use crate::test::TestMemoryExec;
2585 use crate::union::UnionExec;
2586
2587 use datafusion_physical_expr::expressions::col;
2588 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2589
2590 macro_rules! assert_plan {
2595 ($PLAN: expr, @ $EXPECTED: expr) => {
2596 let formatted = crate::displayable($PLAN).indent(true).to_string();
2597
2598 insta::assert_snapshot!(
2599 formatted,
2600 @$EXPECTED
2601 );
2602 };
2603 }
2604
2605 #[tokio::test]
2606 async fn test_preserve_order() -> Result<()> {
2607 let schema = test_schema();
2608 let sort_exprs = sort_exprs(&schema);
2609 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2610 let source2 = sorted_memory_exec(&schema, sort_exprs);
2611 let union = UnionExec::try_new(vec![source1, source2])?;
2613 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2614 .with_preserve_order();
2615
2616 assert_plan!(&exec, @r"
2618 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2619 UnionExec
2620 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2621 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2622 ");
2623 Ok(())
2624 }
2625
2626 #[tokio::test]
2627 async fn test_preserve_order_one_partition() -> Result<()> {
2628 let schema = test_schema();
2629 let sort_exprs = sort_exprs(&schema);
2630 let source = sorted_memory_exec(&schema, sort_exprs);
2631 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2633 .with_preserve_order();
2634
2635 assert_plan!(&exec, @r"
2637 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
2638 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2639 ");
2640
2641 Ok(())
2642 }
2643
2644 #[tokio::test]
2645 async fn test_preserve_order_input_not_sorted() -> Result<()> {
2646 let schema = test_schema();
2647 let source1 = memory_exec(&schema);
2648 let source2 = memory_exec(&schema);
2649 let union = UnionExec::try_new(vec![source1, source2])?;
2651 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2652 .with_preserve_order();
2653
2654 assert_plan!(&exec, @r"
2656 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2657 UnionExec
2658 DataSourceExec: partitions=1, partition_sizes=[0]
2659 DataSourceExec: partitions=1, partition_sizes=[0]
2660 ");
2661 Ok(())
2662 }
2663
2664 #[tokio::test]
2665 async fn test_preserve_order_with_spilling() -> Result<()> {
2666 use datafusion_execution::TaskContext;
2667 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2668
2669 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2673 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2674 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2675 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2676 let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
2677 let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
2678 let schema = batch1.schema();
2679 let sort_exprs = LexOrdering::new([PhysicalSortExpr {
2680 expr: col("c0", &schema).unwrap(),
2681 options: SortOptions::default().asc(),
2682 }])
2683 .unwrap();
2684 let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
2685 let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
2686 let input_partitions = vec![partition1, partition2];
2687
2688 let runtime = RuntimeEnvBuilder::default()
2691 .with_memory_limit(64, 1.0)
2692 .build_arc()?;
2693
2694 let task_ctx = TaskContext::default().with_runtime(runtime);
2695 let task_ctx = Arc::new(task_ctx);
2696
2697 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
2699 .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
2700 let exec = Arc::new(exec);
2701 let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2702 let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
2705 .with_preserve_order();
2706
2707 let mut batches = vec![];
2708
2709 for i in 0..exec.partitioning().partition_count() {
2711 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2712 while let Some(result) = stream.next().await {
2713 let batch = result?;
2714 batches.push(batch);
2715 }
2716 }
2717
2718 #[rustfmt::skip]
2719 let expected = [
2720 [
2721 "+----+",
2722 "| c0 |",
2723 "+----+",
2724 "| 1 |",
2725 "| 2 |",
2726 "| 3 |",
2727 "| 4 |",
2728 "+----+",
2729 ],
2730 [
2731 "+----+",
2732 "| c0 |",
2733 "+----+",
2734 "| 5 |",
2735 "| 6 |",
2736 "| 7 |",
2737 "| 8 |",
2738 "+----+",
2739 ],
2740 [
2741 "+----+",
2742 "| c0 |",
2743 "+----+",
2744 "| 9 |",
2745 "| 10 |",
2746 "| 11 |",
2747 "| 12 |",
2748 "+----+",
2749 ],
2750 ];
2751
2752 for (batch, expected) in batches.iter().zip(expected.iter()) {
2753 assert_batches_eq!(expected, std::slice::from_ref(batch));
2754 }
2755
2756 let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
2760 let metrics = exec.metrics().unwrap();
2761 assert!(
2762 metrics.spill_count().unwrap() > input_partitions.len(),
2763 "Expected spill_count > {} for order-preserving repartition, but got {:?}",
2764 input_partitions.len(),
2765 metrics.spill_count()
2766 );
2767 assert!(
2768 metrics.spilled_bytes().unwrap()
2769 > all_batches
2770 .iter()
2771 .map(|b| b.get_array_memory_size())
2772 .sum::<usize>(),
2773 "Expected spilled_bytes > {} for order-preserving repartition, got {}",
2774 all_batches
2775 .iter()
2776 .map(|b| b.get_array_memory_size())
2777 .sum::<usize>(),
2778 metrics.spilled_bytes().unwrap()
2779 );
2780 assert!(
2781 metrics.spilled_rows().unwrap()
2782 >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2783 "Expected spilled_rows > {} for order-preserving repartition, got {}",
2784 all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2785 metrics.spilled_rows().unwrap()
2786 );
2787
2788 Ok(())
2789 }
2790
2791 #[tokio::test]
2792 async fn test_hash_partitioning_with_spilling() -> Result<()> {
2793 use datafusion_execution::TaskContext;
2794 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2795
2796 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2798 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2799 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2800 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2801 let schema = batch1.schema();
2802
2803 let partition1 = vec![batch1.clone(), batch3.clone()];
2804 let partition2 = vec![batch2.clone(), batch4.clone()];
2805 let input_partitions = vec![partition1, partition2];
2806
2807 let runtime = RuntimeEnvBuilder::default()
2809 .with_memory_limit(1, 1.0)
2810 .build_arc()?;
2811
2812 let task_ctx = TaskContext::default().with_runtime(runtime);
2813 let task_ctx = Arc::new(task_ctx);
2814
2815 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
2817 let exec = Arc::new(exec);
2818 let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2819 let hash_expr = col("c0", &schema)?;
2821 let exec =
2822 RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
2823
2824 let mut join_set = tokio::task::JoinSet::new();
2827 for i in 0..exec.partitioning().partition_count() {
2828 let stream = exec.execute(i, Arc::clone(&task_ctx))?;
2829 join_set.spawn(async move {
2830 let mut count = 0;
2831 futures::pin_mut!(stream);
2832 while let Some(result) = stream.next().await {
2833 let batch = result?;
2834 count += batch.num_rows();
2835 }
2836 Ok::<usize, DataFusionError>(count)
2837 });
2838 }
2839
2840 let mut total_rows = 0;
2842 while let Some(result) = join_set.join_next().await {
2843 total_rows += result.unwrap()?;
2844 }
2845
2846 let all_batches = [batch1, batch2, batch3, batch4];
2848 let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
2849 assert_eq!(total_rows, expected_rows);
2850
2851 let metrics = exec.metrics().unwrap();
2853 let spill_count = metrics.spill_count().unwrap_or(0);
2855 assert!(spill_count > 0);
2856 let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
2857 assert!(spilled_bytes > 0);
2858 let spilled_rows = metrics.spilled_rows().unwrap_or(0);
2859 assert!(spilled_rows > 0);
2860
2861 Ok(())
2862 }
2863
2864 #[tokio::test]
2865 async fn test_repartition() -> Result<()> {
2866 let schema = test_schema();
2867 let sort_exprs = sort_exprs(&schema);
2868 let source = sorted_memory_exec(&schema, sort_exprs);
2869 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2871 .repartitioned(20, &Default::default())?
2872 .unwrap();
2873
2874 assert_plan!(exec.as_ref(), @r"
2876 RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
2877 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2878 ");
2879 Ok(())
2880 }
2881
2882 fn test_schema() -> Arc<Schema> {
2883 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2884 }
2885
2886 fn sort_exprs(schema: &Schema) -> LexOrdering {
2887 [PhysicalSortExpr {
2888 expr: col("c0", schema).unwrap(),
2889 options: SortOptions::default(),
2890 }]
2891 .into()
2892 }
2893
2894 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2895 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2896 }
2897
2898 fn sorted_memory_exec(
2899 schema: &SchemaRef,
2900 sort_exprs: LexOrdering,
2901 ) -> Arc<dyn ExecutionPlan> {
2902 let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2903 .unwrap()
2904 .try_with_sort_information(vec![sort_exprs])
2905 .unwrap();
2906 let exec = Arc::new(exec);
2907 Arc::new(TestMemoryExec::update_cache(&exec))
2908 }
2909}