datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
34use crate::hash_utils::create_hashes;
35use crate::metrics::BaselineMetrics;
36use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
37use crate::repartition::distributor_channels::{
38    channels, partition_aware_channels, DistributionReceiver, DistributionSender,
39};
40use crate::sorts::streaming_merge::StreamingMergeBuilder;
41use crate::stream::RecordBatchStreamAdapter;
42use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
43
44use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
45use arrow::compute::take_arrays;
46use arrow::datatypes::{SchemaRef, UInt32Type};
47use datafusion_common::config::ConfigOptions;
48use datafusion_common::stats::Precision;
49use datafusion_common::utils::transpose;
50use datafusion_common::{internal_err, ColumnStatistics, HashMap};
51use datafusion_common::{not_impl_err, DataFusionError, Result};
52use datafusion_common_runtime::SpawnedTask;
53use datafusion_execution::memory_pool::MemoryConsumer;
54use datafusion_execution::TaskContext;
55use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
56use datafusion_physical_expr_common::sort_expr::LexOrdering;
57
58use crate::filter_pushdown::{
59    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
60    FilterPushdownPropagation,
61};
62use futures::stream::Stream;
63use futures::{FutureExt, StreamExt, TryStreamExt};
64use log::trace;
65use parking_lot::Mutex;
66
67mod distributor_channels;
68
69type MaybeBatch = Option<Result<RecordBatch>>;
70type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
71type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
72
73#[derive(Debug)]
74struct ConsumingInputStreamsState {
75    /// Channels for sending batches from input partitions to output partitions.
76    /// Key is the partition number.
77    channels: HashMap<
78        usize,
79        (
80            InputPartitionsToCurrentPartitionSender,
81            InputPartitionsToCurrentPartitionReceiver,
82            SharedMemoryReservation,
83        ),
84    >,
85
86    /// Helper that ensures that that background job is killed once it is no longer needed.
87    abort_helper: Arc<Vec<SpawnedTask<()>>>,
88}
89
90/// Inner state of [`RepartitionExec`].
91enum RepartitionExecState {
92    /// Not initialized yet. This is the default state stored in the RepartitionExec node
93    /// upon instantiation.
94    NotInitialized,
95    /// Input streams are initialized, but they are still not being consumed. The node
96    /// transitions to this state when the arrow's RecordBatch stream is created in
97    /// RepartitionExec::execute(), but before any message is polled.
98    InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
99    /// The input streams are being consumed. The node transitions to this state when
100    /// the first message in the arrow's RecordBatch stream is consumed.
101    ConsumingInputStreams(ConsumingInputStreamsState),
102}
103
104impl Default for RepartitionExecState {
105    fn default() -> Self {
106        Self::NotInitialized
107    }
108}
109
110impl Debug for RepartitionExecState {
111    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112        match self {
113            RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
114            RepartitionExecState::InputStreamsInitialized(v) => {
115                write!(f, "InputStreamsInitialized({:?})", v.len())
116            }
117            RepartitionExecState::ConsumingInputStreams(v) => {
118                write!(f, "ConsumingInputStreams({v:?})")
119            }
120        }
121    }
122}
123
124impl RepartitionExecState {
125    fn ensure_input_streams_initialized(
126        &mut self,
127        input: Arc<dyn ExecutionPlan>,
128        metrics: ExecutionPlanMetricsSet,
129        output_partitions: usize,
130        ctx: Arc<TaskContext>,
131    ) -> Result<()> {
132        if !matches!(self, RepartitionExecState::NotInitialized) {
133            return Ok(());
134        }
135
136        let num_input_partitions = input.output_partitioning().partition_count();
137        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
138
139        for i in 0..num_input_partitions {
140            let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
141
142            let timer = metrics.fetch_time.timer();
143            let stream = input.execute(i, Arc::clone(&ctx))?;
144            timer.done();
145
146            streams_and_metrics.push((stream, metrics));
147        }
148        *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
149        Ok(())
150    }
151
152    fn consume_input_streams(
153        &mut self,
154        input: Arc<dyn ExecutionPlan>,
155        metrics: ExecutionPlanMetricsSet,
156        partitioning: Partitioning,
157        preserve_order: bool,
158        name: String,
159        context: Arc<TaskContext>,
160    ) -> Result<&mut ConsumingInputStreamsState> {
161        let streams_and_metrics = match self {
162            RepartitionExecState::NotInitialized => {
163                self.ensure_input_streams_initialized(
164                    input,
165                    metrics,
166                    partitioning.partition_count(),
167                    Arc::clone(&context),
168                )?;
169                let RepartitionExecState::InputStreamsInitialized(value) = self else {
170                    // This cannot happen, as ensure_input_streams_initialized() was just called,
171                    // but the compiler does not know.
172                    return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
173                };
174                value
175            }
176            RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
177            RepartitionExecState::InputStreamsInitialized(value) => value,
178        };
179
180        let num_input_partitions = streams_and_metrics.len();
181        let num_output_partitions = partitioning.partition_count();
182
183        let (txs, rxs) = if preserve_order {
184            let (txs, rxs) =
185                partition_aware_channels(num_input_partitions, num_output_partitions);
186            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
187            let txs = transpose(txs);
188            let rxs = transpose(rxs);
189            (txs, rxs)
190        } else {
191            // create one channel per *output* partition
192            // note we use a custom channel that ensures there is always data for each receiver
193            // but limits the amount of buffering if required.
194            let (txs, rxs) = channels(num_output_partitions);
195            // Clone sender for each input partitions
196            let txs = txs
197                .into_iter()
198                .map(|item| vec![item; num_input_partitions])
199                .collect::<Vec<_>>();
200            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
201            (txs, rxs)
202        };
203
204        let mut channels = HashMap::with_capacity(txs.len());
205        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
206            let reservation = Arc::new(Mutex::new(
207                MemoryConsumer::new(format!("{name}[{partition}]"))
208                    .register(context.memory_pool()),
209            ));
210            channels.insert(partition, (tx, rx, reservation));
211        }
212
213        // launch one async task per *input* partition
214        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
215        for (i, (stream, metrics)) in
216            std::mem::take(streams_and_metrics).into_iter().enumerate()
217        {
218            let txs: HashMap<_, _> = channels
219                .iter()
220                .map(|(partition, (tx, _rx, reservation))| {
221                    (*partition, (tx[i].clone(), Arc::clone(reservation)))
222                })
223                .collect();
224
225            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
226                stream,
227                txs.clone(),
228                partitioning.clone(),
229                metrics,
230            ));
231
232            // In a separate task, wait for each input to be done
233            // (and pass along any errors, including panic!s)
234            let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
235                input_task,
236                txs.into_iter()
237                    .map(|(partition, (tx, _reservation))| (partition, tx))
238                    .collect(),
239            ));
240            spawned_tasks.push(wait_for_task);
241        }
242        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
243            channels,
244            abort_helper: Arc::new(spawned_tasks),
245        });
246        match self {
247            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
248            _ => unreachable!(),
249        }
250    }
251}
252
253/// A utility that can be used to partition batches based on [`Partitioning`]
254pub struct BatchPartitioner {
255    state: BatchPartitionerState,
256    timer: metrics::Time,
257}
258
259enum BatchPartitionerState {
260    Hash {
261        random_state: ahash::RandomState,
262        exprs: Vec<Arc<dyn PhysicalExpr>>,
263        num_partitions: usize,
264        hash_buffer: Vec<u64>,
265    },
266    RoundRobin {
267        num_partitions: usize,
268        next_idx: usize,
269    },
270}
271
272impl BatchPartitioner {
273    /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`]
274    ///
275    /// The time spent repartitioning will be recorded to `timer`
276    pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
277        let state = match partitioning {
278            Partitioning::RoundRobinBatch(num_partitions) => {
279                BatchPartitionerState::RoundRobin {
280                    num_partitions,
281                    next_idx: 0,
282                }
283            }
284            Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
285                exprs,
286                num_partitions,
287                // Use fixed random hash
288                random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
289                hash_buffer: vec![],
290            },
291            other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
292        };
293
294        Ok(Self { state, timer })
295    }
296
297    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
298    /// based on the [`Partitioning`] specified on construction
299    ///
300    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
301    /// partition index. Any error returned by `f` will be immediately returned by this
302    /// function without attempting to publish further [`RecordBatch`]
303    ///
304    /// The time spent repartitioning, not including time spent in `f` will be recorded
305    /// to the [`metrics::Time`] provided on construction
306    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
307    where
308        F: FnMut(usize, RecordBatch) -> Result<()>,
309    {
310        self.partition_iter(batch)?.try_for_each(|res| match res {
311            Ok((partition, batch)) => f(partition, batch),
312            Err(e) => Err(e),
313        })
314    }
315
316    /// Actual implementation of [`partition`](Self::partition).
317    ///
318    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
319    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
320    /// this (so we don't need to clone the entire implementation).
321    fn partition_iter(
322        &mut self,
323        batch: RecordBatch,
324    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
325        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
326            match &mut self.state {
327                BatchPartitionerState::RoundRobin {
328                    num_partitions,
329                    next_idx,
330                } => {
331                    let idx = *next_idx;
332                    *next_idx = (*next_idx + 1) % *num_partitions;
333                    Box::new(std::iter::once(Ok((idx, batch))))
334                }
335                BatchPartitionerState::Hash {
336                    random_state,
337                    exprs,
338                    num_partitions: partitions,
339                    hash_buffer,
340                } => {
341                    // Tracking time required for distributing indexes across output partitions
342                    let timer = self.timer.timer();
343
344                    let arrays = exprs
345                        .iter()
346                        .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
347                        .collect::<Result<Vec<_>>>()?;
348
349                    hash_buffer.clear();
350                    hash_buffer.resize(batch.num_rows(), 0);
351
352                    create_hashes(&arrays, random_state, hash_buffer)?;
353
354                    let mut indices: Vec<_> = (0..*partitions)
355                        .map(|_| Vec::with_capacity(batch.num_rows()))
356                        .collect();
357
358                    for (index, hash) in hash_buffer.iter().enumerate() {
359                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
360                    }
361
362                    // Finished building index-arrays for output partitions
363                    timer.done();
364
365                    // Borrowing partitioner timer to prevent moving `self` to closure
366                    let partitioner_timer = &self.timer;
367                    let it = indices
368                        .into_iter()
369                        .enumerate()
370                        .filter_map(|(partition, indices)| {
371                            let indices: PrimitiveArray<UInt32Type> = indices.into();
372                            (!indices.is_empty()).then_some((partition, indices))
373                        })
374                        .map(move |(partition, indices)| {
375                            // Tracking time required for repartitioned batches construction
376                            let _timer = partitioner_timer.timer();
377
378                            // Produce batches based on indices
379                            let columns = take_arrays(batch.columns(), &indices, None)?;
380
381                            let mut options = RecordBatchOptions::new();
382                            options = options.with_row_count(Some(indices.len()));
383                            let batch = RecordBatch::try_new_with_options(
384                                batch.schema(),
385                                columns,
386                                &options,
387                            )
388                            .unwrap();
389
390                            Ok((partition, batch))
391                        });
392
393                    Box::new(it)
394                }
395            };
396
397        Ok(it)
398    }
399
400    // return the number of output partitions
401    fn num_partitions(&self) -> usize {
402        match self.state {
403            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
404            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
405        }
406    }
407}
408
409/// Maps `N` input partitions to `M` output partitions based on a
410/// [`Partitioning`] scheme.
411///
412/// # Background
413///
414/// DataFusion, like most other commercial systems, with the
415/// notable exception of DuckDB, uses the "Exchange Operator" based
416/// approach to parallelism which works well in practice given
417/// sufficient care in implementation.
418///
419/// DataFusion's planner picks the target number of partitions and
420/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
421/// of output partitions.
422///
423/// For example, given `target_partitions=3` (trying to use 3 cores)
424/// but scanning an input with 2 partitions, `RepartitionExec` can be
425/// used to get 3 even streams of `RecordBatch`es
426///
427///
428///```text
429///        ▲                  ▲                  ▲
430///        │                  │                  │
431///        │                  │                  │
432///        │                  │                  │
433///┌───────────────┐  ┌───────────────┐  ┌───────────────┐
434///│    GroupBy    │  │    GroupBy    │  │    GroupBy    │
435///│   (Partial)   │  │   (Partial)   │  │   (Partial)   │
436///└───────────────┘  └───────────────┘  └───────────────┘
437///        ▲                  ▲                  ▲
438///        └──────────────────┼──────────────────┘
439///                           │
440///              ┌─────────────────────────┐
441///              │     RepartitionExec     │
442///              │   (hash/round robin)    │
443///              └─────────────────────────┘
444///                         ▲   ▲
445///             ┌───────────┘   └───────────┐
446///             │                           │
447///             │                           │
448///        .─────────.                 .─────────.
449///     ,─'           '─.           ,─'           '─.
450///    ;      Input      :         ;      Input      :
451///    :   Partition 0   ;         :   Partition 1   ;
452///     ╲               ╱           ╲               ╱
453///      '─.         ,─'             '─.         ,─'
454///         `───────'                   `───────'
455///```
456///
457/// # Error Handling
458///
459/// If any of the input partitions return an error, the error is propagated to
460/// all output partitions and inputs are not polled again.
461///
462/// # Output Ordering
463///
464/// If more than one stream is being repartitioned, the output will be some
465/// arbitrary interleaving (and thus unordered) unless
466/// [`Self::with_preserve_order`] specifies otherwise.
467///
468/// # Footnote
469///
470/// The "Exchange Operator" was first described in the 1989 paper
471/// [Encapsulation of parallelism in the Volcano query processing
472/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
473/// which uses the term "Exchange" for the concept of repartitioning
474/// data across threads.
475#[derive(Debug, Clone)]
476pub struct RepartitionExec {
477    /// Input execution plan
478    input: Arc<dyn ExecutionPlan>,
479    /// Inner state that is initialized when the parent calls .execute() on this node
480    /// and consumed as soon as the parent starts consuming this node.
481    state: Arc<Mutex<RepartitionExecState>>,
482    /// Execution metrics
483    metrics: ExecutionPlanMetricsSet,
484    /// Boolean flag to decide whether to preserve ordering. If true means
485    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
486    preserve_order: bool,
487    /// Cache holding plan properties like equivalences, output partitioning etc.
488    cache: PlanProperties,
489}
490
491#[derive(Debug, Clone)]
492struct RepartitionMetrics {
493    /// Time in nanos to execute child operator and fetch batches
494    fetch_time: metrics::Time,
495    /// Repartitioning elapsed time in nanos
496    repartition_time: metrics::Time,
497    /// Time in nanos for sending resulting batches to channels.
498    ///
499    /// One metric per output partition.
500    send_time: Vec<metrics::Time>,
501}
502
503impl RepartitionMetrics {
504    pub fn new(
505        input_partition: usize,
506        num_output_partitions: usize,
507        metrics: &ExecutionPlanMetricsSet,
508    ) -> Self {
509        // Time in nanos to execute child operator and fetch batches
510        let fetch_time =
511            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
512
513        // Time in nanos to perform repartitioning
514        let repartition_time =
515            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
516
517        // Time in nanos for sending resulting batches to channels
518        let send_time = (0..num_output_partitions)
519            .map(|output_partition| {
520                let label =
521                    metrics::Label::new("outputPartition", output_partition.to_string());
522                MetricBuilder::new(metrics)
523                    .with_label(label)
524                    .subset_time("send_time", input_partition)
525            })
526            .collect();
527
528        Self {
529            fetch_time,
530            repartition_time,
531            send_time,
532        }
533    }
534}
535
536impl RepartitionExec {
537    /// Input execution plan
538    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
539        &self.input
540    }
541
542    /// Partitioning scheme to use
543    pub fn partitioning(&self) -> &Partitioning {
544        &self.cache.partitioning
545    }
546
547    /// Get preserve_order flag of the RepartitionExecutor
548    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
549    pub fn preserve_order(&self) -> bool {
550        self.preserve_order
551    }
552
553    /// Get name used to display this Exec
554    pub fn name(&self) -> &str {
555        "RepartitionExec"
556    }
557}
558
559impl DisplayAs for RepartitionExec {
560    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
561        match t {
562            DisplayFormatType::Default | DisplayFormatType::Verbose => {
563                write!(
564                    f,
565                    "{}: partitioning={}, input_partitions={}",
566                    self.name(),
567                    self.partitioning(),
568                    self.input.output_partitioning().partition_count()
569                )?;
570
571                if self.preserve_order {
572                    write!(f, ", preserve_order=true")?;
573                }
574
575                if let Some(sort_exprs) = self.sort_exprs() {
576                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
577                }
578                Ok(())
579            }
580            DisplayFormatType::TreeRender => {
581                writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
582
583                let input_partition_count =
584                    self.input.output_partitioning().partition_count();
585                let output_partition_count = self.partitioning().partition_count();
586                let input_to_output_partition_str =
587                    format!("{input_partition_count} -> {output_partition_count}");
588                writeln!(
589                    f,
590                    "partition_count(in->out)={input_to_output_partition_str}"
591                )?;
592
593                if self.preserve_order {
594                    writeln!(f, "preserve_order={}", self.preserve_order)?;
595                }
596                Ok(())
597            }
598        }
599    }
600}
601
602impl ExecutionPlan for RepartitionExec {
603    fn name(&self) -> &'static str {
604        "RepartitionExec"
605    }
606
607    /// Return a reference to Any that can be used for downcasting
608    fn as_any(&self) -> &dyn Any {
609        self
610    }
611
612    fn properties(&self) -> &PlanProperties {
613        &self.cache
614    }
615
616    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
617        vec![&self.input]
618    }
619
620    fn with_new_children(
621        self: Arc<Self>,
622        mut children: Vec<Arc<dyn ExecutionPlan>>,
623    ) -> Result<Arc<dyn ExecutionPlan>> {
624        let mut repartition = RepartitionExec::try_new(
625            children.swap_remove(0),
626            self.partitioning().clone(),
627        )?;
628        if self.preserve_order {
629            repartition = repartition.with_preserve_order();
630        }
631        Ok(Arc::new(repartition))
632    }
633
634    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
635        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
636    }
637
638    fn maintains_input_order(&self) -> Vec<bool> {
639        Self::maintains_input_order_helper(self.input(), self.preserve_order)
640    }
641
642    fn execute(
643        &self,
644        partition: usize,
645        context: Arc<TaskContext>,
646    ) -> Result<SendableRecordBatchStream> {
647        trace!(
648            "Start {}::execute for partition: {}",
649            self.name(),
650            partition
651        );
652
653        let input = Arc::clone(&self.input);
654        let partitioning = self.partitioning().clone();
655        let metrics = self.metrics.clone();
656        let preserve_order = self.sort_exprs().is_some();
657        let name = self.name().to_owned();
658        let schema = self.schema();
659        let schema_captured = Arc::clone(&schema);
660
661        // Get existing ordering to use for merging
662        let sort_exprs = self.sort_exprs().cloned();
663
664        let state = Arc::clone(&self.state);
665        if let Some(mut state) = state.try_lock() {
666            state.ensure_input_streams_initialized(
667                Arc::clone(&input),
668                metrics.clone(),
669                partitioning.partition_count(),
670                Arc::clone(&context),
671            )?;
672        }
673
674        let stream = futures::stream::once(async move {
675            let num_input_partitions = input.output_partitioning().partition_count();
676
677            // lock scope
678            let (mut rx, reservation, abort_helper) = {
679                // lock mutexes
680                let mut state = state.lock();
681                let state = state.consume_input_streams(
682                    Arc::clone(&input),
683                    metrics.clone(),
684                    partitioning,
685                    preserve_order,
686                    name.clone(),
687                    Arc::clone(&context),
688                )?;
689
690                // now return stream for the specified *output* partition which will
691                // read from the channel
692                let (_tx, rx, reservation) = state
693                    .channels
694                    .remove(&partition)
695                    .expect("partition not used yet");
696
697                (rx, reservation, Arc::clone(&state.abort_helper))
698            };
699
700            trace!(
701                "Before returning stream in {name}::execute for partition: {partition}"
702            );
703
704            if preserve_order {
705                // Store streams from all the input partitions:
706                let input_streams = rx
707                    .into_iter()
708                    .map(|receiver| {
709                        Box::pin(PerPartitionStream {
710                            schema: Arc::clone(&schema_captured),
711                            receiver,
712                            _drop_helper: Arc::clone(&abort_helper),
713                            reservation: Arc::clone(&reservation),
714                        }) as SendableRecordBatchStream
715                    })
716                    .collect::<Vec<_>>();
717                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
718
719                // Merge streams (while preserving ordering) coming from
720                // input partitions to this partition:
721                let fetch = None;
722                let merge_reservation =
723                    MemoryConsumer::new(format!("{name}[Merge {partition}]"))
724                        .register(context.memory_pool());
725                StreamingMergeBuilder::new()
726                    .with_streams(input_streams)
727                    .with_schema(schema_captured)
728                    .with_expressions(&sort_exprs.unwrap())
729                    .with_metrics(BaselineMetrics::new(&metrics, partition))
730                    .with_batch_size(context.session_config().batch_size())
731                    .with_fetch(fetch)
732                    .with_reservation(merge_reservation)
733                    .build()
734            } else {
735                Ok(Box::pin(RepartitionStream {
736                    num_input_partitions,
737                    num_input_partitions_processed: 0,
738                    schema: input.schema(),
739                    input: rx.swap_remove(0),
740                    _drop_helper: abort_helper,
741                    reservation,
742                }) as SendableRecordBatchStream)
743            }
744        })
745        .try_flatten();
746        let stream = RecordBatchStreamAdapter::new(schema, stream);
747        Ok(Box::pin(stream))
748    }
749
750    fn metrics(&self) -> Option<MetricsSet> {
751        Some(self.metrics.clone_inner())
752    }
753
754    fn statistics(&self) -> Result<Statistics> {
755        self.input.partition_statistics(None)
756    }
757
758    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
759        if let Some(partition) = partition {
760            let partition_count = self.partitioning().partition_count();
761            if partition_count == 0 {
762                return Ok(Statistics::new_unknown(&self.schema()));
763            }
764
765            if partition >= partition_count {
766                return internal_err!(
767                    "RepartitionExec invalid partition {} (expected less than {})",
768                    partition,
769                    self.partitioning().partition_count()
770                );
771            }
772
773            let mut stats = self.input.partition_statistics(None)?;
774
775            // Distribute statistics across partitions
776            stats.num_rows = stats
777                .num_rows
778                .get_value()
779                .map(|rows| Precision::Inexact(rows / partition_count))
780                .unwrap_or(Precision::Absent);
781            stats.total_byte_size = stats
782                .total_byte_size
783                .get_value()
784                .map(|bytes| Precision::Inexact(bytes / partition_count))
785                .unwrap_or(Precision::Absent);
786
787            // Make all column stats unknown
788            stats.column_statistics = stats
789                .column_statistics
790                .iter()
791                .map(|_| ColumnStatistics::new_unknown())
792                .collect();
793
794            Ok(stats)
795        } else {
796            self.input.partition_statistics(None)
797        }
798    }
799
800    fn cardinality_effect(&self) -> CardinalityEffect {
801        CardinalityEffect::Equal
802    }
803
804    fn try_swapping_with_projection(
805        &self,
806        projection: &ProjectionExec,
807    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
808        // If the projection does not narrow the schema, we should not try to push it down.
809        if projection.expr().len() >= projection.input().schema().fields().len() {
810            return Ok(None);
811        }
812
813        // If pushdown is not beneficial or applicable, break it.
814        if projection.benefits_from_input_partitioning()[0]
815            || !all_columns(projection.expr())
816        {
817            return Ok(None);
818        }
819
820        let new_projection = make_with_child(projection, self.input())?;
821
822        let new_partitioning = match self.partitioning() {
823            Partitioning::Hash(partitions, size) => {
824                let mut new_partitions = vec![];
825                for partition in partitions {
826                    let Some(new_partition) =
827                        update_expr(partition, projection.expr(), false)?
828                    else {
829                        return Ok(None);
830                    };
831                    new_partitions.push(new_partition);
832                }
833                Partitioning::Hash(new_partitions, *size)
834            }
835            others => others.clone(),
836        };
837
838        Ok(Some(Arc::new(RepartitionExec::try_new(
839            new_projection,
840            new_partitioning,
841        )?)))
842    }
843
844    fn gather_filters_for_pushdown(
845        &self,
846        _phase: FilterPushdownPhase,
847        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
848        _config: &ConfigOptions,
849    ) -> Result<FilterDescription> {
850        FilterDescription::from_children(parent_filters, &self.children())
851    }
852
853    fn handle_child_pushdown_result(
854        &self,
855        _phase: FilterPushdownPhase,
856        child_pushdown_result: ChildPushdownResult,
857        _config: &ConfigOptions,
858    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
859        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
860    }
861}
862
863impl RepartitionExec {
864    /// Create a new RepartitionExec, that produces output `partitioning`, and
865    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
866    /// for more details)
867    pub fn try_new(
868        input: Arc<dyn ExecutionPlan>,
869        partitioning: Partitioning,
870    ) -> Result<Self> {
871        let preserve_order = false;
872        let cache =
873            Self::compute_properties(&input, partitioning.clone(), preserve_order);
874        Ok(RepartitionExec {
875            input,
876            state: Default::default(),
877            metrics: ExecutionPlanMetricsSet::new(),
878            preserve_order,
879            cache,
880        })
881    }
882
883    fn maintains_input_order_helper(
884        input: &Arc<dyn ExecutionPlan>,
885        preserve_order: bool,
886    ) -> Vec<bool> {
887        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
888        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
889    }
890
891    fn eq_properties_helper(
892        input: &Arc<dyn ExecutionPlan>,
893        preserve_order: bool,
894    ) -> EquivalenceProperties {
895        // Equivalence Properties
896        let mut eq_properties = input.equivalence_properties().clone();
897        // If the ordering is lost, reset the ordering equivalence class:
898        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
899            eq_properties.clear_orderings();
900        }
901        // When there are more than one input partitions, they will be fused at the output.
902        // Therefore, remove per partition constants.
903        if input.output_partitioning().partition_count() > 1 {
904            eq_properties.clear_per_partition_constants();
905        }
906        eq_properties
907    }
908
909    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
910    fn compute_properties(
911        input: &Arc<dyn ExecutionPlan>,
912        partitioning: Partitioning,
913        preserve_order: bool,
914    ) -> PlanProperties {
915        PlanProperties::new(
916            Self::eq_properties_helper(input, preserve_order),
917            partitioning,
918            input.pipeline_behavior(),
919            input.boundedness(),
920        )
921        .with_scheduling_type(SchedulingType::Cooperative)
922        .with_evaluation_type(EvaluationType::Eager)
923    }
924
925    /// Specify if this repartitioning operation should preserve the order of
926    /// rows from its input when producing output. Preserving order is more
927    /// expensive at runtime, so should only be set if the output of this
928    /// operator can take advantage of it.
929    ///
930    /// If the input is not ordered, or has only one partition, this is a no op,
931    /// and the node remains a `RepartitionExec`.
932    pub fn with_preserve_order(mut self) -> Self {
933        self.preserve_order =
934                // If the input isn't ordered, there is no ordering to preserve
935                self.input.output_ordering().is_some() &&
936                // if there is only one input partition, merging is not required
937                // to maintain order
938                self.input.output_partitioning().partition_count() > 1;
939        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
940        self.cache = self.cache.with_eq_properties(eq_properties);
941        self
942    }
943
944    /// Return the sort expressions that are used to merge
945    fn sort_exprs(&self) -> Option<&LexOrdering> {
946        if self.preserve_order {
947            self.input.output_ordering()
948        } else {
949            None
950        }
951    }
952
953    /// Pulls data from the specified input plan, feeding it to the
954    /// output partitions based on the desired partitioning
955    ///
956    /// txs hold the output sending channels for each output partition
957    async fn pull_from_input(
958        mut stream: SendableRecordBatchStream,
959        mut output_channels: HashMap<
960            usize,
961            (DistributionSender<MaybeBatch>, SharedMemoryReservation),
962        >,
963        partitioning: Partitioning,
964        metrics: RepartitionMetrics,
965    ) -> Result<()> {
966        let mut partitioner =
967            BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
968
969        // While there are still outputs to send to, keep pulling inputs
970        let mut batches_until_yield = partitioner.num_partitions();
971        while !output_channels.is_empty() {
972            // fetch the next batch
973            let timer = metrics.fetch_time.timer();
974            let result = stream.next().await;
975            timer.done();
976
977            // Input is done
978            let batch = match result {
979                Some(result) => result?,
980                None => break,
981            };
982
983            for res in partitioner.partition_iter(batch)? {
984                let (partition, batch) = res?;
985                let size = batch.get_array_memory_size();
986
987                let timer = metrics.send_time[partition].timer();
988                // if there is still a receiver, send to it
989                if let Some((tx, reservation)) = output_channels.get_mut(&partition) {
990                    reservation.lock().try_grow(size)?;
991
992                    if tx.send(Some(Ok(batch))).await.is_err() {
993                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
994                        reservation.lock().shrink(size);
995                        output_channels.remove(&partition);
996                    }
997                }
998                timer.done();
999            }
1000
1001            // If the input stream is endless, we may spin forever and
1002            // never yield back to tokio.  See
1003            // https://github.com/apache/datafusion/issues/5278.
1004            //
1005            // However, yielding on every batch causes a bottleneck
1006            // when running with multiple cores. See
1007            // https://github.com/apache/datafusion/issues/6290
1008            //
1009            // Thus, heuristically yield after producing num_partition
1010            // batches
1011            //
1012            // In round robin this is ideal as each input will get a
1013            // new batch. In hash partitioning it may yield too often
1014            // on uneven distributions even if some partition can not
1015            // make progress, but parallelism is going to be limited
1016            // in that case anyways
1017            if batches_until_yield == 0 {
1018                tokio::task::yield_now().await;
1019                batches_until_yield = partitioner.num_partitions();
1020            } else {
1021                batches_until_yield -= 1;
1022            }
1023        }
1024
1025        Ok(())
1026    }
1027
1028    /// Waits for `input_task` which is consuming one of the inputs to
1029    /// complete. Upon each successful completion, sends a `None` to
1030    /// each of the output tx channels to signal one of the inputs is
1031    /// complete. Upon error, propagates the errors to all output tx
1032    /// channels.
1033    async fn wait_for_task(
1034        input_task: SpawnedTask<Result<()>>,
1035        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1036    ) {
1037        // wait for completion, and propagate error
1038        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
1039
1040        match input_task.join().await {
1041            // Error in joining task
1042            Err(e) => {
1043                let e = Arc::new(e);
1044
1045                for (_, tx) in txs {
1046                    let err = Err(DataFusionError::Context(
1047                        "Join Error".to_string(),
1048                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1049                    ));
1050                    tx.send(Some(err)).await.ok();
1051                }
1052            }
1053            // Error from running input task
1054            Ok(Err(e)) => {
1055                // send the same Arc'd error to all output partitions
1056                let e = Arc::new(e);
1057
1058                for (_, tx) in txs {
1059                    // wrap it because need to send error to all output partitions
1060                    let err = Err(DataFusionError::from(&e));
1061                    tx.send(Some(err)).await.ok();
1062                }
1063            }
1064            // Input task completed successfully
1065            Ok(Ok(())) => {
1066                // notify each output partition that this input partition has no more data
1067                for (_, tx) in txs {
1068                    tx.send(None).await.ok();
1069                }
1070            }
1071        }
1072    }
1073}
1074
1075struct RepartitionStream {
1076    /// Number of input partitions that will be sending batches to this output channel
1077    num_input_partitions: usize,
1078
1079    /// Number of input partitions that have finished sending batches to this output channel
1080    num_input_partitions_processed: usize,
1081
1082    /// Schema wrapped by Arc
1083    schema: SchemaRef,
1084
1085    /// channel containing the repartitioned batches
1086    input: DistributionReceiver<MaybeBatch>,
1087
1088    /// Handle to ensure background tasks are killed when no longer needed.
1089    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1090
1091    /// Memory reservation.
1092    reservation: SharedMemoryReservation,
1093}
1094
1095impl Stream for RepartitionStream {
1096    type Item = Result<RecordBatch>;
1097
1098    fn poll_next(
1099        mut self: Pin<&mut Self>,
1100        cx: &mut Context<'_>,
1101    ) -> Poll<Option<Self::Item>> {
1102        loop {
1103            match self.input.recv().poll_unpin(cx) {
1104                Poll::Ready(Some(Some(v))) => {
1105                    if let Ok(batch) = &v {
1106                        self.reservation
1107                            .lock()
1108                            .shrink(batch.get_array_memory_size());
1109                    }
1110
1111                    return Poll::Ready(Some(v));
1112                }
1113                Poll::Ready(Some(None)) => {
1114                    self.num_input_partitions_processed += 1;
1115
1116                    if self.num_input_partitions == self.num_input_partitions_processed {
1117                        // all input partitions have finished sending batches
1118                        return Poll::Ready(None);
1119                    } else {
1120                        // other partitions still have data to send
1121                        continue;
1122                    }
1123                }
1124                Poll::Ready(None) => {
1125                    return Poll::Ready(None);
1126                }
1127                Poll::Pending => {
1128                    return Poll::Pending;
1129                }
1130            }
1131        }
1132    }
1133}
1134
1135impl RecordBatchStream for RepartitionStream {
1136    /// Get the schema
1137    fn schema(&self) -> SchemaRef {
1138        Arc::clone(&self.schema)
1139    }
1140}
1141
1142/// This struct converts a receiver to a stream.
1143/// Receiver receives data on an SPSC channel.
1144struct PerPartitionStream {
1145    /// Schema wrapped by Arc
1146    schema: SchemaRef,
1147
1148    /// channel containing the repartitioned batches
1149    receiver: DistributionReceiver<MaybeBatch>,
1150
1151    /// Handle to ensure background tasks are killed when no longer needed.
1152    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1153
1154    /// Memory reservation.
1155    reservation: SharedMemoryReservation,
1156}
1157
1158impl Stream for PerPartitionStream {
1159    type Item = Result<RecordBatch>;
1160
1161    fn poll_next(
1162        mut self: Pin<&mut Self>,
1163        cx: &mut Context<'_>,
1164    ) -> Poll<Option<Self::Item>> {
1165        match self.receiver.recv().poll_unpin(cx) {
1166            Poll::Ready(Some(Some(v))) => {
1167                if let Ok(batch) = &v {
1168                    self.reservation
1169                        .lock()
1170                        .shrink(batch.get_array_memory_size());
1171                }
1172                Poll::Ready(Some(v))
1173            }
1174            Poll::Ready(Some(None)) => {
1175                // Input partition has finished sending batches
1176                Poll::Ready(None)
1177            }
1178            Poll::Ready(None) => Poll::Ready(None),
1179            Poll::Pending => Poll::Pending,
1180        }
1181    }
1182}
1183
1184impl RecordBatchStream for PerPartitionStream {
1185    /// Get the schema
1186    fn schema(&self) -> SchemaRef {
1187        Arc::clone(&self.schema)
1188    }
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193    use std::collections::HashSet;
1194
1195    use super::*;
1196    use crate::test::TestMemoryExec;
1197    use crate::{
1198        test::{
1199            assert_is_pending,
1200            exec::{
1201                assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1202                ErrorExec, MockExec,
1203            },
1204        },
1205        {collect, expressions::col},
1206    };
1207
1208    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1209    use arrow::datatypes::{DataType, Field, Schema};
1210    use datafusion_common::cast::as_string_array;
1211    use datafusion_common::test_util::batches_to_sort_string;
1212    use datafusion_common::{arrow_datafusion_err, exec_err};
1213    use datafusion_common_runtime::JoinSet;
1214    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1215    use insta::assert_snapshot;
1216    use itertools::Itertools;
1217
1218    #[tokio::test]
1219    async fn one_to_many_round_robin() -> Result<()> {
1220        // define input partitions
1221        let schema = test_schema();
1222        let partition = create_vec_batches(50);
1223        let partitions = vec![partition];
1224
1225        // repartition from 1 input to 4 output
1226        let output_partitions =
1227            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1228
1229        assert_eq!(4, output_partitions.len());
1230        assert_eq!(13, output_partitions[0].len());
1231        assert_eq!(13, output_partitions[1].len());
1232        assert_eq!(12, output_partitions[2].len());
1233        assert_eq!(12, output_partitions[3].len());
1234
1235        Ok(())
1236    }
1237
1238    #[tokio::test]
1239    async fn many_to_one_round_robin() -> Result<()> {
1240        // define input partitions
1241        let schema = test_schema();
1242        let partition = create_vec_batches(50);
1243        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1244
1245        // repartition from 3 input to 1 output
1246        let output_partitions =
1247            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1248
1249        assert_eq!(1, output_partitions.len());
1250        assert_eq!(150, output_partitions[0].len());
1251
1252        Ok(())
1253    }
1254
1255    #[tokio::test]
1256    async fn many_to_many_round_robin() -> Result<()> {
1257        // define input partitions
1258        let schema = test_schema();
1259        let partition = create_vec_batches(50);
1260        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1261
1262        // repartition from 3 input to 5 output
1263        let output_partitions =
1264            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1265
1266        assert_eq!(5, output_partitions.len());
1267        assert_eq!(30, output_partitions[0].len());
1268        assert_eq!(30, output_partitions[1].len());
1269        assert_eq!(30, output_partitions[2].len());
1270        assert_eq!(30, output_partitions[3].len());
1271        assert_eq!(30, output_partitions[4].len());
1272
1273        Ok(())
1274    }
1275
1276    #[tokio::test]
1277    async fn many_to_many_hash_partition() -> Result<()> {
1278        // define input partitions
1279        let schema = test_schema();
1280        let partition = create_vec_batches(50);
1281        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1282
1283        let output_partitions = repartition(
1284            &schema,
1285            partitions,
1286            Partitioning::Hash(vec![col("c0", &schema)?], 8),
1287        )
1288        .await?;
1289
1290        let total_rows: usize = output_partitions
1291            .iter()
1292            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1293            .sum();
1294
1295        assert_eq!(8, output_partitions.len());
1296        assert_eq!(total_rows, 8 * 50 * 3);
1297
1298        Ok(())
1299    }
1300
1301    fn test_schema() -> Arc<Schema> {
1302        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1303    }
1304
1305    async fn repartition(
1306        schema: &SchemaRef,
1307        input_partitions: Vec<Vec<RecordBatch>>,
1308        partitioning: Partitioning,
1309    ) -> Result<Vec<Vec<RecordBatch>>> {
1310        let task_ctx = Arc::new(TaskContext::default());
1311        // create physical plan
1312        let exec =
1313            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1314        let exec = RepartitionExec::try_new(exec, partitioning)?;
1315
1316        // execute and collect results
1317        let mut output_partitions = vec![];
1318        for i in 0..exec.partitioning().partition_count() {
1319            // execute this *output* partition and collect all batches
1320            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1321            let mut batches = vec![];
1322            while let Some(result) = stream.next().await {
1323                batches.push(result?);
1324            }
1325            output_partitions.push(batches);
1326        }
1327        Ok(output_partitions)
1328    }
1329
1330    #[tokio::test]
1331    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1332        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1333            SpawnedTask::spawn(async move {
1334                // define input partitions
1335                let schema = test_schema();
1336                let partition = create_vec_batches(50);
1337                let partitions =
1338                    vec![partition.clone(), partition.clone(), partition.clone()];
1339
1340                // repartition from 3 input to 5 output
1341                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1342            });
1343
1344        let output_partitions = handle.join().await.unwrap().unwrap();
1345
1346        assert_eq!(5, output_partitions.len());
1347        assert_eq!(30, output_partitions[0].len());
1348        assert_eq!(30, output_partitions[1].len());
1349        assert_eq!(30, output_partitions[2].len());
1350        assert_eq!(30, output_partitions[3].len());
1351        assert_eq!(30, output_partitions[4].len());
1352
1353        Ok(())
1354    }
1355
1356    #[tokio::test]
1357    async fn unsupported_partitioning() {
1358        let task_ctx = Arc::new(TaskContext::default());
1359        // have to send at least one batch through to provoke error
1360        let batch = RecordBatch::try_from_iter(vec![(
1361            "my_awesome_field",
1362            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1363        )])
1364        .unwrap();
1365
1366        let schema = batch.schema();
1367        let input = MockExec::new(vec![Ok(batch)], schema);
1368        // This generates an error (partitioning type not supported)
1369        // but only after the plan is executed. The error should be
1370        // returned and no results produced
1371        let partitioning = Partitioning::UnknownPartitioning(1);
1372        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1373        let output_stream = exec.execute(0, task_ctx).unwrap();
1374
1375        // Expect that an error is returned
1376        let result_string = crate::common::collect(output_stream)
1377            .await
1378            .unwrap_err()
1379            .to_string();
1380        assert!(
1381            result_string
1382                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1383            "actual: {result_string}"
1384        );
1385    }
1386
1387    #[tokio::test]
1388    async fn error_for_input_exec() {
1389        // This generates an error on a call to execute. The error
1390        // should be returned and no results produced.
1391
1392        let task_ctx = Arc::new(TaskContext::default());
1393        let input = ErrorExec::new();
1394        let partitioning = Partitioning::RoundRobinBatch(1);
1395        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1396
1397        // Expect that an error is returned
1398        let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1399
1400        assert!(
1401            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1402            "actual: {result_string}"
1403        );
1404    }
1405
1406    #[tokio::test]
1407    async fn repartition_with_error_in_stream() {
1408        let task_ctx = Arc::new(TaskContext::default());
1409        let batch = RecordBatch::try_from_iter(vec![(
1410            "my_awesome_field",
1411            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1412        )])
1413        .unwrap();
1414
1415        // input stream returns one good batch and then one error. The
1416        // error should be returned.
1417        let err = exec_err!("bad data error");
1418
1419        let schema = batch.schema();
1420        let input = MockExec::new(vec![Ok(batch), err], schema);
1421        let partitioning = Partitioning::RoundRobinBatch(1);
1422        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1423
1424        // Note: this should pass (the stream can be created) but the
1425        // error when the input is executed should get passed back
1426        let output_stream = exec.execute(0, task_ctx).unwrap();
1427
1428        // Expect that an error is returned
1429        let result_string = crate::common::collect(output_stream)
1430            .await
1431            .unwrap_err()
1432            .to_string();
1433        assert!(
1434            result_string.contains("bad data error"),
1435            "actual: {result_string}"
1436        );
1437    }
1438
1439    #[tokio::test]
1440    async fn repartition_with_delayed_stream() {
1441        let task_ctx = Arc::new(TaskContext::default());
1442        let batch1 = RecordBatch::try_from_iter(vec![(
1443            "my_awesome_field",
1444            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1445        )])
1446        .unwrap();
1447
1448        let batch2 = RecordBatch::try_from_iter(vec![(
1449            "my_awesome_field",
1450            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1451        )])
1452        .unwrap();
1453
1454        // The mock exec doesn't return immediately (instead it
1455        // requires the input to wait at least once)
1456        let schema = batch1.schema();
1457        let expected_batches = vec![batch1.clone(), batch2.clone()];
1458        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1459        let partitioning = Partitioning::RoundRobinBatch(1);
1460
1461        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1462
1463        assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
1464        +------------------+
1465        | my_awesome_field |
1466        +------------------+
1467        | bar              |
1468        | baz              |
1469        | foo              |
1470        | frob             |
1471        +------------------+
1472        ");
1473
1474        let output_stream = exec.execute(0, task_ctx).unwrap();
1475        let batches = crate::common::collect(output_stream).await.unwrap();
1476
1477        assert_snapshot!(batches_to_sort_string(&batches), @r"
1478        +------------------+
1479        | my_awesome_field |
1480        +------------------+
1481        | bar              |
1482        | baz              |
1483        | foo              |
1484        | frob             |
1485        +------------------+
1486        ");
1487    }
1488
1489    #[tokio::test]
1490    async fn robin_repartition_with_dropping_output_stream() {
1491        let task_ctx = Arc::new(TaskContext::default());
1492        let partitioning = Partitioning::RoundRobinBatch(2);
1493        // The barrier exec waits to be pinged
1494        // requires the input to wait at least once)
1495        let input = Arc::new(make_barrier_exec());
1496
1497        // partition into two output streams
1498        let exec = RepartitionExec::try_new(
1499            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1500            partitioning,
1501        )
1502        .unwrap();
1503
1504        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1505        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1506
1507        // now, purposely drop output stream 0
1508        // *before* any outputs are produced
1509        drop(output_stream0);
1510
1511        // Now, start sending input
1512        let mut background_task = JoinSet::new();
1513        background_task.spawn(async move {
1514            input.wait().await;
1515        });
1516
1517        // output stream 1 should *not* error and have one of the input batches
1518        let batches = crate::common::collect(output_stream1).await.unwrap();
1519
1520        assert_snapshot!(batches_to_sort_string(&batches), @r#"
1521            +------------------+
1522            | my_awesome_field |
1523            +------------------+
1524            | baz              |
1525            | frob             |
1526            | gaz              |
1527            | grob             |
1528            +------------------+
1529            "#);
1530    }
1531
1532    #[tokio::test]
1533    // As the hash results might be different on different platforms or
1534    // with different compilers, we will compare the same execution with
1535    // and without dropping the output stream.
1536    async fn hash_repartition_with_dropping_output_stream() {
1537        let task_ctx = Arc::new(TaskContext::default());
1538        let partitioning = Partitioning::Hash(
1539            vec![Arc::new(crate::expressions::Column::new(
1540                "my_awesome_field",
1541                0,
1542            ))],
1543            2,
1544        );
1545
1546        // We first collect the results without dropping the output stream.
1547        let input = Arc::new(make_barrier_exec());
1548        let exec = RepartitionExec::try_new(
1549            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1550            partitioning.clone(),
1551        )
1552        .unwrap();
1553        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1554        let mut background_task = JoinSet::new();
1555        background_task.spawn(async move {
1556            input.wait().await;
1557        });
1558        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1559
1560        // run some checks on the result
1561        let items_vec = str_batches_to_vec(&batches_without_drop);
1562        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1563        assert_eq!(items_vec.len(), items_set.len());
1564        let source_str_set: HashSet<&str> =
1565            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1566                .iter()
1567                .copied()
1568                .collect();
1569        assert_eq!(items_set.difference(&source_str_set).count(), 0);
1570
1571        // Now do the same but dropping the stream before waiting for the barrier
1572        let input = Arc::new(make_barrier_exec());
1573        let exec = RepartitionExec::try_new(
1574            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1575            partitioning,
1576        )
1577        .unwrap();
1578        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1579        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1580        // now, purposely drop output stream 0
1581        // *before* any outputs are produced
1582        drop(output_stream0);
1583        let mut background_task = JoinSet::new();
1584        background_task.spawn(async move {
1585            input.wait().await;
1586        });
1587        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1588
1589        fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1590            batch
1591                .into_iter()
1592                .sorted_by_key(|b| format!("{b:?}"))
1593                .collect()
1594        }
1595
1596        assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1597    }
1598
1599    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1600        batches
1601            .iter()
1602            .flat_map(|batch| {
1603                assert_eq!(batch.columns().len(), 1);
1604                let string_array = as_string_array(batch.column(0))
1605                    .expect("Unexpected type for repartitioned batch");
1606
1607                string_array
1608                    .iter()
1609                    .map(|v| v.expect("Unexpected null"))
1610                    .collect::<Vec<_>>()
1611            })
1612            .collect::<Vec<_>>()
1613    }
1614
1615    /// Create a BarrierExec that returns two partitions of two batches each
1616    fn make_barrier_exec() -> BarrierExec {
1617        let batch1 = RecordBatch::try_from_iter(vec![(
1618            "my_awesome_field",
1619            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1620        )])
1621        .unwrap();
1622
1623        let batch2 = RecordBatch::try_from_iter(vec![(
1624            "my_awesome_field",
1625            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1626        )])
1627        .unwrap();
1628
1629        let batch3 = RecordBatch::try_from_iter(vec![(
1630            "my_awesome_field",
1631            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1632        )])
1633        .unwrap();
1634
1635        let batch4 = RecordBatch::try_from_iter(vec![(
1636            "my_awesome_field",
1637            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1638        )])
1639        .unwrap();
1640
1641        // The barrier exec waits to be pinged
1642        // requires the input to wait at least once)
1643        let schema = batch1.schema();
1644        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1645    }
1646
1647    #[tokio::test]
1648    async fn test_drop_cancel() -> Result<()> {
1649        let task_ctx = Arc::new(TaskContext::default());
1650        let schema =
1651            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1652
1653        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1654        let refs = blocking_exec.refs();
1655        let repartition_exec = Arc::new(RepartitionExec::try_new(
1656            blocking_exec,
1657            Partitioning::UnknownPartitioning(1),
1658        )?);
1659
1660        let fut = collect(repartition_exec, task_ctx);
1661        let mut fut = fut.boxed();
1662
1663        assert_is_pending(&mut fut);
1664        drop(fut);
1665        assert_strong_count_converges_to_zero(refs).await;
1666
1667        Ok(())
1668    }
1669
1670    #[tokio::test]
1671    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1672        let task_ctx = Arc::new(TaskContext::default());
1673        let batch = RecordBatch::try_from_iter(vec![(
1674            "a",
1675            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1676        )])
1677        .unwrap();
1678        let partitioning = Partitioning::Hash(
1679            vec![Arc::new(crate::expressions::Column::new("a", 0))],
1680            2,
1681        );
1682        let schema = batch.schema();
1683        let input = MockExec::new(vec![Ok(batch)], schema);
1684        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1685        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1686        let batch0 = crate::common::collect(output_stream0).await.unwrap();
1687        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1688        let batch1 = crate::common::collect(output_stream1).await.unwrap();
1689        assert!(batch0.is_empty() || batch1.is_empty());
1690        Ok(())
1691    }
1692
1693    #[tokio::test]
1694    async fn oom() -> Result<()> {
1695        // define input partitions
1696        let schema = test_schema();
1697        let partition = create_vec_batches(50);
1698        let input_partitions = vec![partition];
1699        let partitioning = Partitioning::RoundRobinBatch(4);
1700
1701        // setup up context
1702        let runtime = RuntimeEnvBuilder::default()
1703            .with_memory_limit(1, 1.0)
1704            .build_arc()?;
1705
1706        let task_ctx = TaskContext::default().with_runtime(runtime);
1707        let task_ctx = Arc::new(task_ctx);
1708
1709        // create physical plan
1710        let exec =
1711            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1712        let exec = RepartitionExec::try_new(exec, partitioning)?;
1713
1714        // pull partitions
1715        for i in 0..exec.partitioning().partition_count() {
1716            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1717            let err =
1718                arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into());
1719            let err = err.find_root();
1720            assert!(
1721                matches!(err, DataFusionError::ResourcesExhausted(_)),
1722                "Wrong error type: {err}",
1723            );
1724        }
1725
1726        Ok(())
1727    }
1728
1729    /// Create vector batches
1730    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
1731        let batch = create_batch();
1732        (0..n).map(|_| batch.clone()).collect()
1733    }
1734
1735    /// Create batch
1736    fn create_batch() -> RecordBatch {
1737        let schema = test_schema();
1738        RecordBatch::try_new(
1739            schema,
1740            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
1741        )
1742        .unwrap()
1743    }
1744}
1745
1746#[cfg(test)]
1747mod test {
1748    use arrow::compute::SortOptions;
1749    use arrow::datatypes::{DataType, Field, Schema};
1750
1751    use super::*;
1752    use crate::test::TestMemoryExec;
1753    use crate::union::UnionExec;
1754
1755    use datafusion_physical_expr::expressions::col;
1756    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1757
1758    /// Asserts that the plan is as expected
1759    ///
1760    /// `$EXPECTED_PLAN_LINES`: input plan
1761    /// `$PLAN`: the plan to optimized
1762    ///
1763    macro_rules! assert_plan {
1764        ($EXPECTED_PLAN_LINES: expr,  $PLAN: expr) => {
1765            let physical_plan = $PLAN;
1766            let formatted = crate::displayable(&physical_plan).indent(true).to_string();
1767            let actual: Vec<&str> = formatted.trim().lines().collect();
1768
1769            let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES
1770                .iter().map(|s| *s).collect();
1771
1772            assert_eq!(
1773                expected_plan_lines, actual,
1774                "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n"
1775            );
1776        };
1777    }
1778
1779    #[tokio::test]
1780    async fn test_preserve_order() -> Result<()> {
1781        let schema = test_schema();
1782        let sort_exprs = sort_exprs(&schema);
1783        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
1784        let source2 = sorted_memory_exec(&schema, sort_exprs);
1785        // output has multiple partitions, and is sorted
1786        let union = UnionExec::new(vec![source1, source2]);
1787        let exec =
1788            RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1789                .unwrap()
1790                .with_preserve_order();
1791
1792        // Repartition should preserve order
1793        let expected_plan = [
1794            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC",
1795            "  UnionExec",
1796            "    DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1797            "    DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1798        ];
1799        assert_plan!(expected_plan, exec);
1800        Ok(())
1801    }
1802
1803    #[tokio::test]
1804    async fn test_preserve_order_one_partition() -> Result<()> {
1805        let schema = test_schema();
1806        let sort_exprs = sort_exprs(&schema);
1807        let source = sorted_memory_exec(&schema, sort_exprs);
1808        // output is sorted, but has only a single partition, so no need to sort
1809        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))
1810            .unwrap()
1811            .with_preserve_order();
1812
1813        // Repartition should not preserve order
1814        let expected_plan = [
1815            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1",
1816            "  DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1817        ];
1818        assert_plan!(expected_plan, exec);
1819        Ok(())
1820    }
1821
1822    #[tokio::test]
1823    async fn test_preserve_order_input_not_sorted() -> Result<()> {
1824        let schema = test_schema();
1825        let source1 = memory_exec(&schema);
1826        let source2 = memory_exec(&schema);
1827        // output has multiple partitions, but is not sorted
1828        let union = UnionExec::new(vec![source1, source2]);
1829        let exec =
1830            RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1831                .unwrap()
1832                .with_preserve_order();
1833
1834        // Repartition should not preserve order, as there is no order to preserve
1835        let expected_plan = [
1836            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
1837            "  UnionExec",
1838            "    DataSourceExec: partitions=1, partition_sizes=[0]",
1839            "    DataSourceExec: partitions=1, partition_sizes=[0]",
1840        ];
1841        assert_plan!(expected_plan, exec);
1842        Ok(())
1843    }
1844
1845    fn test_schema() -> Arc<Schema> {
1846        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1847    }
1848
1849    fn sort_exprs(schema: &Schema) -> LexOrdering {
1850        [PhysicalSortExpr {
1851            expr: col("c0", schema).unwrap(),
1852            options: SortOptions::default(),
1853        }]
1854        .into()
1855    }
1856
1857    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
1858        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
1859    }
1860
1861    fn sorted_memory_exec(
1862        schema: &SchemaRef,
1863        sort_exprs: LexOrdering,
1864    ) -> Arc<dyn ExecutionPlan> {
1865        Arc::new(TestMemoryExec::update_cache(Arc::new(
1866            TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
1867                .unwrap()
1868                .try_with_sort_information(vec![sort_exprs])
1869                .unwrap(),
1870        )))
1871    }
1872}