1use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
34use crate::hash_utils::create_hashes;
35use crate::metrics::{BaselineMetrics, SpillMetrics};
36use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
37use crate::sorts::streaming_merge::StreamingMergeBuilder;
38use crate::spill::spill_manager::SpillManager;
39use crate::spill::spill_pool::{self, SpillPoolWriter};
40use crate::stream::RecordBatchStreamAdapter;
41use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
42
43use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
44use arrow::compute::take_arrays;
45use arrow::datatypes::{SchemaRef, UInt32Type};
46use datafusion_common::config::ConfigOptions;
47use datafusion_common::stats::Precision;
48use datafusion_common::utils::transpose;
49use datafusion_common::{internal_err, ColumnStatistics, HashMap};
50use datafusion_common::{not_impl_err, DataFusionError, Result};
51use datafusion_common_runtime::SpawnedTask;
52use datafusion_execution::memory_pool::MemoryConsumer;
53use datafusion_execution::TaskContext;
54use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
55use datafusion_physical_expr_common::sort_expr::LexOrdering;
56
57use crate::filter_pushdown::{
58 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
59 FilterPushdownPropagation,
60};
61use futures::stream::Stream;
62use futures::{FutureExt, StreamExt, TryStreamExt};
63use log::trace;
64use parking_lot::Mutex;
65
66mod distributor_channels;
67use distributor_channels::{
68 channels, partition_aware_channels, DistributionReceiver, DistributionSender,
69};
70
71#[derive(Debug)]
120enum RepartitionBatch {
121 Memory(RecordBatch),
123 Spilled,
128}
129
130type MaybeBatch = Option<Result<RepartitionBatch>>;
131type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
132type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
133
134struct OutputChannel {
136 sender: DistributionSender<MaybeBatch>,
137 reservation: SharedMemoryReservation,
138 spill_writer: SpillPoolWriter,
139}
140
141struct PartitionChannels {
163 tx: InputPartitionsToCurrentPartitionSender,
165 rx: InputPartitionsToCurrentPartitionReceiver,
167 reservation: SharedMemoryReservation,
169 spill_writers: Vec<SpillPoolWriter>,
172 spill_readers: Vec<SendableRecordBatchStream>,
175}
176
177struct ConsumingInputStreamsState {
178 channels: HashMap<usize, PartitionChannels>,
181
182 abort_helper: Arc<Vec<SpawnedTask<()>>>,
184}
185
186impl Debug for ConsumingInputStreamsState {
187 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
188 f.debug_struct("ConsumingInputStreamsState")
189 .field("num_channels", &self.channels.len())
190 .field("abort_helper", &self.abort_helper)
191 .finish()
192 }
193}
194
195#[derive(Default)]
197enum RepartitionExecState {
198 #[default]
201 NotInitialized,
202 InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
206 ConsumingInputStreams(ConsumingInputStreamsState),
209}
210
211impl Debug for RepartitionExecState {
212 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
213 match self {
214 RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
215 RepartitionExecState::InputStreamsInitialized(v) => {
216 write!(f, "InputStreamsInitialized({:?})", v.len())
217 }
218 RepartitionExecState::ConsumingInputStreams(v) => {
219 write!(f, "ConsumingInputStreams({v:?})")
220 }
221 }
222 }
223}
224
225impl RepartitionExecState {
226 fn ensure_input_streams_initialized(
227 &mut self,
228 input: Arc<dyn ExecutionPlan>,
229 metrics: ExecutionPlanMetricsSet,
230 output_partitions: usize,
231 ctx: Arc<TaskContext>,
232 ) -> Result<()> {
233 if !matches!(self, RepartitionExecState::NotInitialized) {
234 return Ok(());
235 }
236
237 let num_input_partitions = input.output_partitioning().partition_count();
238 let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
239
240 for i in 0..num_input_partitions {
241 let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
242
243 let timer = metrics.fetch_time.timer();
244 let stream = input.execute(i, Arc::clone(&ctx))?;
245 timer.done();
246
247 streams_and_metrics.push((stream, metrics));
248 }
249 *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
250 Ok(())
251 }
252
253 #[expect(clippy::too_many_arguments)]
254 fn consume_input_streams(
255 &mut self,
256 input: Arc<dyn ExecutionPlan>,
257 metrics: ExecutionPlanMetricsSet,
258 partitioning: Partitioning,
259 preserve_order: bool,
260 name: String,
261 context: Arc<TaskContext>,
262 spill_manager: SpillManager,
263 ) -> Result<&mut ConsumingInputStreamsState> {
264 let streams_and_metrics = match self {
265 RepartitionExecState::NotInitialized => {
266 self.ensure_input_streams_initialized(
267 Arc::clone(&input),
268 metrics.clone(),
269 partitioning.partition_count(),
270 Arc::clone(&context),
271 )?;
272 let RepartitionExecState::InputStreamsInitialized(value) = self else {
273 return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
276 };
277 value
278 }
279 RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
280 RepartitionExecState::InputStreamsInitialized(value) => value,
281 };
282
283 let num_input_partitions = streams_and_metrics.len();
284 let num_output_partitions = partitioning.partition_count();
285
286 let spill_manager = Arc::new(spill_manager);
287
288 let (txs, rxs) = if preserve_order {
289 let (txs_all, rxs_all) =
292 partition_aware_channels(num_input_partitions, num_output_partitions);
293 let txs = transpose(txs_all);
295 let rxs = transpose(rxs_all);
296 (txs, rxs)
297 } else {
298 let (txs, rxs) = channels(num_output_partitions);
300 let txs = txs
302 .into_iter()
303 .map(|item| vec![item; num_input_partitions])
304 .collect::<Vec<_>>();
305 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
306 (txs, rxs)
307 };
308
309 let mut channels = HashMap::with_capacity(txs.len());
310 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
311 let reservation = Arc::new(Mutex::new(
312 MemoryConsumer::new(format!("{name}[{partition}]"))
313 .with_can_spill(true)
314 .register(context.memory_pool()),
315 ));
316
317 let max_file_size = context
322 .session_config()
323 .options()
324 .execution
325 .max_spill_file_size_bytes;
326 let num_spill_channels = if preserve_order {
327 num_input_partitions
328 } else {
329 1
330 };
331 let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
332 ..num_spill_channels)
333 .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
334 .unzip();
335
336 channels.insert(
337 partition,
338 PartitionChannels {
339 tx,
340 rx,
341 reservation,
342 spill_readers,
343 spill_writers,
344 },
345 );
346 }
347
348 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
350 for (i, (stream, metrics)) in
351 std::mem::take(streams_and_metrics).into_iter().enumerate()
352 {
353 let txs: HashMap<_, _> = channels
354 .iter()
355 .map(|(partition, channels)| {
356 let spill_writer_idx = if preserve_order { i } else { 0 };
359 (
360 *partition,
361 OutputChannel {
362 sender: channels.tx[i].clone(),
363 reservation: Arc::clone(&channels.reservation),
364 spill_writer: channels.spill_writers[spill_writer_idx]
365 .clone(),
366 },
367 )
368 })
369 .collect();
370
371 let senders: HashMap<_, _> = txs
373 .iter()
374 .map(|(partition, channel)| (*partition, channel.sender.clone()))
375 .collect();
376
377 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
378 stream,
379 txs,
380 partitioning.clone(),
381 metrics,
382 ));
383
384 let wait_for_task =
387 SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
388 spawned_tasks.push(wait_for_task);
389 }
390 *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
391 channels,
392 abort_helper: Arc::new(spawned_tasks),
393 });
394 match self {
395 RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
396 _ => unreachable!(),
397 }
398 }
399}
400
401pub struct BatchPartitioner {
403 state: BatchPartitionerState,
404 timer: metrics::Time,
405}
406
407enum BatchPartitionerState {
408 Hash {
409 random_state: ahash::RandomState,
410 exprs: Vec<Arc<dyn PhysicalExpr>>,
411 num_partitions: usize,
412 hash_buffer: Vec<u64>,
413 },
414 RoundRobin {
415 num_partitions: usize,
416 next_idx: usize,
417 },
418}
419
420impl BatchPartitioner {
421 pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
425 let state = match partitioning {
426 Partitioning::RoundRobinBatch(num_partitions) => {
427 BatchPartitionerState::RoundRobin {
428 num_partitions,
429 next_idx: 0,
430 }
431 }
432 Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
433 exprs,
434 num_partitions,
435 random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
437 hash_buffer: vec![],
438 },
439 other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
440 };
441
442 Ok(Self { state, timer })
443 }
444
445 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
455 where
456 F: FnMut(usize, RecordBatch) -> Result<()>,
457 {
458 self.partition_iter(batch)?.try_for_each(|res| match res {
459 Ok((partition, batch)) => f(partition, batch),
460 Err(e) => Err(e),
461 })
462 }
463
464 fn partition_iter(
470 &mut self,
471 batch: RecordBatch,
472 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
473 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
474 match &mut self.state {
475 BatchPartitionerState::RoundRobin {
476 num_partitions,
477 next_idx,
478 } => {
479 let idx = *next_idx;
480 *next_idx = (*next_idx + 1) % *num_partitions;
481 Box::new(std::iter::once(Ok((idx, batch))))
482 }
483 BatchPartitionerState::Hash {
484 random_state,
485 exprs,
486 num_partitions: partitions,
487 hash_buffer,
488 } => {
489 let timer = self.timer.timer();
491
492 let arrays = exprs
493 .iter()
494 .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
495 .collect::<Result<Vec<_>>>()?;
496
497 hash_buffer.clear();
498 hash_buffer.resize(batch.num_rows(), 0);
499
500 create_hashes(&arrays, random_state, hash_buffer)?;
501
502 let mut indices: Vec<_> = (0..*partitions)
503 .map(|_| Vec::with_capacity(batch.num_rows()))
504 .collect();
505
506 for (index, hash) in hash_buffer.iter().enumerate() {
507 indices[(*hash % *partitions as u64) as usize].push(index as u32);
508 }
509
510 timer.done();
512
513 let partitioner_timer = &self.timer;
515 let it = indices
516 .into_iter()
517 .enumerate()
518 .filter_map(|(partition, indices)| {
519 let indices: PrimitiveArray<UInt32Type> = indices.into();
520 (!indices.is_empty()).then_some((partition, indices))
521 })
522 .map(move |(partition, indices)| {
523 let _timer = partitioner_timer.timer();
525
526 let columns = take_arrays(batch.columns(), &indices, None)?;
528
529 let mut options = RecordBatchOptions::new();
530 options = options.with_row_count(Some(indices.len()));
531 let batch = RecordBatch::try_new_with_options(
532 batch.schema(),
533 columns,
534 &options,
535 )
536 .unwrap();
537
538 Ok((partition, batch))
539 });
540
541 Box::new(it)
542 }
543 };
544
545 Ok(it)
546 }
547
548 fn num_partitions(&self) -> usize {
550 match self.state {
551 BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
552 BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
553 }
554 }
555}
556
557#[derive(Debug, Clone)]
656pub struct RepartitionExec {
657 input: Arc<dyn ExecutionPlan>,
659 state: Arc<Mutex<RepartitionExecState>>,
662 metrics: ExecutionPlanMetricsSet,
664 preserve_order: bool,
667 cache: PlanProperties,
669}
670
671#[derive(Debug, Clone)]
672struct RepartitionMetrics {
673 fetch_time: metrics::Time,
675 repartition_time: metrics::Time,
677 send_time: Vec<metrics::Time>,
681}
682
683impl RepartitionMetrics {
684 pub fn new(
685 input_partition: usize,
686 num_output_partitions: usize,
687 metrics: &ExecutionPlanMetricsSet,
688 ) -> Self {
689 let fetch_time =
691 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
692
693 let repartition_time =
695 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
696
697 let send_time = (0..num_output_partitions)
699 .map(|output_partition| {
700 let label =
701 metrics::Label::new("outputPartition", output_partition.to_string());
702 MetricBuilder::new(metrics)
703 .with_label(label)
704 .subset_time("send_time", input_partition)
705 })
706 .collect();
707
708 Self {
709 fetch_time,
710 repartition_time,
711 send_time,
712 }
713 }
714}
715
716impl RepartitionExec {
717 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
719 &self.input
720 }
721
722 pub fn partitioning(&self) -> &Partitioning {
724 &self.cache.partitioning
725 }
726
727 pub fn preserve_order(&self) -> bool {
730 self.preserve_order
731 }
732
733 pub fn name(&self) -> &str {
735 "RepartitionExec"
736 }
737}
738
739impl DisplayAs for RepartitionExec {
740 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
741 match t {
742 DisplayFormatType::Default | DisplayFormatType::Verbose => {
743 write!(
744 f,
745 "{}: partitioning={}, input_partitions={}",
746 self.name(),
747 self.partitioning(),
748 self.input.output_partitioning().partition_count()
749 )?;
750
751 if self.preserve_order {
752 write!(f, ", preserve_order=true")?;
753 }
754
755 if let Some(sort_exprs) = self.sort_exprs() {
756 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
757 }
758 Ok(())
759 }
760 DisplayFormatType::TreeRender => {
761 writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
762
763 let input_partition_count =
764 self.input.output_partitioning().partition_count();
765 let output_partition_count = self.partitioning().partition_count();
766 let input_to_output_partition_str =
767 format!("{input_partition_count} -> {output_partition_count}");
768 writeln!(
769 f,
770 "partition_count(in->out)={input_to_output_partition_str}"
771 )?;
772
773 if self.preserve_order {
774 writeln!(f, "preserve_order={}", self.preserve_order)?;
775 }
776 Ok(())
777 }
778 }
779 }
780}
781
782impl ExecutionPlan for RepartitionExec {
783 fn name(&self) -> &'static str {
784 "RepartitionExec"
785 }
786
787 fn as_any(&self) -> &dyn Any {
789 self
790 }
791
792 fn properties(&self) -> &PlanProperties {
793 &self.cache
794 }
795
796 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
797 vec![&self.input]
798 }
799
800 fn with_new_children(
801 self: Arc<Self>,
802 mut children: Vec<Arc<dyn ExecutionPlan>>,
803 ) -> Result<Arc<dyn ExecutionPlan>> {
804 let mut repartition = RepartitionExec::try_new(
805 children.swap_remove(0),
806 self.partitioning().clone(),
807 )?;
808 if self.preserve_order {
809 repartition = repartition.with_preserve_order();
810 }
811 Ok(Arc::new(repartition))
812 }
813
814 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
815 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
816 }
817
818 fn maintains_input_order(&self) -> Vec<bool> {
819 Self::maintains_input_order_helper(self.input(), self.preserve_order)
820 }
821
822 fn execute(
823 &self,
824 partition: usize,
825 context: Arc<TaskContext>,
826 ) -> Result<SendableRecordBatchStream> {
827 trace!(
828 "Start {}::execute for partition: {}",
829 self.name(),
830 partition
831 );
832
833 let spill_metrics = SpillMetrics::new(&self.metrics, partition);
834
835 let input = Arc::clone(&self.input);
836 let partitioning = self.partitioning().clone();
837 let metrics = self.metrics.clone();
838 let preserve_order = self.sort_exprs().is_some();
839 let name = self.name().to_owned();
840 let schema = self.schema();
841 let schema_captured = Arc::clone(&schema);
842
843 let spill_manager = SpillManager::new(
844 Arc::clone(&context.runtime_env()),
845 spill_metrics,
846 input.schema(),
847 );
848
849 let sort_exprs = self.sort_exprs().cloned();
851
852 let state = Arc::clone(&self.state);
853 if let Some(mut state) = state.try_lock() {
854 state.ensure_input_streams_initialized(
855 Arc::clone(&input),
856 metrics.clone(),
857 partitioning.partition_count(),
858 Arc::clone(&context),
859 )?;
860 }
861
862 let num_input_partitions = input.output_partitioning().partition_count();
863
864 let stream = futures::stream::once(async move {
865 let (rx, reservation, spill_readers, abort_helper) = {
867 let mut state = state.lock();
869 let state = state.consume_input_streams(
870 Arc::clone(&input),
871 metrics.clone(),
872 partitioning,
873 preserve_order,
874 name.clone(),
875 Arc::clone(&context),
876 spill_manager.clone(),
877 )?;
878
879 let PartitionChannels {
882 rx,
883 reservation,
884 spill_readers,
885 ..
886 } = state
887 .channels
888 .remove(&partition)
889 .expect("partition not used yet");
890
891 (
892 rx,
893 reservation,
894 spill_readers,
895 Arc::clone(&state.abort_helper),
896 )
897 };
898
899 trace!(
900 "Before returning stream in {name}::execute for partition: {partition}"
901 );
902
903 if preserve_order {
904 let input_streams = rx
907 .into_iter()
908 .zip(spill_readers)
909 .map(|(receiver, spill_stream)| {
910 Box::pin(PerPartitionStream::new(
912 Arc::clone(&schema_captured),
913 receiver,
914 Arc::clone(&abort_helper),
915 Arc::clone(&reservation),
916 spill_stream,
917 1, )) as SendableRecordBatchStream
919 })
920 .collect::<Vec<_>>();
921 let fetch = None;
926 let merge_reservation =
927 MemoryConsumer::new(format!("{name}[Merge {partition}]"))
928 .register(context.memory_pool());
929 StreamingMergeBuilder::new()
930 .with_streams(input_streams)
931 .with_schema(schema_captured)
932 .with_expressions(&sort_exprs.unwrap())
933 .with_metrics(BaselineMetrics::new(&metrics, partition))
934 .with_batch_size(context.session_config().batch_size())
935 .with_fetch(fetch)
936 .with_reservation(merge_reservation)
937 .with_spill_manager(spill_manager)
938 .build()
939 } else {
940 let spill_stream = spill_readers
942 .into_iter()
943 .next()
944 .expect("at least one spill reader should exist");
945
946 Ok(Box::pin(PerPartitionStream::new(
947 schema_captured,
948 rx.into_iter()
949 .next()
950 .expect("at least one receiver should exist"),
951 abort_helper,
952 reservation,
953 spill_stream,
954 num_input_partitions,
955 )) as SendableRecordBatchStream)
956 }
957 })
958 .try_flatten();
959 let stream = RecordBatchStreamAdapter::new(schema, stream);
960 Ok(Box::pin(stream))
961 }
962
963 fn metrics(&self) -> Option<MetricsSet> {
964 Some(self.metrics.clone_inner())
965 }
966
967 fn statistics(&self) -> Result<Statistics> {
968 self.input.partition_statistics(None)
969 }
970
971 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
972 if let Some(partition) = partition {
973 let partition_count = self.partitioning().partition_count();
974 if partition_count == 0 {
975 return Ok(Statistics::new_unknown(&self.schema()));
976 }
977
978 if partition >= partition_count {
979 return internal_err!(
980 "RepartitionExec invalid partition {} (expected less than {})",
981 partition,
982 self.partitioning().partition_count()
983 );
984 }
985
986 let mut stats = self.input.partition_statistics(None)?;
987
988 stats.num_rows = stats
990 .num_rows
991 .get_value()
992 .map(|rows| Precision::Inexact(rows / partition_count))
993 .unwrap_or(Precision::Absent);
994 stats.total_byte_size = stats
995 .total_byte_size
996 .get_value()
997 .map(|bytes| Precision::Inexact(bytes / partition_count))
998 .unwrap_or(Precision::Absent);
999
1000 stats.column_statistics = stats
1002 .column_statistics
1003 .iter()
1004 .map(|_| ColumnStatistics::new_unknown())
1005 .collect();
1006
1007 Ok(stats)
1008 } else {
1009 self.input.partition_statistics(None)
1010 }
1011 }
1012
1013 fn cardinality_effect(&self) -> CardinalityEffect {
1014 CardinalityEffect::Equal
1015 }
1016
1017 fn try_swapping_with_projection(
1018 &self,
1019 projection: &ProjectionExec,
1020 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1021 if projection.expr().len() >= projection.input().schema().fields().len() {
1023 return Ok(None);
1024 }
1025
1026 if projection.benefits_from_input_partitioning()[0]
1028 || !all_columns(projection.expr())
1029 {
1030 return Ok(None);
1031 }
1032
1033 let new_projection = make_with_child(projection, self.input())?;
1034
1035 let new_partitioning = match self.partitioning() {
1036 Partitioning::Hash(partitions, size) => {
1037 let mut new_partitions = vec![];
1038 for partition in partitions {
1039 let Some(new_partition) =
1040 update_expr(partition, projection.expr(), false)?
1041 else {
1042 return Ok(None);
1043 };
1044 new_partitions.push(new_partition);
1045 }
1046 Partitioning::Hash(new_partitions, *size)
1047 }
1048 others => others.clone(),
1049 };
1050
1051 Ok(Some(Arc::new(RepartitionExec::try_new(
1052 new_projection,
1053 new_partitioning,
1054 )?)))
1055 }
1056
1057 fn gather_filters_for_pushdown(
1058 &self,
1059 _phase: FilterPushdownPhase,
1060 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1061 _config: &ConfigOptions,
1062 ) -> Result<FilterDescription> {
1063 FilterDescription::from_children(parent_filters, &self.children())
1064 }
1065
1066 fn handle_child_pushdown_result(
1067 &self,
1068 _phase: FilterPushdownPhase,
1069 child_pushdown_result: ChildPushdownResult,
1070 _config: &ConfigOptions,
1071 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1072 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1073 }
1074
1075 fn repartitioned(
1076 &self,
1077 target_partitions: usize,
1078 _config: &ConfigOptions,
1079 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1080 use Partitioning::*;
1081 let mut new_properties = self.cache.clone();
1082 new_properties.partitioning = match new_properties.partitioning {
1083 RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1084 Hash(hash, _) => Hash(hash, target_partitions),
1085 UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1086 };
1087 Ok(Some(Arc::new(Self {
1088 input: Arc::clone(&self.input),
1089 state: Arc::clone(&self.state),
1090 metrics: self.metrics.clone(),
1091 preserve_order: self.preserve_order,
1092 cache: new_properties,
1093 })))
1094 }
1095}
1096
1097impl RepartitionExec {
1098 pub fn try_new(
1102 input: Arc<dyn ExecutionPlan>,
1103 partitioning: Partitioning,
1104 ) -> Result<Self> {
1105 let preserve_order = false;
1106 let cache =
1107 Self::compute_properties(&input, partitioning.clone(), preserve_order);
1108 Ok(RepartitionExec {
1109 input,
1110 state: Default::default(),
1111 metrics: ExecutionPlanMetricsSet::new(),
1112 preserve_order,
1113 cache,
1114 })
1115 }
1116
1117 fn maintains_input_order_helper(
1118 input: &Arc<dyn ExecutionPlan>,
1119 preserve_order: bool,
1120 ) -> Vec<bool> {
1121 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1123 }
1124
1125 fn eq_properties_helper(
1126 input: &Arc<dyn ExecutionPlan>,
1127 preserve_order: bool,
1128 ) -> EquivalenceProperties {
1129 let mut eq_properties = input.equivalence_properties().clone();
1131 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1133 eq_properties.clear_orderings();
1134 }
1135 if input.output_partitioning().partition_count() > 1 {
1138 eq_properties.clear_per_partition_constants();
1139 }
1140 eq_properties
1141 }
1142
1143 fn compute_properties(
1145 input: &Arc<dyn ExecutionPlan>,
1146 partitioning: Partitioning,
1147 preserve_order: bool,
1148 ) -> PlanProperties {
1149 PlanProperties::new(
1150 Self::eq_properties_helper(input, preserve_order),
1151 partitioning,
1152 input.pipeline_behavior(),
1153 input.boundedness(),
1154 )
1155 .with_scheduling_type(SchedulingType::Cooperative)
1156 .with_evaluation_type(EvaluationType::Eager)
1157 }
1158
1159 pub fn with_preserve_order(mut self) -> Self {
1167 self.preserve_order =
1168 self.input.output_ordering().is_some() &&
1170 self.input.output_partitioning().partition_count() > 1;
1173 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1174 self.cache = self.cache.with_eq_properties(eq_properties);
1175 self
1176 }
1177
1178 fn sort_exprs(&self) -> Option<&LexOrdering> {
1180 if self.preserve_order {
1181 self.input.output_ordering()
1182 } else {
1183 None
1184 }
1185 }
1186
1187 async fn pull_from_input(
1192 mut stream: SendableRecordBatchStream,
1193 mut output_channels: HashMap<usize, OutputChannel>,
1194 partitioning: Partitioning,
1195 metrics: RepartitionMetrics,
1196 ) -> Result<()> {
1197 let mut partitioner =
1198 BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
1199
1200 let mut batches_until_yield = partitioner.num_partitions();
1202 while !output_channels.is_empty() {
1203 let timer = metrics.fetch_time.timer();
1205 let result = stream.next().await;
1206 timer.done();
1207
1208 let batch = match result {
1210 Some(result) => result?,
1211 None => break,
1212 };
1213
1214 if batch.num_rows() == 0 {
1216 continue;
1217 }
1218
1219 for res in partitioner.partition_iter(batch)? {
1220 let (partition, batch) = res?;
1221 let size = batch.get_array_memory_size();
1222
1223 let timer = metrics.send_time[partition].timer();
1224 if let Some(channel) = output_channels.get_mut(&partition) {
1226 let (batch_to_send, is_memory_batch) =
1227 match channel.reservation.lock().try_grow(size) {
1228 Ok(_) => {
1229 (RepartitionBatch::Memory(batch), true)
1231 }
1232 Err(_) => {
1233 channel.spill_writer.push_batch(&batch)?;
1236 (RepartitionBatch::Spilled, false)
1238 }
1239 };
1240
1241 if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
1242 if is_memory_batch {
1245 channel.reservation.lock().shrink(size);
1246 }
1247 output_channels.remove(&partition);
1248 }
1249 }
1250 timer.done();
1251 }
1252
1253 if batches_until_yield == 0 {
1270 tokio::task::yield_now().await;
1271 batches_until_yield = partitioner.num_partitions();
1272 } else {
1273 batches_until_yield -= 1;
1274 }
1275 }
1276
1277 Ok(())
1280 }
1281
1282 async fn wait_for_task(
1288 input_task: SpawnedTask<Result<()>>,
1289 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1290 ) {
1291 match input_task.join().await {
1295 Err(e) => {
1297 let e = Arc::new(e);
1298
1299 for (_, tx) in txs {
1300 let err = Err(DataFusionError::Context(
1301 "Join Error".to_string(),
1302 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1303 ));
1304 tx.send(Some(err)).await.ok();
1305 }
1306 }
1307 Ok(Err(e)) => {
1309 let e = Arc::new(e);
1311
1312 for (_, tx) in txs {
1313 let err = Err(DataFusionError::from(&e));
1315 tx.send(Some(err)).await.ok();
1316 }
1317 }
1318 Ok(Ok(())) => {
1320 for (_partition, tx) in txs {
1322 tx.send(None).await.ok();
1323 }
1324 }
1325 }
1326 }
1327}
1328
1329#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1372enum StreamState {
1373 ReadingMemory,
1375 ReadingSpilled,
1378}
1379
1380struct PerPartitionStream {
1383 schema: SchemaRef,
1385
1386 receiver: DistributionReceiver<MaybeBatch>,
1388
1389 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1391
1392 reservation: SharedMemoryReservation,
1394
1395 spill_stream: SendableRecordBatchStream,
1397
1398 state: StreamState,
1400
1401 remaining_partitions: usize,
1405}
1406
1407impl PerPartitionStream {
1408 fn new(
1409 schema: SchemaRef,
1410 receiver: DistributionReceiver<MaybeBatch>,
1411 drop_helper: Arc<Vec<SpawnedTask<()>>>,
1412 reservation: SharedMemoryReservation,
1413 spill_stream: SendableRecordBatchStream,
1414 num_input_partitions: usize,
1415 ) -> Self {
1416 Self {
1417 schema,
1418 receiver,
1419 _drop_helper: drop_helper,
1420 reservation,
1421 spill_stream,
1422 state: StreamState::ReadingMemory,
1423 remaining_partitions: num_input_partitions,
1424 }
1425 }
1426}
1427
1428impl Stream for PerPartitionStream {
1429 type Item = Result<RecordBatch>;
1430
1431 fn poll_next(
1432 mut self: Pin<&mut Self>,
1433 cx: &mut Context<'_>,
1434 ) -> Poll<Option<Self::Item>> {
1435 use futures::StreamExt;
1436
1437 loop {
1438 match self.state {
1439 StreamState::ReadingMemory => {
1440 let value = match self.receiver.recv().poll_unpin(cx) {
1442 Poll::Ready(v) => v,
1443 Poll::Pending => {
1444 return Poll::Pending;
1446 }
1447 };
1448
1449 match value {
1450 Some(Some(v)) => match v {
1451 Ok(RepartitionBatch::Memory(batch)) => {
1452 self.reservation
1454 .lock()
1455 .shrink(batch.get_array_memory_size());
1456 return Poll::Ready(Some(Ok(batch)));
1457 }
1458 Ok(RepartitionBatch::Spilled) => {
1459 self.state = StreamState::ReadingSpilled;
1463 continue;
1464 }
1465 Err(e) => {
1466 return Poll::Ready(Some(Err(e)));
1467 }
1468 },
1469 Some(None) => {
1470 self.remaining_partitions -= 1;
1472 if self.remaining_partitions == 0 {
1473 return Poll::Ready(None);
1475 }
1476 continue;
1478 }
1479 None => {
1480 return Poll::Ready(None);
1482 }
1483 }
1484 }
1485 StreamState::ReadingSpilled => {
1486 match self.spill_stream.poll_next_unpin(cx) {
1488 Poll::Ready(Some(Ok(batch))) => {
1489 self.state = StreamState::ReadingMemory;
1490 return Poll::Ready(Some(Ok(batch)));
1491 }
1492 Poll::Ready(Some(Err(e))) => {
1493 return Poll::Ready(Some(Err(e)));
1494 }
1495 Poll::Ready(None) => {
1496 self.state = StreamState::ReadingMemory;
1498 }
1499 Poll::Pending => {
1500 return Poll::Pending;
1503 }
1504 }
1505 }
1506 }
1507 }
1508 }
1509}
1510
1511impl RecordBatchStream for PerPartitionStream {
1512 fn schema(&self) -> SchemaRef {
1514 Arc::clone(&self.schema)
1515 }
1516}
1517
1518#[cfg(test)]
1519mod tests {
1520 use std::collections::HashSet;
1521
1522 use super::*;
1523 use crate::test::TestMemoryExec;
1524 use crate::{
1525 test::{
1526 assert_is_pending,
1527 exec::{
1528 assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1529 ErrorExec, MockExec,
1530 },
1531 },
1532 {collect, expressions::col},
1533 };
1534
1535 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1536 use arrow::datatypes::{DataType, Field, Schema};
1537 use datafusion_common::cast::as_string_array;
1538 use datafusion_common::exec_err;
1539 use datafusion_common::test_util::batches_to_sort_string;
1540 use datafusion_common_runtime::JoinSet;
1541 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1542 use insta::assert_snapshot;
1543 use itertools::Itertools;
1544
1545 #[tokio::test]
1546 async fn one_to_many_round_robin() -> Result<()> {
1547 let schema = test_schema();
1549 let partition = create_vec_batches(50);
1550 let partitions = vec![partition];
1551
1552 let output_partitions =
1554 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1555
1556 assert_eq!(4, output_partitions.len());
1557 assert_eq!(13, output_partitions[0].len());
1558 assert_eq!(13, output_partitions[1].len());
1559 assert_eq!(12, output_partitions[2].len());
1560 assert_eq!(12, output_partitions[3].len());
1561
1562 Ok(())
1563 }
1564
1565 #[tokio::test]
1566 async fn many_to_one_round_robin() -> Result<()> {
1567 let schema = test_schema();
1569 let partition = create_vec_batches(50);
1570 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1571
1572 let output_partitions =
1574 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1575
1576 assert_eq!(1, output_partitions.len());
1577 assert_eq!(150, output_partitions[0].len());
1578
1579 Ok(())
1580 }
1581
1582 #[tokio::test]
1583 async fn many_to_many_round_robin() -> Result<()> {
1584 let schema = test_schema();
1586 let partition = create_vec_batches(50);
1587 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1588
1589 let output_partitions =
1591 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1592
1593 assert_eq!(5, output_partitions.len());
1594 assert_eq!(30, output_partitions[0].len());
1595 assert_eq!(30, output_partitions[1].len());
1596 assert_eq!(30, output_partitions[2].len());
1597 assert_eq!(30, output_partitions[3].len());
1598 assert_eq!(30, output_partitions[4].len());
1599
1600 Ok(())
1601 }
1602
1603 #[tokio::test]
1604 async fn many_to_many_hash_partition() -> Result<()> {
1605 let schema = test_schema();
1607 let partition = create_vec_batches(50);
1608 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1609
1610 let output_partitions = repartition(
1611 &schema,
1612 partitions,
1613 Partitioning::Hash(vec![col("c0", &schema)?], 8),
1614 )
1615 .await?;
1616
1617 let total_rows: usize = output_partitions
1618 .iter()
1619 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1620 .sum();
1621
1622 assert_eq!(8, output_partitions.len());
1623 assert_eq!(total_rows, 8 * 50 * 3);
1624
1625 Ok(())
1626 }
1627
1628 fn test_schema() -> Arc<Schema> {
1629 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1630 }
1631
1632 async fn repartition(
1633 schema: &SchemaRef,
1634 input_partitions: Vec<Vec<RecordBatch>>,
1635 partitioning: Partitioning,
1636 ) -> Result<Vec<Vec<RecordBatch>>> {
1637 let task_ctx = Arc::new(TaskContext::default());
1638 let exec =
1640 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1641 let exec = RepartitionExec::try_new(exec, partitioning)?;
1642
1643 let mut output_partitions = vec![];
1645 for i in 0..exec.partitioning().partition_count() {
1646 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1648 let mut batches = vec![];
1649 while let Some(result) = stream.next().await {
1650 batches.push(result?);
1651 }
1652 output_partitions.push(batches);
1653 }
1654 Ok(output_partitions)
1655 }
1656
1657 #[tokio::test]
1658 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1659 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1660 SpawnedTask::spawn(async move {
1661 let schema = test_schema();
1663 let partition = create_vec_batches(50);
1664 let partitions =
1665 vec![partition.clone(), partition.clone(), partition.clone()];
1666
1667 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1669 });
1670
1671 let output_partitions = handle.join().await.unwrap().unwrap();
1672
1673 assert_eq!(5, output_partitions.len());
1674 assert_eq!(30, output_partitions[0].len());
1675 assert_eq!(30, output_partitions[1].len());
1676 assert_eq!(30, output_partitions[2].len());
1677 assert_eq!(30, output_partitions[3].len());
1678 assert_eq!(30, output_partitions[4].len());
1679
1680 Ok(())
1681 }
1682
1683 #[tokio::test]
1684 async fn unsupported_partitioning() {
1685 let task_ctx = Arc::new(TaskContext::default());
1686 let batch = RecordBatch::try_from_iter(vec![(
1688 "my_awesome_field",
1689 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1690 )])
1691 .unwrap();
1692
1693 let schema = batch.schema();
1694 let input = MockExec::new(vec![Ok(batch)], schema);
1695 let partitioning = Partitioning::UnknownPartitioning(1);
1699 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1700 let output_stream = exec.execute(0, task_ctx).unwrap();
1701
1702 let result_string = crate::common::collect(output_stream)
1704 .await
1705 .unwrap_err()
1706 .to_string();
1707 assert!(
1708 result_string
1709 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1710 "actual: {result_string}"
1711 );
1712 }
1713
1714 #[tokio::test]
1715 async fn error_for_input_exec() {
1716 let task_ctx = Arc::new(TaskContext::default());
1720 let input = ErrorExec::new();
1721 let partitioning = Partitioning::RoundRobinBatch(1);
1722 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1723
1724 let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1726
1727 assert!(
1728 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1729 "actual: {result_string}"
1730 );
1731 }
1732
1733 #[tokio::test]
1734 async fn repartition_with_error_in_stream() {
1735 let task_ctx = Arc::new(TaskContext::default());
1736 let batch = RecordBatch::try_from_iter(vec![(
1737 "my_awesome_field",
1738 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1739 )])
1740 .unwrap();
1741
1742 let err = exec_err!("bad data error");
1745
1746 let schema = batch.schema();
1747 let input = MockExec::new(vec![Ok(batch), err], schema);
1748 let partitioning = Partitioning::RoundRobinBatch(1);
1749 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1750
1751 let output_stream = exec.execute(0, task_ctx).unwrap();
1754
1755 let result_string = crate::common::collect(output_stream)
1757 .await
1758 .unwrap_err()
1759 .to_string();
1760 assert!(
1761 result_string.contains("bad data error"),
1762 "actual: {result_string}"
1763 );
1764 }
1765
1766 #[tokio::test]
1767 async fn repartition_with_delayed_stream() {
1768 let task_ctx = Arc::new(TaskContext::default());
1769 let batch1 = RecordBatch::try_from_iter(vec![(
1770 "my_awesome_field",
1771 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1772 )])
1773 .unwrap();
1774
1775 let batch2 = RecordBatch::try_from_iter(vec![(
1776 "my_awesome_field",
1777 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1778 )])
1779 .unwrap();
1780
1781 let schema = batch1.schema();
1784 let expected_batches = vec![batch1.clone(), batch2.clone()];
1785 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1786 let partitioning = Partitioning::RoundRobinBatch(1);
1787
1788 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1789
1790 assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
1791 +------------------+
1792 | my_awesome_field |
1793 +------------------+
1794 | bar |
1795 | baz |
1796 | foo |
1797 | frob |
1798 +------------------+
1799 ");
1800
1801 let output_stream = exec.execute(0, task_ctx).unwrap();
1802 let batches = crate::common::collect(output_stream).await.unwrap();
1803
1804 assert_snapshot!(batches_to_sort_string(&batches), @r"
1805 +------------------+
1806 | my_awesome_field |
1807 +------------------+
1808 | bar |
1809 | baz |
1810 | foo |
1811 | frob |
1812 +------------------+
1813 ");
1814 }
1815
1816 #[tokio::test]
1817 async fn robin_repartition_with_dropping_output_stream() {
1818 let task_ctx = Arc::new(TaskContext::default());
1819 let partitioning = Partitioning::RoundRobinBatch(2);
1820 let input = Arc::new(make_barrier_exec());
1823
1824 let exec = RepartitionExec::try_new(
1826 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1827 partitioning,
1828 )
1829 .unwrap();
1830
1831 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1832 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1833
1834 drop(output_stream0);
1837
1838 let mut background_task = JoinSet::new();
1840 background_task.spawn(async move {
1841 input.wait().await;
1842 });
1843
1844 let batches = crate::common::collect(output_stream1).await.unwrap();
1846
1847 assert_snapshot!(batches_to_sort_string(&batches), @r#"
1848 +------------------+
1849 | my_awesome_field |
1850 +------------------+
1851 | baz |
1852 | frob |
1853 | gaz |
1854 | grob |
1855 +------------------+
1856 "#);
1857 }
1858
1859 #[tokio::test]
1860 async fn hash_repartition_with_dropping_output_stream() {
1864 let task_ctx = Arc::new(TaskContext::default());
1865 let partitioning = Partitioning::Hash(
1866 vec![Arc::new(crate::expressions::Column::new(
1867 "my_awesome_field",
1868 0,
1869 ))],
1870 2,
1871 );
1872
1873 let input = Arc::new(make_barrier_exec());
1875 let exec = RepartitionExec::try_new(
1876 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1877 partitioning.clone(),
1878 )
1879 .unwrap();
1880 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1881 let mut background_task = JoinSet::new();
1882 background_task.spawn(async move {
1883 input.wait().await;
1884 });
1885 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1886
1887 let items_vec = str_batches_to_vec(&batches_without_drop);
1889 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1890 assert_eq!(items_vec.len(), items_set.len());
1891 let source_str_set: HashSet<&str> =
1892 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1893 .iter()
1894 .copied()
1895 .collect();
1896 assert_eq!(items_set.difference(&source_str_set).count(), 0);
1897
1898 let input = Arc::new(make_barrier_exec());
1900 let exec = RepartitionExec::try_new(
1901 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1902 partitioning,
1903 )
1904 .unwrap();
1905 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1906 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1907 drop(output_stream0);
1910 let mut background_task = JoinSet::new();
1911 background_task.spawn(async move {
1912 input.wait().await;
1913 });
1914 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1915
1916 fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1917 batch
1918 .into_iter()
1919 .sorted_by_key(|b| format!("{b:?}"))
1920 .collect()
1921 }
1922
1923 assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1924 }
1925
1926 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1927 batches
1928 .iter()
1929 .flat_map(|batch| {
1930 assert_eq!(batch.columns().len(), 1);
1931 let string_array = as_string_array(batch.column(0))
1932 .expect("Unexpected type for repartitioned batch");
1933
1934 string_array
1935 .iter()
1936 .map(|v| v.expect("Unexpected null"))
1937 .collect::<Vec<_>>()
1938 })
1939 .collect::<Vec<_>>()
1940 }
1941
1942 fn make_barrier_exec() -> BarrierExec {
1944 let batch1 = RecordBatch::try_from_iter(vec![(
1945 "my_awesome_field",
1946 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1947 )])
1948 .unwrap();
1949
1950 let batch2 = RecordBatch::try_from_iter(vec![(
1951 "my_awesome_field",
1952 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1953 )])
1954 .unwrap();
1955
1956 let batch3 = RecordBatch::try_from_iter(vec![(
1957 "my_awesome_field",
1958 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1959 )])
1960 .unwrap();
1961
1962 let batch4 = RecordBatch::try_from_iter(vec![(
1963 "my_awesome_field",
1964 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1965 )])
1966 .unwrap();
1967
1968 let schema = batch1.schema();
1971 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1972 }
1973
1974 #[tokio::test]
1975 async fn test_drop_cancel() -> Result<()> {
1976 let task_ctx = Arc::new(TaskContext::default());
1977 let schema =
1978 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1979
1980 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1981 let refs = blocking_exec.refs();
1982 let repartition_exec = Arc::new(RepartitionExec::try_new(
1983 blocking_exec,
1984 Partitioning::UnknownPartitioning(1),
1985 )?);
1986
1987 let fut = collect(repartition_exec, task_ctx);
1988 let mut fut = fut.boxed();
1989
1990 assert_is_pending(&mut fut);
1991 drop(fut);
1992 assert_strong_count_converges_to_zero(refs).await;
1993
1994 Ok(())
1995 }
1996
1997 #[tokio::test]
1998 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1999 let task_ctx = Arc::new(TaskContext::default());
2000 let batch = RecordBatch::try_from_iter(vec![(
2001 "a",
2002 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2003 )])
2004 .unwrap();
2005 let partitioning = Partitioning::Hash(
2006 vec![Arc::new(crate::expressions::Column::new("a", 0))],
2007 2,
2008 );
2009 let schema = batch.schema();
2010 let input = MockExec::new(vec![Ok(batch)], schema);
2011 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2012 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2013 let batch0 = crate::common::collect(output_stream0).await.unwrap();
2014 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2015 let batch1 = crate::common::collect(output_stream1).await.unwrap();
2016 assert!(batch0.is_empty() || batch1.is_empty());
2017 Ok(())
2018 }
2019
2020 #[tokio::test]
2021 async fn repartition_with_spilling() -> Result<()> {
2022 let schema = test_schema();
2024 let partition = create_vec_batches(50);
2025 let input_partitions = vec![partition];
2026 let partitioning = Partitioning::RoundRobinBatch(4);
2027
2028 let runtime = RuntimeEnvBuilder::default()
2030 .with_memory_limit(1, 1.0)
2031 .build_arc()?;
2032
2033 let task_ctx = TaskContext::default().with_runtime(runtime);
2034 let task_ctx = Arc::new(task_ctx);
2035
2036 let exec =
2038 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2039 let exec = RepartitionExec::try_new(exec, partitioning)?;
2040
2041 let mut total_rows = 0;
2043 for i in 0..exec.partitioning().partition_count() {
2044 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2045 while let Some(result) = stream.next().await {
2046 let batch = result?;
2047 total_rows += batch.num_rows();
2048 }
2049 }
2050
2051 assert_eq!(total_rows, 50 * 8);
2053
2054 let metrics = exec.metrics().unwrap();
2056 assert!(
2057 metrics.spill_count().unwrap() > 0,
2058 "Expected spill_count > 0, but got {:?}",
2059 metrics.spill_count()
2060 );
2061 println!("Spilled {} times", metrics.spill_count().unwrap());
2062 assert!(
2063 metrics.spilled_bytes().unwrap() > 0,
2064 "Expected spilled_bytes > 0, but got {:?}",
2065 metrics.spilled_bytes()
2066 );
2067 println!(
2068 "Spilled {} bytes in {} spills",
2069 metrics.spilled_bytes().unwrap(),
2070 metrics.spill_count().unwrap()
2071 );
2072 assert!(
2073 metrics.spilled_rows().unwrap() > 0,
2074 "Expected spilled_rows > 0, but got {:?}",
2075 metrics.spilled_rows()
2076 );
2077 println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2078
2079 Ok(())
2080 }
2081
2082 #[tokio::test]
2083 async fn repartition_with_partial_spilling() -> Result<()> {
2084 let schema = test_schema();
2086 let partition = create_vec_batches(50);
2087 let input_partitions = vec![partition];
2088 let partitioning = Partitioning::RoundRobinBatch(4);
2089
2090 let runtime = RuntimeEnvBuilder::default()
2093 .with_memory_limit(2 * 1024, 1.0)
2094 .build_arc()?;
2095
2096 let task_ctx = TaskContext::default().with_runtime(runtime);
2097 let task_ctx = Arc::new(task_ctx);
2098
2099 let exec =
2101 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2102 let exec = RepartitionExec::try_new(exec, partitioning)?;
2103
2104 let mut total_rows = 0;
2106 for i in 0..exec.partitioning().partition_count() {
2107 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2108 while let Some(result) = stream.next().await {
2109 let batch = result?;
2110 total_rows += batch.num_rows();
2111 }
2112 }
2113
2114 assert_eq!(total_rows, 50 * 8);
2116
2117 let metrics = exec.metrics().unwrap();
2119 let spill_count = metrics.spill_count().unwrap();
2120 let spilled_rows = metrics.spilled_rows().unwrap();
2121 let spilled_bytes = metrics.spilled_bytes().unwrap();
2122
2123 assert!(
2124 spill_count > 0,
2125 "Expected some spilling to occur, but got spill_count={spill_count}"
2126 );
2127 assert!(
2128 spilled_rows > 0 && spilled_rows < total_rows,
2129 "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2130 );
2131 assert!(
2132 spilled_bytes > 0,
2133 "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2134 );
2135
2136 println!(
2137 "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2138 spilled_rows,
2139 total_rows,
2140 (spilled_rows as f64 / total_rows as f64) * 100.0,
2141 spill_count,
2142 spilled_bytes
2143 );
2144
2145 Ok(())
2146 }
2147
2148 #[tokio::test]
2149 async fn repartition_without_spilling() -> Result<()> {
2150 let schema = test_schema();
2152 let partition = create_vec_batches(50);
2153 let input_partitions = vec![partition];
2154 let partitioning = Partitioning::RoundRobinBatch(4);
2155
2156 let runtime = RuntimeEnvBuilder::default()
2158 .with_memory_limit(10 * 1024 * 1024, 1.0) .build_arc()?;
2160
2161 let task_ctx = TaskContext::default().with_runtime(runtime);
2162 let task_ctx = Arc::new(task_ctx);
2163
2164 let exec =
2166 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2167 let exec = RepartitionExec::try_new(exec, partitioning)?;
2168
2169 let mut total_rows = 0;
2171 for i in 0..exec.partitioning().partition_count() {
2172 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2173 while let Some(result) = stream.next().await {
2174 let batch = result?;
2175 total_rows += batch.num_rows();
2176 }
2177 }
2178
2179 assert_eq!(total_rows, 50 * 8);
2181
2182 let metrics = exec.metrics().unwrap();
2184 assert_eq!(
2185 metrics.spill_count(),
2186 Some(0),
2187 "Expected no spilling, but got spill_count={:?}",
2188 metrics.spill_count()
2189 );
2190 assert_eq!(
2191 metrics.spilled_bytes(),
2192 Some(0),
2193 "Expected no bytes spilled, but got spilled_bytes={:?}",
2194 metrics.spilled_bytes()
2195 );
2196 assert_eq!(
2197 metrics.spilled_rows(),
2198 Some(0),
2199 "Expected no rows spilled, but got spilled_rows={:?}",
2200 metrics.spilled_rows()
2201 );
2202
2203 println!("No spilling occurred - all data processed in memory");
2204
2205 Ok(())
2206 }
2207
2208 #[tokio::test]
2209 async fn oom() -> Result<()> {
2210 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2211
2212 let schema = test_schema();
2214 let partition = create_vec_batches(50);
2215 let input_partitions = vec![partition];
2216 let partitioning = Partitioning::RoundRobinBatch(4);
2217
2218 let runtime = RuntimeEnvBuilder::default()
2220 .with_memory_limit(1, 1.0)
2221 .with_disk_manager_builder(
2222 DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2223 )
2224 .build_arc()?;
2225
2226 let task_ctx = TaskContext::default().with_runtime(runtime);
2227 let task_ctx = Arc::new(task_ctx);
2228
2229 let exec =
2231 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2232 let exec = RepartitionExec::try_new(exec, partitioning)?;
2233
2234 for i in 0..exec.partitioning().partition_count() {
2236 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2237 let err = stream.next().await.unwrap().unwrap_err();
2238 let err = err.find_root();
2239 assert!(
2240 matches!(err, DataFusionError::ResourcesExhausted(_)),
2241 "Wrong error type: {err}",
2242 );
2243 }
2244
2245 Ok(())
2246 }
2247
2248 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2250 let batch = create_batch();
2251 (0..n).map(|_| batch.clone()).collect()
2252 }
2253
2254 fn create_batch() -> RecordBatch {
2256 let schema = test_schema();
2257 RecordBatch::try_new(
2258 schema,
2259 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2260 )
2261 .unwrap()
2262 }
2263
2264 fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2266 let schema = test_schema();
2267 (0..num_batches)
2268 .map(|i| {
2269 let start = (i * 8) as u32;
2270 RecordBatch::try_new(
2271 Arc::clone(&schema),
2272 vec![Arc::new(UInt32Array::from(
2273 (start..start + 8).collect::<Vec<_>>(),
2274 ))],
2275 )
2276 .unwrap()
2277 })
2278 .collect()
2279 }
2280
2281 #[tokio::test]
2282 async fn test_repartition_ordering_with_spilling() -> Result<()> {
2283 let schema = test_schema();
2288 let partition = create_ordered_batches(20);
2291 let input_partitions = vec![partition];
2292
2293 let partitioning = Partitioning::RoundRobinBatch(2);
2295
2296 let runtime = RuntimeEnvBuilder::default()
2298 .with_memory_limit(1, 1.0)
2299 .build_arc()?;
2300
2301 let task_ctx = TaskContext::default().with_runtime(runtime);
2302 let task_ctx = Arc::new(task_ctx);
2303
2304 let exec =
2306 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2307 let exec = RepartitionExec::try_new(exec, partitioning)?;
2308
2309 let mut all_batches = Vec::new();
2311 for i in 0..exec.partitioning().partition_count() {
2312 let mut partition_batches = Vec::new();
2313 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2314 while let Some(result) = stream.next().await {
2315 let batch = result?;
2316 partition_batches.push(batch);
2317 }
2318 all_batches.push(partition_batches);
2319 }
2320
2321 let metrics = exec.metrics().unwrap();
2323 assert!(
2324 metrics.spill_count().unwrap() > 0,
2325 "Expected spilling to occur, but spill_count = 0"
2326 );
2327
2328 for (partition_idx, batches) in all_batches.iter().enumerate() {
2331 let mut last_value = None;
2332 for batch in batches {
2333 let array = batch
2334 .column(0)
2335 .as_any()
2336 .downcast_ref::<UInt32Array>()
2337 .unwrap();
2338
2339 for i in 0..array.len() {
2340 let value = array.value(i);
2341 if let Some(last) = last_value {
2342 assert!(
2343 value > last,
2344 "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2345 );
2346 }
2347 last_value = Some(value);
2348 }
2349 }
2350 }
2351
2352 Ok(())
2353 }
2354}
2355
2356#[cfg(test)]
2357mod test {
2358 use arrow::array::record_batch;
2359 use arrow::compute::SortOptions;
2360 use arrow::datatypes::{DataType, Field, Schema};
2361 use datafusion_common::assert_batches_eq;
2362
2363 use super::*;
2364 use crate::test::TestMemoryExec;
2365 use crate::union::UnionExec;
2366
2367 use datafusion_physical_expr::expressions::col;
2368 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2369
2370 macro_rules! assert_plan {
2375 ($PLAN: expr, @ $EXPECTED: expr) => {
2376 let formatted = crate::displayable($PLAN).indent(true).to_string();
2377
2378 insta::assert_snapshot!(
2379 formatted,
2380 @$EXPECTED
2381 );
2382 };
2383 }
2384
2385 #[tokio::test]
2386 async fn test_preserve_order() -> Result<()> {
2387 let schema = test_schema();
2388 let sort_exprs = sort_exprs(&schema);
2389 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2390 let source2 = sorted_memory_exec(&schema, sort_exprs);
2391 let union = UnionExec::try_new(vec![source1, source2])?;
2393 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2394 .with_preserve_order();
2395
2396 assert_plan!(&exec, @r"
2398 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2399 UnionExec
2400 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2401 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2402 ");
2403 Ok(())
2404 }
2405
2406 #[tokio::test]
2407 async fn test_preserve_order_one_partition() -> Result<()> {
2408 let schema = test_schema();
2409 let sort_exprs = sort_exprs(&schema);
2410 let source = sorted_memory_exec(&schema, sort_exprs);
2411 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2413 .with_preserve_order();
2414
2415 assert_plan!(&exec, @r"
2417 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1
2418 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2419 ");
2420
2421 Ok(())
2422 }
2423
2424 #[tokio::test]
2425 async fn test_preserve_order_input_not_sorted() -> Result<()> {
2426 let schema = test_schema();
2427 let source1 = memory_exec(&schema);
2428 let source2 = memory_exec(&schema);
2429 let union = UnionExec::try_new(vec![source1, source2])?;
2431 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2432 .with_preserve_order();
2433
2434 assert_plan!(&exec, @r"
2436 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2437 UnionExec
2438 DataSourceExec: partitions=1, partition_sizes=[0]
2439 DataSourceExec: partitions=1, partition_sizes=[0]
2440 ");
2441 Ok(())
2442 }
2443
2444 #[tokio::test]
2445 async fn test_preserve_order_with_spilling() -> Result<()> {
2446 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2447 use datafusion_execution::TaskContext;
2448
2449 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2453 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2454 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2455 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2456 let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
2457 let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
2458 let schema = batch1.schema();
2459 let sort_exprs = LexOrdering::new([PhysicalSortExpr {
2460 expr: col("c0", &schema).unwrap(),
2461 options: SortOptions::default().asc(),
2462 }])
2463 .unwrap();
2464 let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
2465 let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
2466 let input_partitions = vec![partition1, partition2];
2467
2468 let runtime = RuntimeEnvBuilder::default()
2471 .with_memory_limit(64, 1.0)
2472 .build_arc()?;
2473
2474 let task_ctx = TaskContext::default().with_runtime(runtime);
2475 let task_ctx = Arc::new(task_ctx);
2476
2477 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
2479 .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
2480 let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec)));
2481 let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
2484 .with_preserve_order();
2485
2486 let mut batches = vec![];
2487
2488 for i in 0..exec.partitioning().partition_count() {
2490 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2491 while let Some(result) = stream.next().await {
2492 let batch = result?;
2493 batches.push(batch);
2494 }
2495 }
2496
2497 #[rustfmt::skip]
2498 let expected = [
2499 [
2500 "+----+",
2501 "| c0 |",
2502 "+----+",
2503 "| 1 |",
2504 "| 2 |",
2505 "| 3 |",
2506 "| 4 |",
2507 "+----+",
2508 ],
2509 [
2510 "+----+",
2511 "| c0 |",
2512 "+----+",
2513 "| 5 |",
2514 "| 6 |",
2515 "| 7 |",
2516 "| 8 |",
2517 "+----+",
2518 ],
2519 [
2520 "+----+",
2521 "| c0 |",
2522 "+----+",
2523 "| 9 |",
2524 "| 10 |",
2525 "| 11 |",
2526 "| 12 |",
2527 "+----+",
2528 ],
2529 ];
2530
2531 for (batch, expected) in batches.iter().zip(expected.iter()) {
2532 assert_batches_eq!(expected, std::slice::from_ref(batch));
2533 }
2534
2535 let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
2539 let metrics = exec.metrics().unwrap();
2540 assert!(
2541 metrics.spill_count().unwrap() > input_partitions.len(),
2542 "Expected spill_count > {} for order-preserving repartition, but got {:?}",
2543 input_partitions.len(),
2544 metrics.spill_count()
2545 );
2546 assert!(
2547 metrics.spilled_bytes().unwrap()
2548 > all_batches
2549 .iter()
2550 .map(|b| b.get_array_memory_size())
2551 .sum::<usize>(),
2552 "Expected spilled_bytes > {} for order-preserving repartition, got {}",
2553 all_batches
2554 .iter()
2555 .map(|b| b.get_array_memory_size())
2556 .sum::<usize>(),
2557 metrics.spilled_bytes().unwrap()
2558 );
2559 assert!(
2560 metrics.spilled_rows().unwrap()
2561 >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2562 "Expected spilled_rows > {} for order-preserving repartition, got {}",
2563 all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2564 metrics.spilled_rows().unwrap()
2565 );
2566
2567 Ok(())
2568 }
2569
2570 #[tokio::test]
2571 async fn test_hash_partitioning_with_spilling() -> Result<()> {
2572 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2573 use datafusion_execution::TaskContext;
2574
2575 let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2577 let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2578 let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2579 let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2580 let schema = batch1.schema();
2581
2582 let partition1 = vec![batch1.clone(), batch3.clone()];
2583 let partition2 = vec![batch2.clone(), batch4.clone()];
2584 let input_partitions = vec![partition1, partition2];
2585
2586 let runtime = RuntimeEnvBuilder::default()
2588 .with_memory_limit(1, 1.0)
2589 .build_arc()?;
2590
2591 let task_ctx = TaskContext::default().with_runtime(runtime);
2592 let task_ctx = Arc::new(task_ctx);
2593
2594 let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
2596 let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec)));
2597 let hash_expr = col("c0", &schema)?;
2599 let exec =
2600 RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
2601
2602 let mut join_set = tokio::task::JoinSet::new();
2605 for i in 0..exec.partitioning().partition_count() {
2606 let stream = exec.execute(i, Arc::clone(&task_ctx))?;
2607 join_set.spawn(async move {
2608 let mut count = 0;
2609 futures::pin_mut!(stream);
2610 while let Some(result) = stream.next().await {
2611 let batch = result?;
2612 count += batch.num_rows();
2613 }
2614 Ok::<usize, DataFusionError>(count)
2615 });
2616 }
2617
2618 let mut total_rows = 0;
2620 while let Some(result) = join_set.join_next().await {
2621 total_rows += result.unwrap()?;
2622 }
2623
2624 let all_batches = [batch1, batch2, batch3, batch4];
2626 let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
2627 assert_eq!(total_rows, expected_rows);
2628
2629 let metrics = exec.metrics().unwrap();
2631 let spill_count = metrics.spill_count().unwrap_or(0);
2633 assert!(spill_count > 0);
2634 let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
2635 assert!(spilled_bytes > 0);
2636 let spilled_rows = metrics.spilled_rows().unwrap_or(0);
2637 assert!(spilled_rows > 0);
2638
2639 Ok(())
2640 }
2641
2642 #[tokio::test]
2643 async fn test_repartition() -> Result<()> {
2644 let schema = test_schema();
2645 let sort_exprs = sort_exprs(&schema);
2646 let source = sorted_memory_exec(&schema, sort_exprs);
2647 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2649 .repartitioned(20, &Default::default())?
2650 .unwrap();
2651
2652 assert_plan!(exec.as_ref(), @r"
2654 RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1
2655 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2656 ");
2657 Ok(())
2658 }
2659
2660 fn test_schema() -> Arc<Schema> {
2661 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2662 }
2663
2664 fn sort_exprs(schema: &Schema) -> LexOrdering {
2665 [PhysicalSortExpr {
2666 expr: col("c0", schema).unwrap(),
2667 options: SortOptions::default(),
2668 }]
2669 .into()
2670 }
2671
2672 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2673 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2674 }
2675
2676 fn sorted_memory_exec(
2677 schema: &SchemaRef,
2678 sort_exprs: LexOrdering,
2679 ) -> Arc<dyn ExecutionPlan> {
2680 Arc::new(TestMemoryExec::update_cache(Arc::new(
2681 TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2682 .unwrap()
2683 .try_with_sort_information(vec![sort_exprs])
2684 .unwrap(),
2685 )))
2686 }
2687}