Skip to main content

datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::coalesce::LimitedBatchCoalescer;
34use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
35use crate::hash_utils::create_hashes;
36use crate::metrics::{BaselineMetrics, SpillMetrics};
37use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
38use crate::sorts::streaming_merge::StreamingMergeBuilder;
39use crate::spill::spill_manager::SpillManager;
40use crate::spill::spill_pool::{self, SpillPoolWriter};
41use crate::stream::RecordBatchStreamAdapter;
42use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
43
44use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
45use arrow::compute::take_arrays;
46use arrow::datatypes::{SchemaRef, UInt32Type};
47use datafusion_common::config::ConfigOptions;
48use datafusion_common::stats::Precision;
49use datafusion_common::utils::transpose;
50use datafusion_common::{
51    ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err,
52};
53use datafusion_common::{Result, not_impl_err};
54use datafusion_common_runtime::SpawnedTask;
55use datafusion_execution::TaskContext;
56use datafusion_execution::memory_pool::MemoryConsumer;
57use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
58use datafusion_physical_expr_common::sort_expr::LexOrdering;
59
60use crate::filter_pushdown::{
61    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
62    FilterPushdownPropagation,
63};
64use crate::joins::SeededRandomState;
65use crate::sort_pushdown::SortOrderPushdownResult;
66use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
67use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
68use futures::stream::Stream;
69use futures::{FutureExt, StreamExt, TryStreamExt, ready};
70use log::trace;
71use parking_lot::Mutex;
72
73mod distributor_channels;
74use distributor_channels::{
75    DistributionReceiver, DistributionSender, channels, partition_aware_channels,
76};
77
78/// A batch in the repartition queue - either in memory or spilled to disk.
79///
80/// This enum represents the two states a batch can be in during repartitioning.
81/// The decision to spill is made based on memory availability when sending a batch
82/// to an output partition.
83///
84/// # Batch Flow with Spilling
85///
86/// ```text
87/// Input Stream ──▶ Partition Logic ──▶ try_grow()
88///                                            │
89///                            ┌───────────────┴────────────────┐
90///                            │                                │
91///                            ▼                                ▼
92///                   try_grow() succeeds            try_grow() fails
93///                   (Memory Available)              (Memory Pressure)
94///                            │                                │
95///                            ▼                                ▼
96///                  RepartitionBatch::Memory         spill_writer.push_batch()
97///                  (batch held in memory)           (batch written to disk)
98///                            │                                │
99///                            │                                ▼
100///                            │                      RepartitionBatch::Spilled
101///                            │                      (marker - no batch data)
102///                            │                                │
103///                            └────────┬───────────────────────┘
104///                                     │
105///                                     ▼
106///                              Send to channel
107///                                     │
108///                                     ▼
109///                            Output Stream (poll)
110///                                     │
111///                      ┌──────────────┴─────────────┐
112///                      │                            │
113///                      ▼                            ▼
114///         RepartitionBatch::Memory      RepartitionBatch::Spilled
115///         Return batch immediately       Poll spill_stream (blocks)
116///                      │                            │
117///                      └────────┬───────────────────┘
118///                               │
119///                               ▼
120///                          Return batch
121///                    (FIFO order preserved)
122/// ```
123///
124/// See [`RepartitionExec`] for overall architecture and [`StreamState`] for
125/// the state machine that handles reading these batches.
126#[derive(Debug)]
127enum RepartitionBatch {
128    /// Batch held in memory (counts against memory reservation)
129    Memory(RecordBatch),
130    /// Marker indicating a batch was spilled to the partition's SpillPool.
131    /// The actual batch can be retrieved by reading from the SpillPoolStream.
132    /// This variant contains no data itself - it's just a signal to the reader
133    /// to fetch the next batch from the spill stream.
134    Spilled,
135}
136
137type MaybeBatch = Option<Result<RepartitionBatch>>;
138type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
139type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
140
141/// Output channel with its associated memory reservation and spill writer
142struct OutputChannel {
143    sender: DistributionSender<MaybeBatch>,
144    reservation: SharedMemoryReservation,
145    spill_writer: SpillPoolWriter,
146}
147
148/// Channels and resources for a single output partition.
149///
150/// Each output partition has channels to receive data from all input partitions.
151/// To handle memory pressure, each (input, output) pair gets its own
152/// [`SpillPool`](crate::spill::spill_pool) channel via [`spill_pool::channel`].
153///
154/// # Structure
155///
156/// For an output partition receiving from N input partitions:
157/// - `tx`: N senders (one per input) for sending batches to this output
158/// - `rx`: N receivers (one per input) for receiving batches at this output
159/// - `spill_writers`: N spill writers (one per input) for writing spilled data
160/// - `spill_readers`: N spill readers (one per input) for reading spilled data
161///
162/// This 1:1 mapping between input partitions and spill channels ensures that
163/// batches from each input are processed in FIFO order, even when some batches
164/// are spilled to disk and others remain in memory.
165///
166/// See [`RepartitionExec`] for the overall N×M architecture.
167///
168/// [`spill_pool::channel`]: crate::spill::spill_pool::channel
169struct PartitionChannels {
170    /// Senders for each input partition to send data to this output partition
171    tx: InputPartitionsToCurrentPartitionSender,
172    /// Receivers for each input partition sending data to this output partition
173    rx: InputPartitionsToCurrentPartitionReceiver,
174    /// Memory reservation for this output partition
175    reservation: SharedMemoryReservation,
176    /// Spill writers for writing spilled data.
177    /// SpillPoolWriter is Clone, so multiple writers can share state in non-preserve-order mode.
178    spill_writers: Vec<SpillPoolWriter>,
179    /// Spill readers for reading spilled data - one per input partition (FIFO semantics).
180    /// Each (input, output) pair gets its own reader to maintain proper ordering.
181    spill_readers: Vec<SendableRecordBatchStream>,
182}
183
184struct ConsumingInputStreamsState {
185    /// Channels for sending batches from input partitions to output partitions.
186    /// Key is the partition number.
187    channels: HashMap<usize, PartitionChannels>,
188
189    /// Helper that ensures that background jobs are killed once they are no longer needed.
190    abort_helper: Arc<Vec<SpawnedTask<()>>>,
191}
192
193impl Debug for ConsumingInputStreamsState {
194    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("ConsumingInputStreamsState")
196            .field("num_channels", &self.channels.len())
197            .field("abort_helper", &self.abort_helper)
198            .finish()
199    }
200}
201
202/// Inner state of [`RepartitionExec`].
203#[derive(Default)]
204enum RepartitionExecState {
205    /// Not initialized yet. This is the default state stored in the RepartitionExec node
206    /// upon instantiation.
207    #[default]
208    NotInitialized,
209    /// Input streams are initialized, but they are still not being consumed. The node
210    /// transitions to this state when the arrow's RecordBatch stream is created in
211    /// RepartitionExec::execute(), but before any message is polled.
212    InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
213    /// The input streams are being consumed. The node transitions to this state when
214    /// the first message in the arrow's RecordBatch stream is consumed.
215    ConsumingInputStreams(ConsumingInputStreamsState),
216}
217
218impl Debug for RepartitionExecState {
219    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220        match self {
221            RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
222            RepartitionExecState::InputStreamsInitialized(v) => {
223                write!(f, "InputStreamsInitialized({:?})", v.len())
224            }
225            RepartitionExecState::ConsumingInputStreams(v) => {
226                write!(f, "ConsumingInputStreams({v:?})")
227            }
228        }
229    }
230}
231
232impl RepartitionExecState {
233    fn ensure_input_streams_initialized(
234        &mut self,
235        input: &Arc<dyn ExecutionPlan>,
236        metrics: &ExecutionPlanMetricsSet,
237        output_partitions: usize,
238        ctx: &Arc<TaskContext>,
239    ) -> Result<()> {
240        if !matches!(self, RepartitionExecState::NotInitialized) {
241            return Ok(());
242        }
243
244        let num_input_partitions = input.output_partitioning().partition_count();
245        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
246
247        for i in 0..num_input_partitions {
248            let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
249
250            let timer = metrics.fetch_time.timer();
251            let stream = input.execute(i, Arc::clone(ctx))?;
252            timer.done();
253
254            streams_and_metrics.push((stream, metrics));
255        }
256        *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
257        Ok(())
258    }
259
260    #[expect(clippy::too_many_arguments)]
261    fn consume_input_streams(
262        &mut self,
263        input: &Arc<dyn ExecutionPlan>,
264        metrics: &ExecutionPlanMetricsSet,
265        partitioning: &Partitioning,
266        preserve_order: bool,
267        name: &str,
268        context: &Arc<TaskContext>,
269        spill_manager: SpillManager,
270    ) -> Result<&mut ConsumingInputStreamsState> {
271        let streams_and_metrics = match self {
272            RepartitionExecState::NotInitialized => {
273                self.ensure_input_streams_initialized(
274                    input,
275                    metrics,
276                    partitioning.partition_count(),
277                    context,
278                )?;
279                let RepartitionExecState::InputStreamsInitialized(value) = self else {
280                    // This cannot happen, as ensure_input_streams_initialized() was just called,
281                    // but the compiler does not know.
282                    return internal_err!(
283                        "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
284                    );
285                };
286                value
287            }
288            RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
289            RepartitionExecState::InputStreamsInitialized(value) => value,
290        };
291
292        let num_input_partitions = streams_and_metrics.len();
293        let num_output_partitions = partitioning.partition_count();
294
295        let spill_manager = Arc::new(spill_manager);
296
297        let (txs, rxs) = if preserve_order {
298            // Create partition-aware channels with one channel per (input, output) pair
299            // This provides backpressure while maintaining proper ordering
300            let (txs_all, rxs_all) =
301                partition_aware_channels(num_input_partitions, num_output_partitions);
302            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
303            let txs = transpose(txs_all);
304            let rxs = transpose(rxs_all);
305            (txs, rxs)
306        } else {
307            // Create one channel per *output* partition with backpressure
308            let (txs, rxs) = channels(num_output_partitions);
309            // Clone sender for each input partitions
310            let txs = txs
311                .into_iter()
312                .map(|item| vec![item; num_input_partitions])
313                .collect::<Vec<_>>();
314            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
315            (txs, rxs)
316        };
317
318        let mut channels = HashMap::with_capacity(txs.len());
319        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
320            let reservation = Arc::new(Mutex::new(
321                MemoryConsumer::new(format!("{name}[{partition}]"))
322                    .with_can_spill(true)
323                    .register(context.memory_pool()),
324            ));
325
326            // Create spill channels based on mode:
327            // - preserve_order: one spill channel per (input, output) pair for proper FIFO ordering
328            // - non-preserve-order: one shared spill channel per output partition since all inputs
329            //   share the same receiver
330            let max_file_size = context
331                .session_config()
332                .options()
333                .execution
334                .max_spill_file_size_bytes;
335            let num_spill_channels = if preserve_order {
336                num_input_partitions
337            } else {
338                1
339            };
340            let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
341                ..num_spill_channels)
342                .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
343                .unzip();
344
345            channels.insert(
346                partition,
347                PartitionChannels {
348                    tx,
349                    rx,
350                    reservation,
351                    spill_readers,
352                    spill_writers,
353                },
354            );
355        }
356
357        // launch one async task per *input* partition
358        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
359        for (i, (stream, metrics)) in
360            std::mem::take(streams_and_metrics).into_iter().enumerate()
361        {
362            let txs: HashMap<_, _> = channels
363                .iter()
364                .map(|(partition, channels)| {
365                    // In preserve_order mode: each input gets its own spill writer (index i)
366                    // In non-preserve-order mode: all inputs share spill writer 0 via clone
367                    let spill_writer_idx = if preserve_order { i } else { 0 };
368                    (
369                        *partition,
370                        OutputChannel {
371                            sender: channels.tx[i].clone(),
372                            reservation: Arc::clone(&channels.reservation),
373                            spill_writer: channels.spill_writers[spill_writer_idx]
374                                .clone(),
375                        },
376                    )
377                })
378                .collect();
379
380            // Extract senders for wait_for_task before moving txs
381            let senders: HashMap<_, _> = txs
382                .iter()
383                .map(|(partition, channel)| (*partition, channel.sender.clone()))
384                .collect();
385
386            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
387                stream,
388                txs,
389                partitioning.clone(),
390                metrics,
391                // preserve_order depends on partition index to start from 0
392                if preserve_order { 0 } else { i },
393                num_input_partitions,
394            ));
395
396            // In a separate task, wait for each input to be done
397            // (and pass along any errors, including panic!s)
398            let wait_for_task =
399                SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
400            spawned_tasks.push(wait_for_task);
401        }
402        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
403            channels,
404            abort_helper: Arc::new(spawned_tasks),
405        });
406        match self {
407            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
408            _ => unreachable!(),
409        }
410    }
411}
412
413/// A utility that can be used to partition batches based on [`Partitioning`]
414pub struct BatchPartitioner {
415    state: BatchPartitionerState,
416    timer: metrics::Time,
417}
418
419enum BatchPartitionerState {
420    Hash {
421        exprs: Vec<Arc<dyn PhysicalExpr>>,
422        num_partitions: usize,
423        hash_buffer: Vec<u64>,
424    },
425    RoundRobin {
426        num_partitions: usize,
427        next_idx: usize,
428    },
429}
430
431/// Fixed RandomState used for hash repartitioning to ensure consistent behavior across
432/// executions and runs.
433pub const REPARTITION_RANDOM_STATE: SeededRandomState =
434    SeededRandomState::with_seeds(0, 0, 0, 0);
435
436impl BatchPartitioner {
437    /// Create a new [`BatchPartitioner`] for hash-based repartitioning.
438    ///
439    /// # Parameters
440    /// - `exprs`: Expressions used to compute the hash for each input row.
441    /// - `num_partitions`: Total number of output partitions.
442    /// - `timer`: Metric used to record time spent during repartitioning.
443    ///
444    /// # Notes
445    /// This constructor cannot fail and performs no validation.
446    pub fn new_hash_partitioner(
447        exprs: Vec<Arc<dyn PhysicalExpr>>,
448        num_partitions: usize,
449        timer: metrics::Time,
450    ) -> Self {
451        Self {
452            state: BatchPartitionerState::Hash {
453                exprs,
454                num_partitions,
455                hash_buffer: vec![],
456            },
457            timer,
458        }
459    }
460
461    /// Create a new [`BatchPartitioner`] for round-robin repartitioning.
462    ///
463    /// # Parameters
464    /// - `num_partitions`: Total number of output partitions.
465    /// - `timer`: Metric used to record time spent during repartitioning.
466    /// - `input_partition`: Index of the current input partition.
467    /// - `num_input_partitions`: Total number of input partitions.
468    ///
469    /// # Notes
470    /// The starting output partition is derived from the input partition
471    /// to avoid skew when multiple input partitions are used.
472    pub fn new_round_robin_partitioner(
473        num_partitions: usize,
474        timer: metrics::Time,
475        input_partition: usize,
476        num_input_partitions: usize,
477    ) -> Self {
478        Self {
479            state: BatchPartitionerState::RoundRobin {
480                num_partitions,
481                next_idx: (input_partition * num_partitions) / num_input_partitions,
482            },
483            timer,
484        }
485    }
486    /// Create a new [`BatchPartitioner`] based on the provided [`Partitioning`] scheme.
487    ///
488    /// This is a convenience constructor that delegates to the specialized
489    /// hash or round-robin constructors depending on the partitioning variant.
490    ///
491    /// # Parameters
492    /// - `partitioning`: Partitioning scheme to apply (hash or round-robin).
493    /// - `timer`: Metric used to record time spent during repartitioning.
494    /// - `input_partition`: Index of the current input partition.
495    /// - `num_input_partitions`: Total number of input partitions.
496    ///
497    /// # Errors
498    /// Returns an error if the provided partitioning scheme is not supported.
499    pub fn try_new(
500        partitioning: Partitioning,
501        timer: metrics::Time,
502        input_partition: usize,
503        num_input_partitions: usize,
504    ) -> Result<Self> {
505        match partitioning {
506            Partitioning::Hash(exprs, num_partitions) => {
507                Ok(Self::new_hash_partitioner(exprs, num_partitions, timer))
508            }
509            Partitioning::RoundRobinBatch(num_partitions) => {
510                Ok(Self::new_round_robin_partitioner(
511                    num_partitions,
512                    timer,
513                    input_partition,
514                    num_input_partitions,
515                ))
516            }
517            other => {
518                not_impl_err!("Unsupported repartitioning scheme {other:?}")
519            }
520        }
521    }
522
523    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
524    /// based on the [`Partitioning`] specified on construction
525    ///
526    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
527    /// partition index. Any error returned by `f` will be immediately returned by this
528    /// function without attempting to publish further [`RecordBatch`]
529    ///
530    /// The time spent repartitioning, not including time spent in `f` will be recorded
531    /// to the [`metrics::Time`] provided on construction
532    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
533    where
534        F: FnMut(usize, RecordBatch) -> Result<()>,
535    {
536        self.partition_iter(batch)?.try_for_each(|res| match res {
537            Ok((partition, batch)) => f(partition, batch),
538            Err(e) => Err(e),
539        })
540    }
541
542    /// Actual implementation of [`partition`](Self::partition).
543    ///
544    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
545    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
546    /// this (so we don't need to clone the entire implementation).
547    fn partition_iter(
548        &mut self,
549        batch: RecordBatch,
550    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
551        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
552            match &mut self.state {
553                BatchPartitionerState::RoundRobin {
554                    num_partitions,
555                    next_idx,
556                } => {
557                    let idx = *next_idx;
558                    *next_idx = (*next_idx + 1) % *num_partitions;
559                    Box::new(std::iter::once(Ok((idx, batch))))
560                }
561                BatchPartitionerState::Hash {
562                    exprs,
563                    num_partitions: partitions,
564                    hash_buffer,
565                } => {
566                    // Tracking time required for distributing indexes across output partitions
567                    let timer = self.timer.timer();
568
569                    let arrays =
570                        evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
571
572                    hash_buffer.clear();
573                    hash_buffer.resize(batch.num_rows(), 0);
574
575                    create_hashes(
576                        &arrays,
577                        REPARTITION_RANDOM_STATE.random_state(),
578                        hash_buffer,
579                    )?;
580
581                    let mut indices: Vec<_> = (0..*partitions)
582                        .map(|_| Vec::with_capacity(batch.num_rows()))
583                        .collect();
584
585                    for (index, hash) in hash_buffer.iter().enumerate() {
586                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
587                    }
588
589                    // Finished building index-arrays for output partitions
590                    timer.done();
591
592                    // Borrowing partitioner timer to prevent moving `self` to closure
593                    let partitioner_timer = &self.timer;
594                    let it = indices
595                        .into_iter()
596                        .enumerate()
597                        .filter_map(|(partition, indices)| {
598                            let indices: PrimitiveArray<UInt32Type> = indices.into();
599                            (!indices.is_empty()).then_some((partition, indices))
600                        })
601                        .map(move |(partition, indices)| {
602                            // Tracking time required for repartitioned batches construction
603                            let _timer = partitioner_timer.timer();
604
605                            // Produce batches based on indices
606                            let columns = take_arrays(batch.columns(), &indices, None)?;
607
608                            let mut options = RecordBatchOptions::new();
609                            options = options.with_row_count(Some(indices.len()));
610                            let batch = RecordBatch::try_new_with_options(
611                                batch.schema(),
612                                columns,
613                                &options,
614                            )
615                            .unwrap();
616
617                            Ok((partition, batch))
618                        });
619
620                    Box::new(it)
621                }
622            };
623
624        Ok(it)
625    }
626
627    // return the number of output partitions
628    fn num_partitions(&self) -> usize {
629        match self.state {
630            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
631            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
632        }
633    }
634}
635
636/// Maps `N` input partitions to `M` output partitions based on a
637/// [`Partitioning`] scheme.
638///
639/// # Background
640///
641/// DataFusion, like most other commercial systems, with the
642/// notable exception of DuckDB, uses the "Exchange Operator" based
643/// approach to parallelism which works well in practice given
644/// sufficient care in implementation.
645///
646/// DataFusion's planner picks the target number of partitions and
647/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
648/// of output partitions.
649///
650/// For example, given `target_partitions=3` (trying to use 3 cores)
651/// but scanning an input with 2 partitions, `RepartitionExec` can be
652/// used to get 3 even streams of `RecordBatch`es
653///
654///
655///```text
656///        ▲                  ▲                  ▲
657///        │                  │                  │
658///        │                  │                  │
659///        │                  │                  │
660/// ┌───────────────┐  ┌───────────────┐  ┌───────────────┐
661/// │    GroupBy    │  │    GroupBy    │  │    GroupBy    │
662/// │   (Partial)   │  │   (Partial)   │  │   (Partial)   │
663/// └───────────────┘  └───────────────┘  └───────────────┘
664///        ▲                  ▲                  ▲
665///        └──────────────────┼──────────────────┘
666///                           │
667///              ┌─────────────────────────┐
668///              │     RepartitionExec     │
669///              │   (hash/round robin)    │
670///              └─────────────────────────┘
671///                         ▲   ▲
672///             ┌───────────┘   └───────────┐
673///             │                           │
674///             │                           │
675///        .─────────.                 .─────────.
676///     ,─'           '─.           ,─'           '─.
677///    ;      Input      :         ;      Input      :
678///    :   Partition 0   ;         :   Partition 1   ;
679///     ╲               ╱           ╲               ╱
680///      '─.         ,─'             '─.         ,─'
681///         `───────'                   `───────'
682/// ```
683///
684/// # Error Handling
685///
686/// If any of the input partitions return an error, the error is propagated to
687/// all output partitions and inputs are not polled again.
688///
689/// # Output Ordering
690///
691/// If more than one stream is being repartitioned, the output will be some
692/// arbitrary interleaving (and thus unordered) unless
693/// [`Self::with_preserve_order`] specifies otherwise.
694///
695/// # Spilling Architecture
696///
697/// RepartitionExec uses [`SpillPool`](crate::spill::spill_pool) channels to handle
698/// memory pressure during repartitioning. Each (input partition, output partition)
699/// pair gets its own SpillPool channel for FIFO ordering.
700///
701/// ```text
702/// Input Partitions (N)          Output Partitions (M)
703/// ────────────────────          ─────────────────────
704///
705///    Input 0 ──┐                      ┌──▶ Output 0
706///              │  ┌──────────────┐    │
707///              ├─▶│ SpillPool    │────┤
708///              │  │ [In0→Out0]   │    │
709///    Input 1 ──┤  └──────────────┘    ├──▶ Output 1
710///              │                       │
711///              │  ┌──────────────┐    │
712///              ├─▶│ SpillPool    │────┤
713///              │  │ [In1→Out0]   │    │
714///    Input 2 ──┤  └──────────────┘    ├──▶ Output 2
715///              │                      │
716///              │       ... (N×M SpillPools total)
717///              │                      │
718///              │  ┌──────────────┐    │
719///              └─▶│ SpillPool    │────┘
720///                 │ [InN→OutM]   │
721///                 └──────────────┘
722///
723/// Each SpillPool maintains FIFO order for its (input, output) pair.
724/// See `RepartitionBatch` for details on the memory/spill decision logic.
725/// ```
726///
727/// # Footnote
728///
729/// The "Exchange Operator" was first described in the 1989 paper
730/// [Encapsulation of parallelism in the Volcano query processing
731/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
732/// which uses the term "Exchange" for the concept of repartitioning
733/// data across threads.
734#[derive(Debug, Clone)]
735pub struct RepartitionExec {
736    /// Input execution plan
737    input: Arc<dyn ExecutionPlan>,
738    /// Inner state that is initialized when the parent calls .execute() on this node
739    /// and consumed as soon as the parent starts consuming this node.
740    state: Arc<Mutex<RepartitionExecState>>,
741    /// Execution metrics
742    metrics: ExecutionPlanMetricsSet,
743    /// Boolean flag to decide whether to preserve ordering. If true means
744    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
745    preserve_order: bool,
746    /// Cache holding plan properties like equivalences, output partitioning etc.
747    cache: PlanProperties,
748}
749
750#[derive(Debug, Clone)]
751struct RepartitionMetrics {
752    /// Time in nanos to execute child operator and fetch batches
753    fetch_time: metrics::Time,
754    /// Repartitioning elapsed time in nanos
755    repartition_time: metrics::Time,
756    /// Time in nanos for sending resulting batches to channels.
757    ///
758    /// One metric per output partition.
759    send_time: Vec<metrics::Time>,
760}
761
762impl RepartitionMetrics {
763    pub fn new(
764        input_partition: usize,
765        num_output_partitions: usize,
766        metrics: &ExecutionPlanMetricsSet,
767    ) -> Self {
768        // Time in nanos to execute child operator and fetch batches
769        let fetch_time =
770            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
771
772        // Time in nanos to perform repartitioning
773        let repartition_time =
774            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
775
776        // Time in nanos for sending resulting batches to channels
777        let send_time = (0..num_output_partitions)
778            .map(|output_partition| {
779                let label =
780                    metrics::Label::new("outputPartition", output_partition.to_string());
781                MetricBuilder::new(metrics)
782                    .with_label(label)
783                    .subset_time("send_time", input_partition)
784            })
785            .collect();
786
787        Self {
788            fetch_time,
789            repartition_time,
790            send_time,
791        }
792    }
793}
794
795impl RepartitionExec {
796    /// Input execution plan
797    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
798        &self.input
799    }
800
801    /// Partitioning scheme to use
802    pub fn partitioning(&self) -> &Partitioning {
803        &self.cache.partitioning
804    }
805
806    /// Get preserve_order flag of the RepartitionExec
807    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
808    pub fn preserve_order(&self) -> bool {
809        self.preserve_order
810    }
811
812    /// Get name used to display this Exec
813    pub fn name(&self) -> &str {
814        "RepartitionExec"
815    }
816}
817
818impl DisplayAs for RepartitionExec {
819    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
820        let input_partition_count = self.input.output_partitioning().partition_count();
821        match t {
822            DisplayFormatType::Default | DisplayFormatType::Verbose => {
823                write!(
824                    f,
825                    "{}: partitioning={}, input_partitions={}",
826                    self.name(),
827                    self.partitioning(),
828                    input_partition_count,
829                )?;
830
831                if self.preserve_order {
832                    write!(f, ", preserve_order=true")?;
833                } else if input_partition_count <= 1
834                    && self.input.output_ordering().is_some()
835                {
836                    // Make it explicit that repartition maintains sortedness for a single input partition even
837                    // when `preserve_sort order` is false
838                    write!(f, ", maintains_sort_order=true")?;
839                }
840
841                if let Some(sort_exprs) = self.sort_exprs() {
842                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
843                }
844                Ok(())
845            }
846            DisplayFormatType::TreeRender => {
847                writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
848                let output_partition_count = self.partitioning().partition_count();
849                let input_to_output_partition_str =
850                    format!("{input_partition_count} -> {output_partition_count}");
851                writeln!(
852                    f,
853                    "partition_count(in->out)={input_to_output_partition_str}"
854                )?;
855
856                if self.preserve_order {
857                    writeln!(f, "preserve_order={}", self.preserve_order)?;
858                }
859                Ok(())
860            }
861        }
862    }
863}
864
865impl ExecutionPlan for RepartitionExec {
866    fn name(&self) -> &'static str {
867        "RepartitionExec"
868    }
869
870    /// Return a reference to Any that can be used for downcasting
871    fn as_any(&self) -> &dyn Any {
872        self
873    }
874
875    fn properties(&self) -> &PlanProperties {
876        &self.cache
877    }
878
879    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
880        vec![&self.input]
881    }
882
883    fn with_new_children(
884        self: Arc<Self>,
885        mut children: Vec<Arc<dyn ExecutionPlan>>,
886    ) -> Result<Arc<dyn ExecutionPlan>> {
887        let mut repartition = RepartitionExec::try_new(
888            children.swap_remove(0),
889            self.partitioning().clone(),
890        )?;
891        if self.preserve_order {
892            repartition = repartition.with_preserve_order();
893        }
894        Ok(Arc::new(repartition))
895    }
896
897    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
898        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
899    }
900
901    fn maintains_input_order(&self) -> Vec<bool> {
902        Self::maintains_input_order_helper(self.input(), self.preserve_order)
903    }
904
905    fn execute(
906        &self,
907        partition: usize,
908        context: Arc<TaskContext>,
909    ) -> Result<SendableRecordBatchStream> {
910        trace!(
911            "Start {}::execute for partition: {}",
912            self.name(),
913            partition
914        );
915
916        let spill_metrics = SpillMetrics::new(&self.metrics, partition);
917
918        let input = Arc::clone(&self.input);
919        let partitioning = self.partitioning().clone();
920        let metrics = self.metrics.clone();
921        let preserve_order = self.sort_exprs().is_some();
922        let name = self.name().to_owned();
923        let schema = self.schema();
924        let schema_captured = Arc::clone(&schema);
925
926        let spill_manager = SpillManager::new(
927            Arc::clone(&context.runtime_env()),
928            spill_metrics,
929            input.schema(),
930        );
931
932        // Get existing ordering to use for merging
933        let sort_exprs = self.sort_exprs().cloned();
934
935        let state = Arc::clone(&self.state);
936        if let Some(mut state) = state.try_lock() {
937            state.ensure_input_streams_initialized(
938                &input,
939                &metrics,
940                partitioning.partition_count(),
941                &context,
942            )?;
943        }
944
945        let num_input_partitions = input.output_partitioning().partition_count();
946
947        let stream = futures::stream::once(async move {
948            // lock scope
949            let (rx, reservation, spill_readers, abort_helper) = {
950                // lock mutexes
951                let mut state = state.lock();
952                let state = state.consume_input_streams(
953                    &input,
954                    &metrics,
955                    &partitioning,
956                    preserve_order,
957                    &name,
958                    &context,
959                    spill_manager.clone(),
960                )?;
961
962                // now return stream for the specified *output* partition which will
963                // read from the channel
964                let PartitionChannels {
965                    rx,
966                    reservation,
967                    spill_readers,
968                    ..
969                } = state
970                    .channels
971                    .remove(&partition)
972                    .expect("partition not used yet");
973
974                (
975                    rx,
976                    reservation,
977                    spill_readers,
978                    Arc::clone(&state.abort_helper),
979                )
980            };
981
982            trace!(
983                "Before returning stream in {name}::execute for partition: {partition}"
984            );
985
986            if preserve_order {
987                // Store streams from all the input partitions:
988                // Each input partition gets its own spill reader to maintain proper FIFO ordering
989                let input_streams = rx
990                    .into_iter()
991                    .zip(spill_readers)
992                    .map(|(receiver, spill_stream)| {
993                        // In preserve_order mode, each receiver corresponds to exactly one input partition
994                        Box::pin(PerPartitionStream::new(
995                            Arc::clone(&schema_captured),
996                            receiver,
997                            Arc::clone(&abort_helper),
998                            Arc::clone(&reservation),
999                            spill_stream,
1000                            1, // Each receiver handles one input partition
1001                            BaselineMetrics::new(&metrics, partition),
1002                            None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286
1003                        )) as SendableRecordBatchStream
1004                    })
1005                    .collect::<Vec<_>>();
1006                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
1007
1008                // Merge streams (while preserving ordering) coming from
1009                // input partitions to this partition:
1010                let fetch = None;
1011                let merge_reservation =
1012                    MemoryConsumer::new(format!("{name}[Merge {partition}]"))
1013                        .register(context.memory_pool());
1014                StreamingMergeBuilder::new()
1015                    .with_streams(input_streams)
1016                    .with_schema(schema_captured)
1017                    .with_expressions(&sort_exprs.unwrap())
1018                    .with_metrics(BaselineMetrics::new(&metrics, partition))
1019                    .with_batch_size(context.session_config().batch_size())
1020                    .with_fetch(fetch)
1021                    .with_reservation(merge_reservation)
1022                    .with_spill_manager(spill_manager)
1023                    .build()
1024            } else {
1025                // Non-preserve-order case: single input stream, so use the first spill reader
1026                let spill_stream = spill_readers
1027                    .into_iter()
1028                    .next()
1029                    .expect("at least one spill reader should exist");
1030
1031                Ok(Box::pin(PerPartitionStream::new(
1032                    schema_captured,
1033                    rx.into_iter()
1034                        .next()
1035                        .expect("at least one receiver should exist"),
1036                    abort_helper,
1037                    reservation,
1038                    spill_stream,
1039                    num_input_partitions,
1040                    BaselineMetrics::new(&metrics, partition),
1041                    Some(context.session_config().batch_size()),
1042                )) as SendableRecordBatchStream)
1043            }
1044        })
1045        .try_flatten();
1046        let stream = RecordBatchStreamAdapter::new(schema, stream);
1047        Ok(Box::pin(stream))
1048    }
1049
1050    fn metrics(&self) -> Option<MetricsSet> {
1051        Some(self.metrics.clone_inner())
1052    }
1053
1054    fn statistics(&self) -> Result<Statistics> {
1055        self.input.partition_statistics(None)
1056    }
1057
1058    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1059        if let Some(partition) = partition {
1060            let partition_count = self.partitioning().partition_count();
1061            if partition_count == 0 {
1062                return Ok(Statistics::new_unknown(&self.schema()));
1063            }
1064
1065            assert_or_internal_err!(
1066                partition < partition_count,
1067                "RepartitionExec invalid partition {} (expected less than {})",
1068                partition,
1069                partition_count
1070            );
1071
1072            let mut stats = self.input.partition_statistics(None)?;
1073
1074            // Distribute statistics across partitions
1075            stats.num_rows = stats
1076                .num_rows
1077                .get_value()
1078                .map(|rows| Precision::Inexact(rows / partition_count))
1079                .unwrap_or(Precision::Absent);
1080            stats.total_byte_size = stats
1081                .total_byte_size
1082                .get_value()
1083                .map(|bytes| Precision::Inexact(bytes / partition_count))
1084                .unwrap_or(Precision::Absent);
1085
1086            // Make all column stats unknown
1087            stats.column_statistics = stats
1088                .column_statistics
1089                .iter()
1090                .map(|_| ColumnStatistics::new_unknown())
1091                .collect();
1092
1093            Ok(stats)
1094        } else {
1095            self.input.partition_statistics(None)
1096        }
1097    }
1098
1099    fn cardinality_effect(&self) -> CardinalityEffect {
1100        CardinalityEffect::Equal
1101    }
1102
1103    fn try_swapping_with_projection(
1104        &self,
1105        projection: &ProjectionExec,
1106    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1107        // If the projection does not narrow the schema, we should not try to push it down.
1108        if projection.expr().len() >= projection.input().schema().fields().len() {
1109            return Ok(None);
1110        }
1111
1112        // If pushdown is not beneficial or applicable, break it.
1113        if projection.benefits_from_input_partitioning()[0]
1114            || !all_columns(projection.expr())
1115        {
1116            return Ok(None);
1117        }
1118
1119        let new_projection = make_with_child(projection, self.input())?;
1120
1121        let new_partitioning = match self.partitioning() {
1122            Partitioning::Hash(partitions, size) => {
1123                let mut new_partitions = vec![];
1124                for partition in partitions {
1125                    let Some(new_partition) =
1126                        update_expr(partition, projection.expr(), false)?
1127                    else {
1128                        return Ok(None);
1129                    };
1130                    new_partitions.push(new_partition);
1131                }
1132                Partitioning::Hash(new_partitions, *size)
1133            }
1134            others => others.clone(),
1135        };
1136
1137        Ok(Some(Arc::new(RepartitionExec::try_new(
1138            new_projection,
1139            new_partitioning,
1140        )?)))
1141    }
1142
1143    fn gather_filters_for_pushdown(
1144        &self,
1145        _phase: FilterPushdownPhase,
1146        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1147        _config: &ConfigOptions,
1148    ) -> Result<FilterDescription> {
1149        FilterDescription::from_children(parent_filters, &self.children())
1150    }
1151
1152    fn handle_child_pushdown_result(
1153        &self,
1154        _phase: FilterPushdownPhase,
1155        child_pushdown_result: ChildPushdownResult,
1156        _config: &ConfigOptions,
1157    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1158        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1159    }
1160
1161    fn try_pushdown_sort(
1162        &self,
1163        order: &[PhysicalSortExpr],
1164    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
1165        // RepartitionExec only maintains input order if preserve_order is set
1166        // or if there's only one partition
1167        if !self.maintains_input_order()[0] {
1168            return Ok(SortOrderPushdownResult::Unsupported);
1169        }
1170
1171        // Delegate to the child and wrap with a new RepartitionExec
1172        self.input.try_pushdown_sort(order)?.try_map(|new_input| {
1173            let mut new_repartition =
1174                RepartitionExec::try_new(new_input, self.partitioning().clone())?;
1175            if self.preserve_order {
1176                new_repartition = new_repartition.with_preserve_order();
1177            }
1178            Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
1179        })
1180    }
1181
1182    fn repartitioned(
1183        &self,
1184        target_partitions: usize,
1185        _config: &ConfigOptions,
1186    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1187        use Partitioning::*;
1188        let mut new_properties = self.cache.clone();
1189        new_properties.partitioning = match new_properties.partitioning {
1190            RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1191            Hash(hash, _) => Hash(hash, target_partitions),
1192            UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1193        };
1194        Ok(Some(Arc::new(Self {
1195            input: Arc::clone(&self.input),
1196            state: Arc::clone(&self.state),
1197            metrics: self.metrics.clone(),
1198            preserve_order: self.preserve_order,
1199            cache: new_properties,
1200        })))
1201    }
1202}
1203
1204impl RepartitionExec {
1205    /// Create a new RepartitionExec, that produces output `partitioning`, and
1206    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
1207    /// for more details)
1208    pub fn try_new(
1209        input: Arc<dyn ExecutionPlan>,
1210        partitioning: Partitioning,
1211    ) -> Result<Self> {
1212        let preserve_order = false;
1213        let cache = Self::compute_properties(&input, partitioning, preserve_order);
1214        Ok(RepartitionExec {
1215            input,
1216            state: Default::default(),
1217            metrics: ExecutionPlanMetricsSet::new(),
1218            preserve_order,
1219            cache,
1220        })
1221    }
1222
1223    fn maintains_input_order_helper(
1224        input: &Arc<dyn ExecutionPlan>,
1225        preserve_order: bool,
1226    ) -> Vec<bool> {
1227        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
1228        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1229    }
1230
1231    fn eq_properties_helper(
1232        input: &Arc<dyn ExecutionPlan>,
1233        preserve_order: bool,
1234    ) -> EquivalenceProperties {
1235        // Equivalence Properties
1236        let mut eq_properties = input.equivalence_properties().clone();
1237        // If the ordering is lost, reset the ordering equivalence class:
1238        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1239            eq_properties.clear_orderings();
1240        }
1241        // When there are more than one input partitions, they will be fused at the output.
1242        // Therefore, remove per partition constants.
1243        if input.output_partitioning().partition_count() > 1 {
1244            eq_properties.clear_per_partition_constants();
1245        }
1246        eq_properties
1247    }
1248
1249    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1250    fn compute_properties(
1251        input: &Arc<dyn ExecutionPlan>,
1252        partitioning: Partitioning,
1253        preserve_order: bool,
1254    ) -> PlanProperties {
1255        PlanProperties::new(
1256            Self::eq_properties_helper(input, preserve_order),
1257            partitioning,
1258            input.pipeline_behavior(),
1259            input.boundedness(),
1260        )
1261        .with_scheduling_type(SchedulingType::Cooperative)
1262        .with_evaluation_type(EvaluationType::Eager)
1263    }
1264
1265    /// Specify if this repartitioning operation should preserve the order of
1266    /// rows from its input when producing output. Preserving order is more
1267    /// expensive at runtime, so should only be set if the output of this
1268    /// operator can take advantage of it.
1269    ///
1270    /// If the input is not ordered, or has only one partition, this is a no op,
1271    /// and the node remains a `RepartitionExec`.
1272    pub fn with_preserve_order(mut self) -> Self {
1273        self.preserve_order =
1274                // If the input isn't ordered, there is no ordering to preserve
1275                self.input.output_ordering().is_some() &&
1276                // if there is only one input partition, merging is not required
1277                // to maintain order
1278                self.input.output_partitioning().partition_count() > 1;
1279        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1280        self.cache = self.cache.with_eq_properties(eq_properties);
1281        self
1282    }
1283
1284    /// Return the sort expressions that are used to merge
1285    fn sort_exprs(&self) -> Option<&LexOrdering> {
1286        if self.preserve_order {
1287            self.input.output_ordering()
1288        } else {
1289            None
1290        }
1291    }
1292
1293    /// Pulls data from the specified input plan, feeding it to the
1294    /// output partitions based on the desired partitioning
1295    ///
1296    /// `output_channels` holds the output sending channels for each output partition
1297    async fn pull_from_input(
1298        mut stream: SendableRecordBatchStream,
1299        mut output_channels: HashMap<usize, OutputChannel>,
1300        partitioning: Partitioning,
1301        metrics: RepartitionMetrics,
1302        input_partition: usize,
1303        num_input_partitions: usize,
1304    ) -> Result<()> {
1305        let mut partitioner = match &partitioning {
1306            Partitioning::Hash(exprs, num_partitions) => {
1307                BatchPartitioner::new_hash_partitioner(
1308                    exprs.clone(),
1309                    *num_partitions,
1310                    metrics.repartition_time.clone(),
1311                )
1312            }
1313            Partitioning::RoundRobinBatch(num_partitions) => {
1314                BatchPartitioner::new_round_robin_partitioner(
1315                    *num_partitions,
1316                    metrics.repartition_time.clone(),
1317                    input_partition,
1318                    num_input_partitions,
1319                )
1320            }
1321            other => {
1322                return not_impl_err!("Unsupported repartitioning scheme {other:?}");
1323            }
1324        };
1325
1326        // While there are still outputs to send to, keep pulling inputs
1327        let mut batches_until_yield = partitioner.num_partitions();
1328        while !output_channels.is_empty() {
1329            // fetch the next batch
1330            let timer = metrics.fetch_time.timer();
1331            let result = stream.next().await;
1332            timer.done();
1333
1334            // Input is done
1335            let batch = match result {
1336                Some(result) => result?,
1337                None => break,
1338            };
1339
1340            // Handle empty batch
1341            if batch.num_rows() == 0 {
1342                continue;
1343            }
1344
1345            for res in partitioner.partition_iter(batch)? {
1346                let (partition, batch) = res?;
1347                let size = batch.get_array_memory_size();
1348
1349                let timer = metrics.send_time[partition].timer();
1350                // if there is still a receiver, send to it
1351                if let Some(channel) = output_channels.get_mut(&partition) {
1352                    let (batch_to_send, is_memory_batch) =
1353                        match channel.reservation.lock().try_grow(size) {
1354                            Ok(_) => {
1355                                // Memory available - send in-memory batch
1356                                (RepartitionBatch::Memory(batch), true)
1357                            }
1358                            Err(_) => {
1359                                // We're memory limited - spill to SpillPool
1360                                // SpillPool handles file handle reuse and rotation
1361                                channel.spill_writer.push_batch(&batch)?;
1362                                // Send marker indicating batch was spilled
1363                                (RepartitionBatch::Spilled, false)
1364                            }
1365                        };
1366
1367                    if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
1368                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
1369                        // Only shrink memory if it was a memory batch
1370                        if is_memory_batch {
1371                            channel.reservation.lock().shrink(size);
1372                        }
1373                        output_channels.remove(&partition);
1374                    }
1375                }
1376                timer.done();
1377            }
1378
1379            // If the input stream is endless, we may spin forever and
1380            // never yield back to tokio.  See
1381            // https://github.com/apache/datafusion/issues/5278.
1382            //
1383            // However, yielding on every batch causes a bottleneck
1384            // when running with multiple cores. See
1385            // https://github.com/apache/datafusion/issues/6290
1386            //
1387            // Thus, heuristically yield after producing num_partition
1388            // batches
1389            //
1390            // In round robin this is ideal as each input will get a
1391            // new batch. In hash partitioning it may yield too often
1392            // on uneven distributions even if some partition can not
1393            // make progress, but parallelism is going to be limited
1394            // in that case anyways
1395            if batches_until_yield == 0 {
1396                tokio::task::yield_now().await;
1397                batches_until_yield = partitioner.num_partitions();
1398            } else {
1399                batches_until_yield -= 1;
1400            }
1401        }
1402
1403        // Spill writers will auto-finalize when dropped
1404        // No need for explicit flush
1405        Ok(())
1406    }
1407
1408    /// Waits for `input_task` which is consuming one of the inputs to
1409    /// complete. Upon each successful completion, sends a `None` to
1410    /// each of the output tx channels to signal one of the inputs is
1411    /// complete. Upon error, propagates the errors to all output tx
1412    /// channels.
1413    async fn wait_for_task(
1414        input_task: SpawnedTask<Result<()>>,
1415        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1416    ) {
1417        // wait for completion, and propagate error
1418        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
1419
1420        match input_task.join().await {
1421            // Error in joining task
1422            Err(e) => {
1423                let e = Arc::new(e);
1424
1425                for (_, tx) in txs {
1426                    let err = Err(DataFusionError::Context(
1427                        "Join Error".to_string(),
1428                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1429                    ));
1430                    tx.send(Some(err)).await.ok();
1431                }
1432            }
1433            // Error from running input task
1434            Ok(Err(e)) => {
1435                // send the same Arc'd error to all output partitions
1436                let e = Arc::new(e);
1437
1438                for (_, tx) in txs {
1439                    // wrap it because need to send error to all output partitions
1440                    let err = Err(DataFusionError::from(&e));
1441                    tx.send(Some(err)).await.ok();
1442                }
1443            }
1444            // Input task completed successfully
1445            Ok(Ok(())) => {
1446                // notify each output partition that this input partition has no more data
1447                for (_partition, tx) in txs {
1448                    tx.send(None).await.ok();
1449                }
1450            }
1451        }
1452    }
1453}
1454
1455/// State for tracking whether we're reading from memory channel or spill stream.
1456///
1457/// This state machine ensures proper ordering when batches are mixed between memory
1458/// and spilled storage. When a [`RepartitionBatch::Spilled`] marker is received,
1459/// the stream must block on the spill stream until the corresponding batch arrives.
1460///
1461/// # State Machine
1462///
1463/// ```text
1464///                        ┌─────────────────┐
1465///                   ┌───▶│  ReadingMemory  │◀───┐
1466///                   │    └────────┬────────┘    │
1467///                   │             │             │
1468///                   │     Poll channel          │
1469///                   │             │             │
1470///                   │  ┌──────────┼─────────────┐
1471///                   │  │          │             │
1472///                   │  ▼          ▼             │
1473///                   │ Memory   Spilled          │
1474///       Got batch   │ batch    marker           │
1475///       from spill  │  │          │             │
1476///                   │  │          ▼             │
1477///                   │  │  ┌──────────────────┐  │
1478///                   │  │  │ ReadingSpilled   │  │
1479///                   │  │  └────────┬─────────┘  │
1480///                   │  │           │            │
1481///                   │  │   Poll spill_stream    │
1482///                   │  │           │            │
1483///                   │  │           ▼            │
1484///                   │  │      Get batch         │
1485///                   │  │           │            │
1486///                   └──┴───────────┴────────────┘
1487///                                  │
1488///                                  ▼
1489///                           Return batch
1490///                     (Order preserved within
1491///                      (input, output) pair)
1492/// ```
1493///
1494/// The transition to `ReadingSpilled` blocks further channel polling to maintain
1495/// FIFO ordering - we cannot read the next item from the channel until the spill
1496/// stream provides the current batch.
1497#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1498enum StreamState {
1499    /// Reading from the memory channel (normal operation)
1500    ReadingMemory,
1501    /// Waiting for a spilled batch from the spill stream.
1502    /// Must not poll channel until spilled batch is received to preserve ordering.
1503    ReadingSpilled,
1504}
1505
1506/// This struct converts a receiver to a stream.
1507/// Receiver receives data on an SPSC channel.
1508struct PerPartitionStream {
1509    /// Schema wrapped by Arc
1510    schema: SchemaRef,
1511
1512    /// channel containing the repartitioned batches
1513    receiver: DistributionReceiver<MaybeBatch>,
1514
1515    /// Handle to ensure background tasks are killed when no longer needed.
1516    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1517
1518    /// Memory reservation.
1519    reservation: SharedMemoryReservation,
1520
1521    /// Infinite stream for reading from the spill pool
1522    spill_stream: SendableRecordBatchStream,
1523
1524    /// Internal state indicating if we are reading from memory or spill stream
1525    state: StreamState,
1526
1527    /// Number of input partitions that have not yet finished.
1528    /// In non-preserve-order mode, multiple input partitions send to the same channel,
1529    /// each sending None when complete. We must wait for all of them.
1530    remaining_partitions: usize,
1531
1532    /// Execution metrics
1533    baseline_metrics: BaselineMetrics,
1534
1535    /// None for sort preserving variant (merge sort already does coalescing)
1536    batch_coalescer: Option<LimitedBatchCoalescer>,
1537}
1538
1539impl PerPartitionStream {
1540    #[expect(clippy::too_many_arguments)]
1541    fn new(
1542        schema: SchemaRef,
1543        receiver: DistributionReceiver<MaybeBatch>,
1544        drop_helper: Arc<Vec<SpawnedTask<()>>>,
1545        reservation: SharedMemoryReservation,
1546        spill_stream: SendableRecordBatchStream,
1547        num_input_partitions: usize,
1548        baseline_metrics: BaselineMetrics,
1549        batch_size: Option<usize>,
1550    ) -> Self {
1551        let batch_coalescer =
1552            batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
1553        Self {
1554            schema,
1555            receiver,
1556            _drop_helper: drop_helper,
1557            reservation,
1558            spill_stream,
1559            state: StreamState::ReadingMemory,
1560            remaining_partitions: num_input_partitions,
1561            baseline_metrics,
1562            batch_coalescer,
1563        }
1564    }
1565
1566    fn poll_next_inner(
1567        self: &mut Pin<&mut Self>,
1568        cx: &mut Context<'_>,
1569    ) -> Poll<Option<Result<RecordBatch>>> {
1570        use futures::StreamExt;
1571        let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1572        let _timer = cloned_time.timer();
1573
1574        loop {
1575            match self.state {
1576                StreamState::ReadingMemory => {
1577                    // Poll the memory channel for next message
1578                    let value = match self.receiver.recv().poll_unpin(cx) {
1579                        Poll::Ready(v) => v,
1580                        Poll::Pending => {
1581                            // Nothing from channel, wait
1582                            return Poll::Pending;
1583                        }
1584                    };
1585
1586                    match value {
1587                        Some(Some(v)) => match v {
1588                            Ok(RepartitionBatch::Memory(batch)) => {
1589                                // Release memory and return batch
1590                                self.reservation
1591                                    .lock()
1592                                    .shrink(batch.get_array_memory_size());
1593                                return Poll::Ready(Some(Ok(batch)));
1594                            }
1595                            Ok(RepartitionBatch::Spilled) => {
1596                                // Batch was spilled, transition to reading from spill stream
1597                                // We must block on spill stream until we get the batch
1598                                // to preserve ordering
1599                                self.state = StreamState::ReadingSpilled;
1600                                continue;
1601                            }
1602                            Err(e) => {
1603                                return Poll::Ready(Some(Err(e)));
1604                            }
1605                        },
1606                        Some(None) => {
1607                            // One input partition finished
1608                            self.remaining_partitions -= 1;
1609                            if self.remaining_partitions == 0 {
1610                                // All input partitions finished
1611                                return Poll::Ready(None);
1612                            }
1613                            // Continue to poll for more data from other partitions
1614                            continue;
1615                        }
1616                        None => {
1617                            // Channel closed unexpectedly
1618                            return Poll::Ready(None);
1619                        }
1620                    }
1621                }
1622                StreamState::ReadingSpilled => {
1623                    // Poll spill stream for the spilled batch
1624                    match self.spill_stream.poll_next_unpin(cx) {
1625                        Poll::Ready(Some(Ok(batch))) => {
1626                            self.state = StreamState::ReadingMemory;
1627                            return Poll::Ready(Some(Ok(batch)));
1628                        }
1629                        Poll::Ready(Some(Err(e))) => {
1630                            return Poll::Ready(Some(Err(e)));
1631                        }
1632                        Poll::Ready(None) => {
1633                            // Spill stream ended, keep draining the memory channel
1634                            self.state = StreamState::ReadingMemory;
1635                        }
1636                        Poll::Pending => {
1637                            // Spilled batch not ready yet, must wait
1638                            // This preserves ordering by blocking until spill data arrives
1639                            return Poll::Pending;
1640                        }
1641                    }
1642                }
1643            }
1644        }
1645    }
1646
1647    fn poll_next_and_coalesce(
1648        self: &mut Pin<&mut Self>,
1649        cx: &mut Context<'_>,
1650        coalescer: &mut LimitedBatchCoalescer,
1651    ) -> Poll<Option<Result<RecordBatch>>> {
1652        let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1653        let mut completed = false;
1654
1655        loop {
1656            if let Some(batch) = coalescer.next_completed_batch() {
1657                return Poll::Ready(Some(Ok(batch)));
1658            }
1659            if completed {
1660                return Poll::Ready(None);
1661            }
1662
1663            match ready!(self.poll_next_inner(cx)) {
1664                Some(Ok(batch)) => {
1665                    let _timer = cloned_time.timer();
1666                    if let Err(err) = coalescer.push_batch(batch) {
1667                        return Poll::Ready(Some(Err(err)));
1668                    }
1669                }
1670                Some(err) => {
1671                    return Poll::Ready(Some(err));
1672                }
1673                None => {
1674                    completed = true;
1675                    let _timer = cloned_time.timer();
1676                    if let Err(err) = coalescer.finish() {
1677                        return Poll::Ready(Some(Err(err)));
1678                    }
1679                }
1680            }
1681        }
1682    }
1683}
1684
1685impl Stream for PerPartitionStream {
1686    type Item = Result<RecordBatch>;
1687
1688    fn poll_next(
1689        mut self: Pin<&mut Self>,
1690        cx: &mut Context<'_>,
1691    ) -> Poll<Option<Self::Item>> {
1692        let poll;
1693        if let Some(mut coalescer) = self.batch_coalescer.take() {
1694            poll = self.poll_next_and_coalesce(cx, &mut coalescer);
1695            self.batch_coalescer = Some(coalescer);
1696        } else {
1697            poll = self.poll_next_inner(cx);
1698        }
1699        self.baseline_metrics.record_poll(poll)
1700    }
1701}
1702
1703impl RecordBatchStream for PerPartitionStream {
1704    /// Get the schema
1705    fn schema(&self) -> SchemaRef {
1706        Arc::clone(&self.schema)
1707    }
1708}
1709
1710#[cfg(test)]
1711mod tests {
1712    use std::collections::HashSet;
1713
1714    use super::*;
1715    use crate::test::TestMemoryExec;
1716    use crate::{
1717        test::{
1718            assert_is_pending,
1719            exec::{
1720                BarrierExec, BlockingExec, ErrorExec, MockExec,
1721                assert_strong_count_converges_to_zero,
1722            },
1723        },
1724        {collect, expressions::col},
1725    };
1726
1727    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1728    use arrow::datatypes::{DataType, Field, Schema};
1729    use datafusion_common::cast::as_string_array;
1730    use datafusion_common::exec_err;
1731    use datafusion_common::test_util::batches_to_sort_string;
1732    use datafusion_common_runtime::JoinSet;
1733    use datafusion_execution::config::SessionConfig;
1734    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1735    use insta::assert_snapshot;
1736
1737    #[tokio::test]
1738    async fn one_to_many_round_robin() -> Result<()> {
1739        // define input partitions
1740        let schema = test_schema();
1741        let partition = create_vec_batches(50);
1742        let partitions = vec![partition];
1743
1744        // repartition from 1 input to 4 output
1745        let output_partitions =
1746            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1747
1748        assert_eq!(4, output_partitions.len());
1749        for partition in &output_partitions {
1750            assert_eq!(1, partition.len());
1751        }
1752        assert_eq!(13 * 8, output_partitions[0][0].num_rows());
1753        assert_eq!(13 * 8, output_partitions[1][0].num_rows());
1754        assert_eq!(12 * 8, output_partitions[2][0].num_rows());
1755        assert_eq!(12 * 8, output_partitions[3][0].num_rows());
1756
1757        Ok(())
1758    }
1759
1760    #[tokio::test]
1761    async fn many_to_one_round_robin() -> Result<()> {
1762        // define input partitions
1763        let schema = test_schema();
1764        let partition = create_vec_batches(50);
1765        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1766
1767        // repartition from 3 input to 1 output
1768        let output_partitions =
1769            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1770
1771        assert_eq!(1, output_partitions.len());
1772        assert_eq!(150 * 8, output_partitions[0][0].num_rows());
1773
1774        Ok(())
1775    }
1776
1777    #[tokio::test]
1778    async fn many_to_many_round_robin() -> Result<()> {
1779        // define input partitions
1780        let schema = test_schema();
1781        let partition = create_vec_batches(50);
1782        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1783
1784        // repartition from 3 input to 5 output
1785        let output_partitions =
1786            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1787
1788        let total_rows_per_partition = 8 * 50 * 3 / 5;
1789        assert_eq!(5, output_partitions.len());
1790        for partition in output_partitions {
1791            assert_eq!(1, partition.len());
1792            assert_eq!(total_rows_per_partition, partition[0].num_rows());
1793        }
1794
1795        Ok(())
1796    }
1797
1798    #[tokio::test]
1799    async fn many_to_many_hash_partition() -> Result<()> {
1800        // define input partitions
1801        let schema = test_schema();
1802        let partition = create_vec_batches(50);
1803        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1804
1805        let output_partitions = repartition(
1806            &schema,
1807            partitions,
1808            Partitioning::Hash(vec![col("c0", &schema)?], 8),
1809        )
1810        .await?;
1811
1812        let total_rows: usize = output_partitions
1813            .iter()
1814            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1815            .sum();
1816
1817        assert_eq!(8, output_partitions.len());
1818        assert_eq!(total_rows, 8 * 50 * 3);
1819
1820        Ok(())
1821    }
1822
1823    #[tokio::test]
1824    async fn test_repartition_with_coalescing() -> Result<()> {
1825        let schema = test_schema();
1826        // create 50 batches, each having 8 rows
1827        let partition = create_vec_batches(50);
1828        let partitions = vec![partition.clone(), partition.clone()];
1829        let partitioning = Partitioning::RoundRobinBatch(1);
1830
1831        let session_config = SessionConfig::new().with_batch_size(200);
1832        let task_ctx = TaskContext::default().with_session_config(session_config);
1833        let task_ctx = Arc::new(task_ctx);
1834
1835        // create physical plan
1836        let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
1837        let exec = RepartitionExec::try_new(exec, partitioning)?;
1838
1839        for i in 0..exec.partitioning().partition_count() {
1840            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1841            while let Some(result) = stream.next().await {
1842                let batch = result?;
1843                assert_eq!(200, batch.num_rows());
1844            }
1845        }
1846        Ok(())
1847    }
1848
1849    fn test_schema() -> Arc<Schema> {
1850        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1851    }
1852
1853    async fn repartition(
1854        schema: &SchemaRef,
1855        input_partitions: Vec<Vec<RecordBatch>>,
1856        partitioning: Partitioning,
1857    ) -> Result<Vec<Vec<RecordBatch>>> {
1858        let task_ctx = Arc::new(TaskContext::default());
1859        // create physical plan
1860        let exec =
1861            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1862        let exec = RepartitionExec::try_new(exec, partitioning)?;
1863
1864        // execute and collect results
1865        let mut output_partitions = vec![];
1866        for i in 0..exec.partitioning().partition_count() {
1867            // execute this *output* partition and collect all batches
1868            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1869            let mut batches = vec![];
1870            while let Some(result) = stream.next().await {
1871                batches.push(result?);
1872            }
1873            output_partitions.push(batches);
1874        }
1875        Ok(output_partitions)
1876    }
1877
1878    #[tokio::test]
1879    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1880        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1881            SpawnedTask::spawn(async move {
1882                // define input partitions
1883                let schema = test_schema();
1884                let partition = create_vec_batches(50);
1885                let partitions =
1886                    vec![partition.clone(), partition.clone(), partition.clone()];
1887
1888                // repartition from 3 input to 5 output
1889                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1890            });
1891
1892        let output_partitions = handle.join().await.unwrap().unwrap();
1893
1894        let total_rows_per_partition = 8 * 50 * 3 / 5;
1895        assert_eq!(5, output_partitions.len());
1896        for partition in output_partitions {
1897            assert_eq!(1, partition.len());
1898            assert_eq!(total_rows_per_partition, partition[0].num_rows());
1899        }
1900
1901        Ok(())
1902    }
1903
1904    #[tokio::test]
1905    async fn unsupported_partitioning() {
1906        let task_ctx = Arc::new(TaskContext::default());
1907        // have to send at least one batch through to provoke error
1908        let batch = RecordBatch::try_from_iter(vec![(
1909            "my_awesome_field",
1910            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1911        )])
1912        .unwrap();
1913
1914        let schema = batch.schema();
1915        let input = MockExec::new(vec![Ok(batch)], schema);
1916        // This generates an error (partitioning type not supported)
1917        // but only after the plan is executed. The error should be
1918        // returned and no results produced
1919        let partitioning = Partitioning::UnknownPartitioning(1);
1920        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1921        let output_stream = exec.execute(0, task_ctx).unwrap();
1922
1923        // Expect that an error is returned
1924        let result_string = crate::common::collect(output_stream)
1925            .await
1926            .unwrap_err()
1927            .to_string();
1928        assert!(
1929            result_string
1930                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1931            "actual: {result_string}"
1932        );
1933    }
1934
1935    #[tokio::test]
1936    async fn error_for_input_exec() {
1937        // This generates an error on a call to execute. The error
1938        // should be returned and no results produced.
1939
1940        let task_ctx = Arc::new(TaskContext::default());
1941        let input = ErrorExec::new();
1942        let partitioning = Partitioning::RoundRobinBatch(1);
1943        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1944
1945        // Expect that an error is returned
1946        let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1947
1948        assert!(
1949            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1950            "actual: {result_string}"
1951        );
1952    }
1953
1954    #[tokio::test]
1955    async fn repartition_with_error_in_stream() {
1956        let task_ctx = Arc::new(TaskContext::default());
1957        let batch = RecordBatch::try_from_iter(vec![(
1958            "my_awesome_field",
1959            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1960        )])
1961        .unwrap();
1962
1963        // input stream returns one good batch and then one error. The
1964        // error should be returned.
1965        let err = exec_err!("bad data error");
1966
1967        let schema = batch.schema();
1968        let input = MockExec::new(vec![Ok(batch), err], schema);
1969        let partitioning = Partitioning::RoundRobinBatch(1);
1970        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1971
1972        // Note: this should pass (the stream can be created) but the
1973        // error when the input is executed should get passed back
1974        let output_stream = exec.execute(0, task_ctx).unwrap();
1975
1976        // Expect that an error is returned
1977        let result_string = crate::common::collect(output_stream)
1978            .await
1979            .unwrap_err()
1980            .to_string();
1981        assert!(
1982            result_string.contains("bad data error"),
1983            "actual: {result_string}"
1984        );
1985    }
1986
1987    #[tokio::test]
1988    async fn repartition_with_delayed_stream() {
1989        let task_ctx = Arc::new(TaskContext::default());
1990        let batch1 = RecordBatch::try_from_iter(vec![(
1991            "my_awesome_field",
1992            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1993        )])
1994        .unwrap();
1995
1996        let batch2 = RecordBatch::try_from_iter(vec![(
1997            "my_awesome_field",
1998            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1999        )])
2000        .unwrap();
2001
2002        // The mock exec doesn't return immediately (instead it
2003        // requires the input to wait at least once)
2004        let schema = batch1.schema();
2005        let expected_batches = vec![batch1.clone(), batch2.clone()];
2006        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
2007        let partitioning = Partitioning::RoundRobinBatch(1);
2008
2009        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2010
2011        assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
2012        +------------------+
2013        | my_awesome_field |
2014        +------------------+
2015        | bar              |
2016        | baz              |
2017        | foo              |
2018        | frob             |
2019        +------------------+
2020        ");
2021
2022        let output_stream = exec.execute(0, task_ctx).unwrap();
2023        let batches = crate::common::collect(output_stream).await.unwrap();
2024
2025        assert_snapshot!(batches_to_sort_string(&batches), @r"
2026        +------------------+
2027        | my_awesome_field |
2028        +------------------+
2029        | bar              |
2030        | baz              |
2031        | foo              |
2032        | frob             |
2033        +------------------+
2034        ");
2035    }
2036
2037    #[tokio::test]
2038    async fn robin_repartition_with_dropping_output_stream() {
2039        let task_ctx = Arc::new(TaskContext::default());
2040        let partitioning = Partitioning::RoundRobinBatch(2);
2041        // The barrier exec waits to be pinged
2042        // requires the input to wait at least once)
2043        let input = Arc::new(make_barrier_exec());
2044
2045        // partition into two output streams
2046        let exec = RepartitionExec::try_new(
2047            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2048            partitioning,
2049        )
2050        .unwrap();
2051
2052        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2053        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2054
2055        // now, purposely drop output stream 0
2056        // *before* any outputs are produced
2057        drop(output_stream0);
2058
2059        // Now, start sending input
2060        let mut background_task = JoinSet::new();
2061        background_task.spawn(async move {
2062            input.wait().await;
2063        });
2064
2065        // output stream 1 should *not* error and have one of the input batches
2066        let batches = crate::common::collect(output_stream1).await.unwrap();
2067
2068        assert_snapshot!(batches_to_sort_string(&batches), @r"
2069        +------------------+
2070        | my_awesome_field |
2071        +------------------+
2072        | baz              |
2073        | frob             |
2074        | gar              |
2075        | goo              |
2076        +------------------+
2077        ");
2078    }
2079
2080    #[tokio::test]
2081    // As the hash results might be different on different platforms or
2082    // with different compilers, we will compare the same execution with
2083    // and without dropping the output stream.
2084    async fn hash_repartition_with_dropping_output_stream() {
2085        let task_ctx = Arc::new(TaskContext::default());
2086        let partitioning = Partitioning::Hash(
2087            vec![Arc::new(crate::expressions::Column::new(
2088                "my_awesome_field",
2089                0,
2090            ))],
2091            2,
2092        );
2093
2094        // We first collect the results without dropping the output stream.
2095        let input = Arc::new(make_barrier_exec());
2096        let exec = RepartitionExec::try_new(
2097            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2098            partitioning.clone(),
2099        )
2100        .unwrap();
2101        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2102        let mut background_task = JoinSet::new();
2103        background_task.spawn(async move {
2104            input.wait().await;
2105        });
2106        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
2107
2108        // run some checks on the result
2109        let items_vec = str_batches_to_vec(&batches_without_drop);
2110        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
2111        assert_eq!(items_vec.len(), items_set.len());
2112        let source_str_set: HashSet<&str> =
2113            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
2114                .iter()
2115                .copied()
2116                .collect();
2117        assert_eq!(items_set.difference(&source_str_set).count(), 0);
2118
2119        // Now do the same but dropping the stream before waiting for the barrier
2120        let input = Arc::new(make_barrier_exec());
2121        let exec = RepartitionExec::try_new(
2122            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2123            partitioning,
2124        )
2125        .unwrap();
2126        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2127        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2128        // now, purposely drop output stream 0
2129        // *before* any outputs are produced
2130        drop(output_stream0);
2131        let mut background_task = JoinSet::new();
2132        background_task.spawn(async move {
2133            input.wait().await;
2134        });
2135        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
2136
2137        let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2138        let items_set_with_drop: HashSet<&str> =
2139            items_vec_with_drop.iter().copied().collect();
2140        assert_eq!(
2141            items_set_with_drop.symmetric_difference(&items_set).count(),
2142            0
2143        );
2144    }
2145
2146    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
2147        batches
2148            .iter()
2149            .flat_map(|batch| {
2150                assert_eq!(batch.columns().len(), 1);
2151                let string_array = as_string_array(batch.column(0))
2152                    .expect("Unexpected type for repartitioned batch");
2153
2154                string_array
2155                    .iter()
2156                    .map(|v| v.expect("Unexpected null"))
2157                    .collect::<Vec<_>>()
2158            })
2159            .collect::<Vec<_>>()
2160    }
2161
2162    /// Create a BarrierExec that returns two partitions of two batches each
2163    fn make_barrier_exec() -> BarrierExec {
2164        let batch1 = RecordBatch::try_from_iter(vec![(
2165            "my_awesome_field",
2166            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2167        )])
2168        .unwrap();
2169
2170        let batch2 = RecordBatch::try_from_iter(vec![(
2171            "my_awesome_field",
2172            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2173        )])
2174        .unwrap();
2175
2176        let batch3 = RecordBatch::try_from_iter(vec![(
2177            "my_awesome_field",
2178            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
2179        )])
2180        .unwrap();
2181
2182        let batch4 = RecordBatch::try_from_iter(vec![(
2183            "my_awesome_field",
2184            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
2185        )])
2186        .unwrap();
2187
2188        // The barrier exec waits to be pinged
2189        // requires the input to wait at least once)
2190        let schema = batch1.schema();
2191        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
2192    }
2193
2194    #[tokio::test]
2195    async fn test_drop_cancel() -> Result<()> {
2196        let task_ctx = Arc::new(TaskContext::default());
2197        let schema =
2198            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2199
2200        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
2201        let refs = blocking_exec.refs();
2202        let repartition_exec = Arc::new(RepartitionExec::try_new(
2203            blocking_exec,
2204            Partitioning::UnknownPartitioning(1),
2205        )?);
2206
2207        let fut = collect(repartition_exec, task_ctx);
2208        let mut fut = fut.boxed();
2209
2210        assert_is_pending(&mut fut);
2211        drop(fut);
2212        assert_strong_count_converges_to_zero(refs).await;
2213
2214        Ok(())
2215    }
2216
2217    #[tokio::test]
2218    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
2219        let task_ctx = Arc::new(TaskContext::default());
2220        let batch = RecordBatch::try_from_iter(vec![(
2221            "a",
2222            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2223        )])
2224        .unwrap();
2225        let partitioning = Partitioning::Hash(
2226            vec![Arc::new(crate::expressions::Column::new("a", 0))],
2227            2,
2228        );
2229        let schema = batch.schema();
2230        let input = MockExec::new(vec![Ok(batch)], schema);
2231        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2232        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2233        let batch0 = crate::common::collect(output_stream0).await.unwrap();
2234        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2235        let batch1 = crate::common::collect(output_stream1).await.unwrap();
2236        assert!(batch0.is_empty() || batch1.is_empty());
2237        Ok(())
2238    }
2239
2240    #[tokio::test]
2241    async fn repartition_with_spilling() -> Result<()> {
2242        // Test that repartition successfully spills to disk when memory is constrained
2243        let schema = test_schema();
2244        let partition = create_vec_batches(50);
2245        let input_partitions = vec![partition];
2246        let partitioning = Partitioning::RoundRobinBatch(4);
2247
2248        // Set up context with very tight memory limit to force spilling
2249        let runtime = RuntimeEnvBuilder::default()
2250            .with_memory_limit(1, 1.0)
2251            .build_arc()?;
2252
2253        let task_ctx = TaskContext::default().with_runtime(runtime);
2254        let task_ctx = Arc::new(task_ctx);
2255
2256        // create physical plan
2257        let exec =
2258            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2259        let exec = RepartitionExec::try_new(exec, partitioning)?;
2260
2261        // Collect all partitions - should succeed by spilling to disk
2262        let mut total_rows = 0;
2263        for i in 0..exec.partitioning().partition_count() {
2264            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2265            while let Some(result) = stream.next().await {
2266                let batch = result?;
2267                total_rows += batch.num_rows();
2268            }
2269        }
2270
2271        // Verify we got all the data (50 batches * 8 rows each)
2272        assert_eq!(total_rows, 50 * 8);
2273
2274        // Verify spilling metrics to confirm spilling actually happened
2275        let metrics = exec.metrics().unwrap();
2276        assert!(
2277            metrics.spill_count().unwrap() > 0,
2278            "Expected spill_count > 0, but got {:?}",
2279            metrics.spill_count()
2280        );
2281        println!("Spilled {} times", metrics.spill_count().unwrap());
2282        assert!(
2283            metrics.spilled_bytes().unwrap() > 0,
2284            "Expected spilled_bytes > 0, but got {:?}",
2285            metrics.spilled_bytes()
2286        );
2287        println!(
2288            "Spilled {} bytes in {} spills",
2289            metrics.spilled_bytes().unwrap(),
2290            metrics.spill_count().unwrap()
2291        );
2292        assert!(
2293            metrics.spilled_rows().unwrap() > 0,
2294            "Expected spilled_rows > 0, but got {:?}",
2295            metrics.spilled_rows()
2296        );
2297        println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2298
2299        Ok(())
2300    }
2301
2302    #[tokio::test]
2303    async fn repartition_with_partial_spilling() -> Result<()> {
2304        // Test that repartition can handle partial spilling (some batches in memory, some spilled)
2305        let schema = test_schema();
2306        let partition = create_vec_batches(50);
2307        let input_partitions = vec![partition];
2308        let partitioning = Partitioning::RoundRobinBatch(4);
2309
2310        // Set up context with moderate memory limit to force partial spilling
2311        // 2KB should allow some batches in memory but force others to spill
2312        let runtime = RuntimeEnvBuilder::default()
2313            .with_memory_limit(2 * 1024, 1.0)
2314            .build_arc()?;
2315
2316        let task_ctx = TaskContext::default().with_runtime(runtime);
2317        let task_ctx = Arc::new(task_ctx);
2318
2319        // create physical plan
2320        let exec =
2321            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2322        let exec = RepartitionExec::try_new(exec, partitioning)?;
2323
2324        // Collect all partitions - should succeed with partial spilling
2325        let mut total_rows = 0;
2326        for i in 0..exec.partitioning().partition_count() {
2327            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2328            while let Some(result) = stream.next().await {
2329                let batch = result?;
2330                total_rows += batch.num_rows();
2331            }
2332        }
2333
2334        // Verify we got all the data (50 batches * 8 rows each)
2335        assert_eq!(total_rows, 50 * 8);
2336
2337        // Verify partial spilling metrics
2338        let metrics = exec.metrics().unwrap();
2339        let spill_count = metrics.spill_count().unwrap();
2340        let spilled_rows = metrics.spilled_rows().unwrap();
2341        let spilled_bytes = metrics.spilled_bytes().unwrap();
2342
2343        assert!(
2344            spill_count > 0,
2345            "Expected some spilling to occur, but got spill_count={spill_count}"
2346        );
2347        assert!(
2348            spilled_rows > 0 && spilled_rows < total_rows,
2349            "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2350        );
2351        assert!(
2352            spilled_bytes > 0,
2353            "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2354        );
2355
2356        println!(
2357            "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2358            spilled_rows,
2359            total_rows,
2360            (spilled_rows as f64 / total_rows as f64) * 100.0,
2361            spill_count,
2362            spilled_bytes
2363        );
2364
2365        Ok(())
2366    }
2367
2368    #[tokio::test]
2369    async fn repartition_without_spilling() -> Result<()> {
2370        // Test that repartition does not spill when there's ample memory
2371        let schema = test_schema();
2372        let partition = create_vec_batches(50);
2373        let input_partitions = vec![partition];
2374        let partitioning = Partitioning::RoundRobinBatch(4);
2375
2376        // Set up context with generous memory limit - no spilling should occur
2377        let runtime = RuntimeEnvBuilder::default()
2378            .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB
2379            .build_arc()?;
2380
2381        let task_ctx = TaskContext::default().with_runtime(runtime);
2382        let task_ctx = Arc::new(task_ctx);
2383
2384        // create physical plan
2385        let exec =
2386            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2387        let exec = RepartitionExec::try_new(exec, partitioning)?;
2388
2389        // Collect all partitions - should succeed without spilling
2390        let mut total_rows = 0;
2391        for i in 0..exec.partitioning().partition_count() {
2392            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2393            while let Some(result) = stream.next().await {
2394                let batch = result?;
2395                total_rows += batch.num_rows();
2396            }
2397        }
2398
2399        // Verify we got all the data (50 batches * 8 rows each)
2400        assert_eq!(total_rows, 50 * 8);
2401
2402        // Verify no spilling occurred
2403        let metrics = exec.metrics().unwrap();
2404        assert_eq!(
2405            metrics.spill_count(),
2406            Some(0),
2407            "Expected no spilling, but got spill_count={:?}",
2408            metrics.spill_count()
2409        );
2410        assert_eq!(
2411            metrics.spilled_bytes(),
2412            Some(0),
2413            "Expected no bytes spilled, but got spilled_bytes={:?}",
2414            metrics.spilled_bytes()
2415        );
2416        assert_eq!(
2417            metrics.spilled_rows(),
2418            Some(0),
2419            "Expected no rows spilled, but got spilled_rows={:?}",
2420            metrics.spilled_rows()
2421        );
2422
2423        println!("No spilling occurred - all data processed in memory");
2424
2425        Ok(())
2426    }
2427
2428    #[tokio::test]
2429    async fn oom() -> Result<()> {
2430        use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2431
2432        // Test that repartition fails with OOM when disk manager is disabled
2433        let schema = test_schema();
2434        let partition = create_vec_batches(50);
2435        let input_partitions = vec![partition];
2436        let partitioning = Partitioning::RoundRobinBatch(4);
2437
2438        // Setup context with memory limit but NO disk manager (explicitly disabled)
2439        let runtime = RuntimeEnvBuilder::default()
2440            .with_memory_limit(1, 1.0)
2441            .with_disk_manager_builder(
2442                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2443            )
2444            .build_arc()?;
2445
2446        let task_ctx = TaskContext::default().with_runtime(runtime);
2447        let task_ctx = Arc::new(task_ctx);
2448
2449        // create physical plan
2450        let exec =
2451            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2452        let exec = RepartitionExec::try_new(exec, partitioning)?;
2453
2454        // Attempt to execute - should fail with ResourcesExhausted error
2455        for i in 0..exec.partitioning().partition_count() {
2456            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2457            let err = stream.next().await.unwrap().unwrap_err();
2458            let err = err.find_root();
2459            assert!(
2460                matches!(err, DataFusionError::ResourcesExhausted(_)),
2461                "Wrong error type: {err}",
2462            );
2463        }
2464
2465        Ok(())
2466    }
2467
2468    /// Create vector batches
2469    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2470        let batch = create_batch();
2471        (0..n).map(|_| batch.clone()).collect()
2472    }
2473
2474    /// Create batch
2475    fn create_batch() -> RecordBatch {
2476        let schema = test_schema();
2477        RecordBatch::try_new(
2478            schema,
2479            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2480        )
2481        .unwrap()
2482    }
2483
2484    /// Create batches with sequential values for ordering tests
2485    fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2486        let schema = test_schema();
2487        (0..num_batches)
2488            .map(|i| {
2489                let start = (i * 8) as u32;
2490                RecordBatch::try_new(
2491                    Arc::clone(&schema),
2492                    vec![Arc::new(UInt32Array::from(
2493                        (start..start + 8).collect::<Vec<_>>(),
2494                    ))],
2495                )
2496                .unwrap()
2497            })
2498            .collect()
2499    }
2500
2501    #[tokio::test]
2502    async fn test_repartition_ordering_with_spilling() -> Result<()> {
2503        // Test that repartition preserves ordering when spilling occurs
2504        // This tests the state machine fix where we must block on spill_stream
2505        // when a Spilled marker is received, rather than continuing to poll the channel
2506
2507        let schema = test_schema();
2508        // Create batches with sequential values: batch 0 has [0,1,2,3,4,5,6,7],
2509        // batch 1 has [8,9,10,11,12,13,14,15], etc.
2510        let partition = create_ordered_batches(20);
2511        let input_partitions = vec![partition];
2512
2513        // Use RoundRobinBatch to ensure predictable ordering
2514        let partitioning = Partitioning::RoundRobinBatch(2);
2515
2516        // Set up context with very tight memory limit to force spilling
2517        let runtime = RuntimeEnvBuilder::default()
2518            .with_memory_limit(1, 1.0)
2519            .build_arc()?;
2520
2521        let task_ctx = TaskContext::default().with_runtime(runtime);
2522        let task_ctx = Arc::new(task_ctx);
2523
2524        // create physical plan
2525        let exec =
2526            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2527        let exec = RepartitionExec::try_new(exec, partitioning)?;
2528
2529        // Collect all output partitions
2530        let mut all_batches = Vec::new();
2531        for i in 0..exec.partitioning().partition_count() {
2532            let mut partition_batches = Vec::new();
2533            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2534            while let Some(result) = stream.next().await {
2535                let batch = result?;
2536                partition_batches.push(batch);
2537            }
2538            all_batches.push(partition_batches);
2539        }
2540
2541        // Verify spilling occurred
2542        let metrics = exec.metrics().unwrap();
2543        assert!(
2544            metrics.spill_count().unwrap() > 0,
2545            "Expected spilling to occur, but spill_count = 0"
2546        );
2547
2548        // Verify ordering is preserved within each partition
2549        // With RoundRobinBatch, even batches go to partition 0, odd batches to partition 1
2550        for (partition_idx, batches) in all_batches.iter().enumerate() {
2551            let mut last_value = None;
2552            for batch in batches {
2553                let array = batch
2554                    .column(0)
2555                    .as_any()
2556                    .downcast_ref::<UInt32Array>()
2557                    .unwrap();
2558
2559                for i in 0..array.len() {
2560                    let value = array.value(i);
2561                    if let Some(last) = last_value {
2562                        assert!(
2563                            value > last,
2564                            "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2565                        );
2566                    }
2567                    last_value = Some(value);
2568                }
2569            }
2570        }
2571
2572        Ok(())
2573    }
2574}
2575
2576#[cfg(test)]
2577mod test {
2578    use arrow::array::record_batch;
2579    use arrow::compute::SortOptions;
2580    use arrow::datatypes::{DataType, Field, Schema};
2581    use datafusion_common::assert_batches_eq;
2582
2583    use super::*;
2584    use crate::test::TestMemoryExec;
2585    use crate::union::UnionExec;
2586
2587    use datafusion_physical_expr::expressions::col;
2588    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2589
2590    /// Asserts that the plan is as expected
2591    ///
2592    /// `$EXPECTED_PLAN_LINES`: input plan
2593    /// `$PLAN`: the plan to optimized
2594    macro_rules! assert_plan {
2595        ($PLAN: expr,  @ $EXPECTED: expr) => {
2596            let formatted = crate::displayable($PLAN).indent(true).to_string();
2597
2598            insta::assert_snapshot!(
2599                formatted,
2600                @$EXPECTED
2601            );
2602        };
2603    }
2604
2605    #[tokio::test]
2606    async fn test_preserve_order() -> Result<()> {
2607        let schema = test_schema();
2608        let sort_exprs = sort_exprs(&schema);
2609        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2610        let source2 = sorted_memory_exec(&schema, sort_exprs);
2611        // output has multiple partitions, and is sorted
2612        let union = UnionExec::try_new(vec![source1, source2])?;
2613        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2614            .with_preserve_order();
2615
2616        // Repartition should preserve order
2617        assert_plan!(&exec, @r"
2618        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2619          UnionExec
2620            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2621            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2622        ");
2623        Ok(())
2624    }
2625
2626    #[tokio::test]
2627    async fn test_preserve_order_one_partition() -> Result<()> {
2628        let schema = test_schema();
2629        let sort_exprs = sort_exprs(&schema);
2630        let source = sorted_memory_exec(&schema, sort_exprs);
2631        // output is sorted, but has only a single partition, so no need to sort
2632        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2633            .with_preserve_order();
2634
2635        // Repartition should not preserve order
2636        assert_plan!(&exec, @r"
2637        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
2638          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2639        ");
2640
2641        Ok(())
2642    }
2643
2644    #[tokio::test]
2645    async fn test_preserve_order_input_not_sorted() -> Result<()> {
2646        let schema = test_schema();
2647        let source1 = memory_exec(&schema);
2648        let source2 = memory_exec(&schema);
2649        // output has multiple partitions, but is not sorted
2650        let union = UnionExec::try_new(vec![source1, source2])?;
2651        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2652            .with_preserve_order();
2653
2654        // Repartition should not preserve order, as there is no order to preserve
2655        assert_plan!(&exec, @r"
2656        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2657          UnionExec
2658            DataSourceExec: partitions=1, partition_sizes=[0]
2659            DataSourceExec: partitions=1, partition_sizes=[0]
2660        ");
2661        Ok(())
2662    }
2663
2664    #[tokio::test]
2665    async fn test_preserve_order_with_spilling() -> Result<()> {
2666        use datafusion_execution::TaskContext;
2667        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2668
2669        // Create sorted input data across multiple partitions
2670        // Partition1: [1,3], [5,7], [9,11]
2671        // Partition2: [2,4], [6,8], [10,12]
2672        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2673        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2674        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2675        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2676        let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
2677        let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
2678        let schema = batch1.schema();
2679        let sort_exprs = LexOrdering::new([PhysicalSortExpr {
2680            expr: col("c0", &schema).unwrap(),
2681            options: SortOptions::default().asc(),
2682        }])
2683        .unwrap();
2684        let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
2685        let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
2686        let input_partitions = vec![partition1, partition2];
2687
2688        // Set up context with tight memory limit to force spilling
2689        // Sorting needs some non-spillable memory, so 64 bytes should force spilling while still allowing the query to complete
2690        let runtime = RuntimeEnvBuilder::default()
2691            .with_memory_limit(64, 1.0)
2692            .build_arc()?;
2693
2694        let task_ctx = TaskContext::default().with_runtime(runtime);
2695        let task_ctx = Arc::new(task_ctx);
2696
2697        // Create physical plan with order preservation
2698        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
2699            .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
2700        let exec = Arc::new(exec);
2701        let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2702        // Repartition into 3 partitions with order preservation
2703        // We expect 1 batch per output partition after repartitioning
2704        let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
2705            .with_preserve_order();
2706
2707        let mut batches = vec![];
2708
2709        // Collect all partitions - should succeed by spilling to disk
2710        for i in 0..exec.partitioning().partition_count() {
2711            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2712            while let Some(result) = stream.next().await {
2713                let batch = result?;
2714                batches.push(batch);
2715            }
2716        }
2717
2718        #[rustfmt::skip]
2719        let expected = [
2720            [
2721                "+----+",
2722                "| c0 |",
2723                "+----+",
2724                "| 1  |",
2725                "| 2  |",
2726                "| 3  |",
2727                "| 4  |",
2728                "+----+",
2729            ],
2730            [
2731                "+----+",
2732                "| c0 |",
2733                "+----+",
2734                "| 5  |",
2735                "| 6  |",
2736                "| 7  |",
2737                "| 8  |",
2738                "+----+",
2739            ],
2740            [
2741                "+----+",
2742                "| c0 |",
2743                "+----+",
2744                "| 9  |",
2745                "| 10 |",
2746                "| 11 |",
2747                "| 12 |",
2748                "+----+",
2749            ],
2750        ];
2751
2752        for (batch, expected) in batches.iter().zip(expected.iter()) {
2753            assert_batches_eq!(expected, std::slice::from_ref(batch));
2754        }
2755
2756        // We should have spilled ~ all of the data.
2757        // - We spill data during the repartitioning phase
2758        // - We may also spill during the final merge sort
2759        let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
2760        let metrics = exec.metrics().unwrap();
2761        assert!(
2762            metrics.spill_count().unwrap() > input_partitions.len(),
2763            "Expected spill_count > {} for order-preserving repartition, but got {:?}",
2764            input_partitions.len(),
2765            metrics.spill_count()
2766        );
2767        assert!(
2768            metrics.spilled_bytes().unwrap()
2769                > all_batches
2770                    .iter()
2771                    .map(|b| b.get_array_memory_size())
2772                    .sum::<usize>(),
2773            "Expected spilled_bytes > {} for order-preserving repartition, got {}",
2774            all_batches
2775                .iter()
2776                .map(|b| b.get_array_memory_size())
2777                .sum::<usize>(),
2778            metrics.spilled_bytes().unwrap()
2779        );
2780        assert!(
2781            metrics.spilled_rows().unwrap()
2782                >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2783            "Expected spilled_rows > {} for order-preserving repartition, got {}",
2784            all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2785            metrics.spilled_rows().unwrap()
2786        );
2787
2788        Ok(())
2789    }
2790
2791    #[tokio::test]
2792    async fn test_hash_partitioning_with_spilling() -> Result<()> {
2793        use datafusion_execution::TaskContext;
2794        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2795
2796        // Create input data similar to the round-robin test
2797        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2798        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2799        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2800        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2801        let schema = batch1.schema();
2802
2803        let partition1 = vec![batch1.clone(), batch3.clone()];
2804        let partition2 = vec![batch2.clone(), batch4.clone()];
2805        let input_partitions = vec![partition1, partition2];
2806
2807        // Set up context with memory limit to test hash partitioning with spilling infrastructure
2808        let runtime = RuntimeEnvBuilder::default()
2809            .with_memory_limit(1, 1.0)
2810            .build_arc()?;
2811
2812        let task_ctx = TaskContext::default().with_runtime(runtime);
2813        let task_ctx = Arc::new(task_ctx);
2814
2815        // Create physical plan with hash partitioning
2816        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
2817        let exec = Arc::new(exec);
2818        let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2819        // Hash partition into 2 partitions by column c0
2820        let hash_expr = col("c0", &schema)?;
2821        let exec =
2822            RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
2823
2824        // Collect all partitions concurrently using JoinSet - this prevents deadlock
2825        // where the distribution channel gate closes when all output channels are full
2826        let mut join_set = tokio::task::JoinSet::new();
2827        for i in 0..exec.partitioning().partition_count() {
2828            let stream = exec.execute(i, Arc::clone(&task_ctx))?;
2829            join_set.spawn(async move {
2830                let mut count = 0;
2831                futures::pin_mut!(stream);
2832                while let Some(result) = stream.next().await {
2833                    let batch = result?;
2834                    count += batch.num_rows();
2835                }
2836                Ok::<usize, DataFusionError>(count)
2837            });
2838        }
2839
2840        // Wait for all partitions and sum the rows
2841        let mut total_rows = 0;
2842        while let Some(result) = join_set.join_next().await {
2843            total_rows += result.unwrap()?;
2844        }
2845
2846        // Verify we got all rows back
2847        let all_batches = [batch1, batch2, batch3, batch4];
2848        let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
2849        assert_eq!(total_rows, expected_rows);
2850
2851        // Verify metrics are available
2852        let metrics = exec.metrics().unwrap();
2853        // Just verify the metrics can be retrieved (spilling may or may not occur)
2854        let spill_count = metrics.spill_count().unwrap_or(0);
2855        assert!(spill_count > 0);
2856        let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
2857        assert!(spilled_bytes > 0);
2858        let spilled_rows = metrics.spilled_rows().unwrap_or(0);
2859        assert!(spilled_rows > 0);
2860
2861        Ok(())
2862    }
2863
2864    #[tokio::test]
2865    async fn test_repartition() -> Result<()> {
2866        let schema = test_schema();
2867        let sort_exprs = sort_exprs(&schema);
2868        let source = sorted_memory_exec(&schema, sort_exprs);
2869        // output is sorted, but has only a single partition, so no need to sort
2870        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2871            .repartitioned(20, &Default::default())?
2872            .unwrap();
2873
2874        // Repartition should not preserve order
2875        assert_plan!(exec.as_ref(), @r"
2876        RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
2877          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2878        ");
2879        Ok(())
2880    }
2881
2882    fn test_schema() -> Arc<Schema> {
2883        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2884    }
2885
2886    fn sort_exprs(schema: &Schema) -> LexOrdering {
2887        [PhysicalSortExpr {
2888            expr: col("c0", schema).unwrap(),
2889            options: SortOptions::default(),
2890        }]
2891        .into()
2892    }
2893
2894    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2895        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2896    }
2897
2898    fn sorted_memory_exec(
2899        schema: &SchemaRef,
2900        sort_exprs: LexOrdering,
2901    ) -> Arc<dyn ExecutionPlan> {
2902        let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2903            .unwrap()
2904            .try_with_sort_information(vec![sort_exprs])
2905            .unwrap();
2906        let exec = Arc::new(exec);
2907        Arc::new(TestMemoryExec::update_cache(&exec))
2908    }
2909}