Skip to main content

datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::coalesce::LimitedBatchCoalescer;
34use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
35use crate::hash_utils::create_hashes;
36use crate::metrics::{BaselineMetrics, SpillMetrics};
37use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
38use crate::sorts::streaming_merge::StreamingMergeBuilder;
39use crate::spill::spill_manager::SpillManager;
40use crate::spill::spill_pool::{self, SpillPoolWriter};
41use crate::stream::RecordBatchStreamAdapter;
42use crate::{
43    DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics,
44    check_if_same_properties,
45};
46
47use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
48use arrow::compute::take_arrays;
49use arrow::datatypes::{SchemaRef, UInt32Type};
50use datafusion_common::config::ConfigOptions;
51use datafusion_common::stats::Precision;
52use datafusion_common::utils::transpose;
53use datafusion_common::{
54    ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err,
55    internal_datafusion_err, internal_err,
56};
57use datafusion_common::{Result, not_impl_err};
58use datafusion_common_runtime::SpawnedTask;
59use datafusion_execution::TaskContext;
60use datafusion_execution::memory_pool::MemoryConsumer;
61use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
62use datafusion_physical_expr_common::sort_expr::LexOrdering;
63
64use crate::filter_pushdown::{
65    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
66    FilterPushdownPropagation,
67};
68use crate::joins::SeededRandomState;
69use crate::sort_pushdown::SortOrderPushdownResult;
70use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
71use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
72use futures::stream::Stream;
73use futures::{FutureExt, StreamExt, TryStreamExt, ready};
74use log::trace;
75use parking_lot::Mutex;
76
77mod distributor_channels;
78use distributor_channels::{
79    DistributionReceiver, DistributionSender, channels, partition_aware_channels,
80};
81
82/// A batch in the repartition queue - either in memory or spilled to disk.
83///
84/// This enum represents the two states a batch can be in during repartitioning.
85/// The decision to spill is made based on memory availability when sending a batch
86/// to an output partition.
87///
88/// # Batch Flow with Spilling
89///
90/// ```text
91/// Input Stream ──▶ Partition Logic ──▶ try_grow()
92///                                            │
93///                            ┌───────────────┴────────────────┐
94///                            │                                │
95///                            ▼                                ▼
96///                   try_grow() succeeds            try_grow() fails
97///                   (Memory Available)              (Memory Pressure)
98///                            │                                │
99///                            ▼                                ▼
100///                  RepartitionBatch::Memory         spill_writer.push_batch()
101///                  (batch held in memory)           (batch written to disk)
102///                            │                                │
103///                            │                                ▼
104///                            │                      RepartitionBatch::Spilled
105///                            │                      (marker - no batch data)
106///                            │                                │
107///                            └────────┬───────────────────────┘
108///                                     │
109///                                     ▼
110///                              Send to channel
111///                                     │
112///                                     ▼
113///                            Output Stream (poll)
114///                                     │
115///                      ┌──────────────┴─────────────┐
116///                      │                            │
117///                      ▼                            ▼
118///         RepartitionBatch::Memory      RepartitionBatch::Spilled
119///         Return batch immediately       Poll spill_stream (blocks)
120///                      │                            │
121///                      └────────┬───────────────────┘
122///                               │
123///                               ▼
124///                          Return batch
125///                    (FIFO order preserved)
126/// ```
127///
128/// See [`RepartitionExec`] for overall architecture and [`StreamState`] for
129/// the state machine that handles reading these batches.
130#[derive(Debug)]
131enum RepartitionBatch {
132    /// Batch held in memory (counts against memory reservation)
133    Memory(RecordBatch),
134    /// Marker indicating a batch was spilled to the partition's SpillPool.
135    /// The actual batch can be retrieved by reading from the SpillPoolStream.
136    /// This variant contains no data itself - it's just a signal to the reader
137    /// to fetch the next batch from the spill stream.
138    Spilled,
139}
140
141type MaybeBatch = Option<Result<RepartitionBatch>>;
142type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
143type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
144
145/// Output channel with its associated memory reservation and spill writer
146struct OutputChannel {
147    sender: DistributionSender<MaybeBatch>,
148    reservation: SharedMemoryReservation,
149    spill_writer: SpillPoolWriter,
150}
151
152/// Channels and resources for a single output partition.
153///
154/// Each output partition has channels to receive data from all input partitions.
155/// To handle memory pressure, each (input, output) pair gets its own
156/// [`SpillPool`](crate::spill::spill_pool) channel via [`spill_pool::channel`].
157///
158/// # Structure
159///
160/// For an output partition receiving from N input partitions:
161/// - `tx`: N senders (one per input) for sending batches to this output
162/// - `rx`: N receivers (one per input) for receiving batches at this output
163/// - `spill_writers`: N spill writers (one per input) for writing spilled data
164/// - `spill_readers`: N spill readers (one per input) for reading spilled data
165///
166/// This 1:1 mapping between input partitions and spill channels ensures that
167/// batches from each input are processed in FIFO order, even when some batches
168/// are spilled to disk and others remain in memory.
169///
170/// See [`RepartitionExec`] for the overall N×M architecture.
171///
172/// [`spill_pool::channel`]: crate::spill::spill_pool::channel
173struct PartitionChannels {
174    /// Senders for each input partition to send data to this output partition
175    tx: InputPartitionsToCurrentPartitionSender,
176    /// Receivers for each input partition sending data to this output partition
177    rx: InputPartitionsToCurrentPartitionReceiver,
178    /// Memory reservation for this output partition
179    reservation: SharedMemoryReservation,
180    /// Spill writers for writing spilled data.
181    /// SpillPoolWriter is Clone, so multiple writers can share state in non-preserve-order mode.
182    spill_writers: Vec<SpillPoolWriter>,
183    /// Spill readers for reading spilled data - one per input partition (FIFO semantics).
184    /// Each (input, output) pair gets its own reader to maintain proper ordering.
185    spill_readers: Vec<SendableRecordBatchStream>,
186}
187
188struct ConsumingInputStreamsState {
189    /// Channels for sending batches from input partitions to output partitions.
190    /// Key is the partition number.
191    channels: HashMap<usize, PartitionChannels>,
192
193    /// Helper that ensures that background jobs are killed once they are no longer needed.
194    abort_helper: Arc<Vec<SpawnedTask<()>>>,
195}
196
197impl Debug for ConsumingInputStreamsState {
198    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("ConsumingInputStreamsState")
200            .field("num_channels", &self.channels.len())
201            .field("abort_helper", &self.abort_helper)
202            .finish()
203    }
204}
205
206/// Inner state of [`RepartitionExec`].
207#[derive(Default)]
208enum RepartitionExecState {
209    /// Not initialized yet. This is the default state stored in the RepartitionExec node
210    /// upon instantiation.
211    #[default]
212    NotInitialized,
213    /// Input streams are initialized, but they are still not being consumed. The node
214    /// transitions to this state when the arrow's RecordBatch stream is created in
215    /// RepartitionExec::execute(), but before any message is polled.
216    InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
217    /// The input streams are being consumed. The node transitions to this state when
218    /// the first message in the arrow's RecordBatch stream is consumed.
219    ConsumingInputStreams(ConsumingInputStreamsState),
220}
221
222impl Debug for RepartitionExecState {
223    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
224        match self {
225            RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
226            RepartitionExecState::InputStreamsInitialized(v) => {
227                write!(f, "InputStreamsInitialized({:?})", v.len())
228            }
229            RepartitionExecState::ConsumingInputStreams(v) => {
230                write!(f, "ConsumingInputStreams({v:?})")
231            }
232        }
233    }
234}
235
236impl RepartitionExecState {
237    fn ensure_input_streams_initialized(
238        &mut self,
239        input: &Arc<dyn ExecutionPlan>,
240        metrics: &ExecutionPlanMetricsSet,
241        output_partitions: usize,
242        ctx: &Arc<TaskContext>,
243    ) -> Result<()> {
244        if !matches!(self, RepartitionExecState::NotInitialized) {
245            return Ok(());
246        }
247
248        let num_input_partitions = input.output_partitioning().partition_count();
249        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
250
251        for i in 0..num_input_partitions {
252            let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
253
254            let timer = metrics.fetch_time.timer();
255            let stream = input.execute(i, Arc::clone(ctx))?;
256            timer.done();
257
258            streams_and_metrics.push((stream, metrics));
259        }
260        *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
261        Ok(())
262    }
263
264    #[expect(clippy::too_many_arguments)]
265    fn consume_input_streams(
266        &mut self,
267        input: &Arc<dyn ExecutionPlan>,
268        metrics: &ExecutionPlanMetricsSet,
269        partitioning: &Partitioning,
270        preserve_order: bool,
271        name: &str,
272        context: &Arc<TaskContext>,
273        spill_manager: SpillManager,
274    ) -> Result<&mut ConsumingInputStreamsState> {
275        let streams_and_metrics = match self {
276            RepartitionExecState::NotInitialized => {
277                self.ensure_input_streams_initialized(
278                    input,
279                    metrics,
280                    partitioning.partition_count(),
281                    context,
282                )?;
283                let RepartitionExecState::InputStreamsInitialized(value) = self else {
284                    // This cannot happen, as ensure_input_streams_initialized() was just called,
285                    // but the compiler does not know.
286                    return internal_err!(
287                        "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
288                    );
289                };
290                value
291            }
292            RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
293            RepartitionExecState::InputStreamsInitialized(value) => value,
294        };
295
296        let num_input_partitions = streams_and_metrics.len();
297        let num_output_partitions = partitioning.partition_count();
298
299        let spill_manager = Arc::new(spill_manager);
300
301        let (txs, rxs) = if preserve_order {
302            // Create partition-aware channels with one channel per (input, output) pair
303            // This provides backpressure while maintaining proper ordering
304            let (txs_all, rxs_all) =
305                partition_aware_channels(num_input_partitions, num_output_partitions);
306            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
307            let txs = transpose(txs_all);
308            let rxs = transpose(rxs_all);
309            (txs, rxs)
310        } else {
311            // Create one channel per *output* partition with backpressure
312            let (txs, rxs) = channels(num_output_partitions);
313            // Clone sender for each input partitions
314            let txs = txs
315                .into_iter()
316                .map(|item| vec![item; num_input_partitions])
317                .collect::<Vec<_>>();
318            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
319            (txs, rxs)
320        };
321
322        let mut channels = HashMap::with_capacity(txs.len());
323        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
324            let reservation = Arc::new(Mutex::new(
325                MemoryConsumer::new(format!("{name}[{partition}]"))
326                    .with_can_spill(true)
327                    .register(context.memory_pool()),
328            ));
329
330            // Create spill channels based on mode:
331            // - preserve_order: one spill channel per (input, output) pair for proper FIFO ordering
332            // - non-preserve-order: one shared spill channel per output partition since all inputs
333            //   share the same receiver
334            let max_file_size = context
335                .session_config()
336                .options()
337                .execution
338                .max_spill_file_size_bytes;
339            let num_spill_channels = if preserve_order {
340                num_input_partitions
341            } else {
342                1
343            };
344            let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
345                ..num_spill_channels)
346                .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
347                .unzip();
348
349            channels.insert(
350                partition,
351                PartitionChannels {
352                    tx,
353                    rx,
354                    reservation,
355                    spill_readers,
356                    spill_writers,
357                },
358            );
359        }
360
361        // launch one async task per *input* partition
362        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
363        for (i, (stream, metrics)) in
364            std::mem::take(streams_and_metrics).into_iter().enumerate()
365        {
366            let txs: HashMap<_, _> = channels
367                .iter()
368                .map(|(partition, channels)| {
369                    // In preserve_order mode: each input gets its own spill writer (index i)
370                    // In non-preserve-order mode: all inputs share spill writer 0 via clone
371                    let spill_writer_idx = if preserve_order { i } else { 0 };
372                    (
373                        *partition,
374                        OutputChannel {
375                            sender: channels.tx[i].clone(),
376                            reservation: Arc::clone(&channels.reservation),
377                            spill_writer: channels.spill_writers[spill_writer_idx]
378                                .clone(),
379                        },
380                    )
381                })
382                .collect();
383
384            // Extract senders for wait_for_task before moving txs
385            let senders: HashMap<_, _> = txs
386                .iter()
387                .map(|(partition, channel)| (*partition, channel.sender.clone()))
388                .collect();
389
390            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
391                stream,
392                txs,
393                partitioning.clone(),
394                metrics,
395                // preserve_order depends on partition index to start from 0
396                if preserve_order { 0 } else { i },
397                num_input_partitions,
398            ));
399
400            // In a separate task, wait for each input to be done
401            // (and pass along any errors, including panic!s)
402            let wait_for_task =
403                SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
404            spawned_tasks.push(wait_for_task);
405        }
406        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
407            channels,
408            abort_helper: Arc::new(spawned_tasks),
409        });
410        match self {
411            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
412            _ => unreachable!(),
413        }
414    }
415}
416
417/// A utility that can be used to partition batches based on [`Partitioning`]
418pub struct BatchPartitioner {
419    state: BatchPartitionerState,
420    timer: metrics::Time,
421}
422
423enum BatchPartitionerState {
424    Hash {
425        exprs: Vec<Arc<dyn PhysicalExpr>>,
426        num_partitions: usize,
427        hash_buffer: Vec<u64>,
428        indices: Vec<Vec<u32>>,
429    },
430    RoundRobin {
431        num_partitions: usize,
432        next_idx: usize,
433    },
434}
435
436/// Fixed RandomState used for hash repartitioning to ensure consistent behavior across
437/// executions and runs.
438pub const REPARTITION_RANDOM_STATE: SeededRandomState =
439    SeededRandomState::with_seeds(0, 0, 0, 0);
440
441impl BatchPartitioner {
442    /// Create a new [`BatchPartitioner`] for hash-based repartitioning.
443    ///
444    /// # Parameters
445    /// - `exprs`: Expressions used to compute the hash for each input row.
446    /// - `num_partitions`: Total number of output partitions.
447    /// - `timer`: Metric used to record time spent during repartitioning.
448    ///
449    /// # Notes
450    /// This constructor cannot fail and performs no validation.
451    pub fn new_hash_partitioner(
452        exprs: Vec<Arc<dyn PhysicalExpr>>,
453        num_partitions: usize,
454        timer: metrics::Time,
455    ) -> Self {
456        Self {
457            state: BatchPartitionerState::Hash {
458                exprs,
459                num_partitions,
460                hash_buffer: vec![],
461                indices: vec![vec![]; num_partitions],
462            },
463            timer,
464        }
465    }
466
467    /// Create a new [`BatchPartitioner`] for round-robin repartitioning.
468    ///
469    /// # Parameters
470    /// - `num_partitions`: Total number of output partitions.
471    /// - `timer`: Metric used to record time spent during repartitioning.
472    /// - `input_partition`: Index of the current input partition.
473    /// - `num_input_partitions`: Total number of input partitions.
474    ///
475    /// # Notes
476    /// The starting output partition is derived from the input partition
477    /// to avoid skew when multiple input partitions are used.
478    pub fn new_round_robin_partitioner(
479        num_partitions: usize,
480        timer: metrics::Time,
481        input_partition: usize,
482        num_input_partitions: usize,
483    ) -> Self {
484        Self {
485            state: BatchPartitionerState::RoundRobin {
486                num_partitions,
487                next_idx: (input_partition * num_partitions) / num_input_partitions,
488            },
489            timer,
490        }
491    }
492    /// Create a new [`BatchPartitioner`] based on the provided [`Partitioning`] scheme.
493    ///
494    /// This is a convenience constructor that delegates to the specialized
495    /// hash or round-robin constructors depending on the partitioning variant.
496    ///
497    /// # Parameters
498    /// - `partitioning`: Partitioning scheme to apply (hash or round-robin).
499    /// - `timer`: Metric used to record time spent during repartitioning.
500    /// - `input_partition`: Index of the current input partition.
501    /// - `num_input_partitions`: Total number of input partitions.
502    ///
503    /// # Errors
504    /// Returns an error if the provided partitioning scheme is not supported.
505    pub fn try_new(
506        partitioning: Partitioning,
507        timer: metrics::Time,
508        input_partition: usize,
509        num_input_partitions: usize,
510    ) -> Result<Self> {
511        match partitioning {
512            Partitioning::Hash(exprs, num_partitions) => {
513                Ok(Self::new_hash_partitioner(exprs, num_partitions, timer))
514            }
515            Partitioning::RoundRobinBatch(num_partitions) => {
516                Ok(Self::new_round_robin_partitioner(
517                    num_partitions,
518                    timer,
519                    input_partition,
520                    num_input_partitions,
521                ))
522            }
523            other => {
524                not_impl_err!("Unsupported repartitioning scheme {other:?}")
525            }
526        }
527    }
528
529    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
530    /// based on the [`Partitioning`] specified on construction
531    ///
532    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
533    /// partition index. Any error returned by `f` will be immediately returned by this
534    /// function without attempting to publish further [`RecordBatch`]
535    ///
536    /// The time spent repartitioning, not including time spent in `f` will be recorded
537    /// to the [`metrics::Time`] provided on construction
538    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
539    where
540        F: FnMut(usize, RecordBatch) -> Result<()>,
541    {
542        self.partition_iter(batch)?.try_for_each(|res| match res {
543            Ok((partition, batch)) => f(partition, batch),
544            Err(e) => Err(e),
545        })
546    }
547
548    /// Actual implementation of [`partition`](Self::partition).
549    ///
550    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
551    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
552    /// this (so we don't need to clone the entire implementation).
553    fn partition_iter(
554        &mut self,
555        batch: RecordBatch,
556    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
557        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
558            match &mut self.state {
559                BatchPartitionerState::RoundRobin {
560                    num_partitions,
561                    next_idx,
562                } => {
563                    let idx = *next_idx;
564                    *next_idx = (*next_idx + 1) % *num_partitions;
565                    Box::new(std::iter::once(Ok((idx, batch))))
566                }
567                BatchPartitionerState::Hash {
568                    exprs,
569                    num_partitions: partitions,
570                    hash_buffer,
571                    indices,
572                } => {
573                    // Tracking time required for distributing indexes across output partitions
574                    let timer = self.timer.timer();
575
576                    let arrays =
577                        evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
578
579                    hash_buffer.clear();
580                    hash_buffer.resize(batch.num_rows(), 0);
581
582                    create_hashes(
583                        &arrays,
584                        REPARTITION_RANDOM_STATE.random_state(),
585                        hash_buffer,
586                    )?;
587
588                    indices.iter_mut().for_each(|v| v.clear());
589
590                    for (index, hash) in hash_buffer.iter().enumerate() {
591                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
592                    }
593
594                    // Finished building index-arrays for output partitions
595                    timer.done();
596
597                    // Borrowing partitioner timer to prevent moving `self` to closure
598                    let partitioner_timer = &self.timer;
599
600                    let mut partitioned_batches = vec![];
601                    for (partition, p_indices) in indices.iter_mut().enumerate() {
602                        if !p_indices.is_empty() {
603                            let taken_indices = std::mem::take(p_indices);
604                            let indices_array: PrimitiveArray<UInt32Type> =
605                                taken_indices.into();
606
607                            // Tracking time required for repartitioned batches construction
608                            let _timer = partitioner_timer.timer();
609
610                            // Produce batches based on indices
611                            let columns =
612                                take_arrays(batch.columns(), &indices_array, None)?;
613
614                            let mut options = RecordBatchOptions::new();
615                            options = options.with_row_count(Some(indices_array.len()));
616                            let batch = RecordBatch::try_new_with_options(
617                                batch.schema(),
618                                columns,
619                                &options,
620                            )
621                            .unwrap();
622
623                            partitioned_batches.push(Ok((partition, batch)));
624
625                            // Return the taken vec
626                            let (_, buffer, _) = indices_array.into_parts();
627                            let mut vec =
628                                buffer.into_inner().into_vec::<u32>().map_err(|e| {
629                                    internal_datafusion_err!(
630                                        "Could not convert buffer to vec: {e:?}"
631                                    )
632                                })?;
633                            vec.clear();
634                            *p_indices = vec;
635                        }
636                    }
637
638                    Box::new(partitioned_batches.into_iter())
639                }
640            };
641
642        Ok(it)
643    }
644
645    // return the number of output partitions
646    fn num_partitions(&self) -> usize {
647        match self.state {
648            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
649            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
650        }
651    }
652}
653
654/// Maps `N` input partitions to `M` output partitions based on a
655/// [`Partitioning`] scheme.
656///
657/// # Background
658///
659/// DataFusion, like most other commercial systems, with the
660/// notable exception of DuckDB, uses the "Exchange Operator" based
661/// approach to parallelism which works well in practice given
662/// sufficient care in implementation.
663///
664/// DataFusion's planner picks the target number of partitions and
665/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
666/// of output partitions.
667///
668/// For example, given `target_partitions=3` (trying to use 3 cores)
669/// but scanning an input with 2 partitions, `RepartitionExec` can be
670/// used to get 3 even streams of `RecordBatch`es
671///
672///
673///```text
674///        ▲                  ▲                  ▲
675///        │                  │                  │
676///        │                  │                  │
677///        │                  │                  │
678/// ┌───────────────┐  ┌───────────────┐  ┌───────────────┐
679/// │    GroupBy    │  │    GroupBy    │  │    GroupBy    │
680/// │   (Partial)   │  │   (Partial)   │  │   (Partial)   │
681/// └───────────────┘  └───────────────┘  └───────────────┘
682///        ▲                  ▲                  ▲
683///        └──────────────────┼──────────────────┘
684///                           │
685///              ┌─────────────────────────┐
686///              │     RepartitionExec     │
687///              │   (hash/round robin)    │
688///              └─────────────────────────┘
689///                         ▲   ▲
690///             ┌───────────┘   └───────────┐
691///             │                           │
692///             │                           │
693///        .─────────.                 .─────────.
694///     ,─'           '─.           ,─'           '─.
695///    ;      Input      :         ;      Input      :
696///    :   Partition 0   ;         :   Partition 1   ;
697///     ╲               ╱           ╲               ╱
698///      '─.         ,─'             '─.         ,─'
699///         `───────'                   `───────'
700/// ```
701///
702/// # Error Handling
703///
704/// If any of the input partitions return an error, the error is propagated to
705/// all output partitions and inputs are not polled again.
706///
707/// # Output Ordering
708///
709/// If more than one stream is being repartitioned, the output will be some
710/// arbitrary interleaving (and thus unordered) unless
711/// [`Self::with_preserve_order`] specifies otherwise.
712///
713/// # Spilling Architecture
714///
715/// RepartitionExec uses [`SpillPool`](crate::spill::spill_pool) channels to handle
716/// memory pressure during repartitioning. Each (input partition, output partition)
717/// pair gets its own SpillPool channel for FIFO ordering.
718///
719/// ```text
720/// Input Partitions (N)          Output Partitions (M)
721/// ────────────────────          ─────────────────────
722///
723///    Input 0 ──┐                      ┌──▶ Output 0
724///              │  ┌──────────────┐    │
725///              ├─▶│ SpillPool    │────┤
726///              │  │ [In0→Out0]   │    │
727///    Input 1 ──┤  └──────────────┘    ├──▶ Output 1
728///              │                       │
729///              │  ┌──────────────┐    │
730///              ├─▶│ SpillPool    │────┤
731///              │  │ [In1→Out0]   │    │
732///    Input 2 ──┤  └──────────────┘    ├──▶ Output 2
733///              │                      │
734///              │       ... (N×M SpillPools total)
735///              │                      │
736///              │  ┌──────────────┐    │
737///              └─▶│ SpillPool    │────┘
738///                 │ [InN→OutM]   │
739///                 └──────────────┘
740///
741/// Each SpillPool maintains FIFO order for its (input, output) pair.
742/// See `RepartitionBatch` for details on the memory/spill decision logic.
743/// ```
744///
745/// # Footnote
746///
747/// The "Exchange Operator" was first described in the 1989 paper
748/// [Encapsulation of parallelism in the Volcano query processing
749/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
750/// which uses the term "Exchange" for the concept of repartitioning
751/// data across threads.
752///
753/// For more background, please also see the [Optimizing Repartitions in DataFusion] blog.
754///
755/// [Optimizing Repartitions in DataFusion]: https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions
756#[derive(Debug, Clone)]
757pub struct RepartitionExec {
758    /// Input execution plan
759    input: Arc<dyn ExecutionPlan>,
760    /// Inner state that is initialized when the parent calls .execute() on this node
761    /// and consumed as soon as the parent starts consuming this node.
762    state: Arc<Mutex<RepartitionExecState>>,
763    /// Execution metrics
764    metrics: ExecutionPlanMetricsSet,
765    /// Boolean flag to decide whether to preserve ordering. If true means
766    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
767    preserve_order: bool,
768    /// Cache holding plan properties like equivalences, output partitioning etc.
769    cache: Arc<PlanProperties>,
770}
771
772#[derive(Debug, Clone)]
773struct RepartitionMetrics {
774    /// Time in nanos to execute child operator and fetch batches
775    fetch_time: metrics::Time,
776    /// Repartitioning elapsed time in nanos
777    repartition_time: metrics::Time,
778    /// Time in nanos for sending resulting batches to channels.
779    ///
780    /// One metric per output partition.
781    send_time: Vec<metrics::Time>,
782}
783
784impl RepartitionMetrics {
785    pub fn new(
786        input_partition: usize,
787        num_output_partitions: usize,
788        metrics: &ExecutionPlanMetricsSet,
789    ) -> Self {
790        // Time in nanos to execute child operator and fetch batches
791        let fetch_time =
792            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
793
794        // Time in nanos to perform repartitioning
795        let repartition_time =
796            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
797
798        // Time in nanos for sending resulting batches to channels
799        let send_time = (0..num_output_partitions)
800            .map(|output_partition| {
801                let label =
802                    metrics::Label::new("outputPartition", output_partition.to_string());
803                MetricBuilder::new(metrics)
804                    .with_label(label)
805                    .subset_time("send_time", input_partition)
806            })
807            .collect();
808
809        Self {
810            fetch_time,
811            repartition_time,
812            send_time,
813        }
814    }
815}
816
817impl RepartitionExec {
818    /// Input execution plan
819    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
820        &self.input
821    }
822
823    /// Partitioning scheme to use
824    pub fn partitioning(&self) -> &Partitioning {
825        &self.cache.partitioning
826    }
827
828    /// Get preserve_order flag of the RepartitionExec
829    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
830    pub fn preserve_order(&self) -> bool {
831        self.preserve_order
832    }
833
834    /// Get name used to display this Exec
835    pub fn name(&self) -> &str {
836        "RepartitionExec"
837    }
838
839    fn with_new_children_and_same_properties(
840        &self,
841        mut children: Vec<Arc<dyn ExecutionPlan>>,
842    ) -> Self {
843        Self {
844            input: children.swap_remove(0),
845            metrics: ExecutionPlanMetricsSet::new(),
846            state: Default::default(),
847            ..Self::clone(self)
848        }
849    }
850}
851
852impl DisplayAs for RepartitionExec {
853    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
854        let input_partition_count = self.input.output_partitioning().partition_count();
855        match t {
856            DisplayFormatType::Default | DisplayFormatType::Verbose => {
857                write!(
858                    f,
859                    "{}: partitioning={}, input_partitions={}",
860                    self.name(),
861                    self.partitioning(),
862                    input_partition_count,
863                )?;
864
865                if self.preserve_order {
866                    write!(f, ", preserve_order=true")?;
867                } else if input_partition_count <= 1
868                    && self.input.output_ordering().is_some()
869                {
870                    // Make it explicit that repartition maintains sortedness for a single input partition even
871                    // when `preserve_sort order` is false
872                    write!(f, ", maintains_sort_order=true")?;
873                }
874
875                if let Some(sort_exprs) = self.sort_exprs() {
876                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
877                }
878                Ok(())
879            }
880            DisplayFormatType::TreeRender => {
881                writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
882                let output_partition_count = self.partitioning().partition_count();
883                let input_to_output_partition_str =
884                    format!("{input_partition_count} -> {output_partition_count}");
885                writeln!(
886                    f,
887                    "partition_count(in->out)={input_to_output_partition_str}"
888                )?;
889
890                if self.preserve_order {
891                    writeln!(f, "preserve_order={}", self.preserve_order)?;
892                }
893                Ok(())
894            }
895        }
896    }
897}
898
899impl ExecutionPlan for RepartitionExec {
900    fn name(&self) -> &'static str {
901        "RepartitionExec"
902    }
903
904    /// Return a reference to Any that can be used for downcasting
905    fn as_any(&self) -> &dyn Any {
906        self
907    }
908
909    fn properties(&self) -> &Arc<PlanProperties> {
910        &self.cache
911    }
912
913    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
914        vec![&self.input]
915    }
916
917    fn with_new_children(
918        self: Arc<Self>,
919        mut children: Vec<Arc<dyn ExecutionPlan>>,
920    ) -> Result<Arc<dyn ExecutionPlan>> {
921        check_if_same_properties!(self, children);
922        let mut repartition = RepartitionExec::try_new(
923            children.swap_remove(0),
924            self.partitioning().clone(),
925        )?;
926        if self.preserve_order {
927            repartition = repartition.with_preserve_order();
928        }
929        Ok(Arc::new(repartition))
930    }
931
932    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
933        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
934    }
935
936    fn maintains_input_order(&self) -> Vec<bool> {
937        Self::maintains_input_order_helper(self.input(), self.preserve_order)
938    }
939
940    fn execute(
941        &self,
942        partition: usize,
943        context: Arc<TaskContext>,
944    ) -> Result<SendableRecordBatchStream> {
945        trace!(
946            "Start {}::execute for partition: {}",
947            self.name(),
948            partition
949        );
950
951        let spill_metrics = SpillMetrics::new(&self.metrics, partition);
952
953        let input = Arc::clone(&self.input);
954        let partitioning = self.partitioning().clone();
955        let metrics = self.metrics.clone();
956        let preserve_order = self.sort_exprs().is_some();
957        let name = self.name().to_owned();
958        let schema = self.schema();
959        let schema_captured = Arc::clone(&schema);
960
961        let spill_manager = SpillManager::new(
962            Arc::clone(&context.runtime_env()),
963            spill_metrics,
964            input.schema(),
965        );
966
967        // Get existing ordering to use for merging
968        let sort_exprs = self.sort_exprs().cloned();
969
970        let state = Arc::clone(&self.state);
971        if let Some(mut state) = state.try_lock() {
972            state.ensure_input_streams_initialized(
973                &input,
974                &metrics,
975                partitioning.partition_count(),
976                &context,
977            )?;
978        }
979
980        let num_input_partitions = input.output_partitioning().partition_count();
981
982        let stream = futures::stream::once(async move {
983            // lock scope
984            let (rx, reservation, spill_readers, abort_helper) = {
985                // lock mutexes
986                let mut state = state.lock();
987                let state = state.consume_input_streams(
988                    &input,
989                    &metrics,
990                    &partitioning,
991                    preserve_order,
992                    &name,
993                    &context,
994                    spill_manager.clone(),
995                )?;
996
997                // now return stream for the specified *output* partition which will
998                // read from the channel
999                let PartitionChannels {
1000                    rx,
1001                    reservation,
1002                    spill_readers,
1003                    ..
1004                } = state
1005                    .channels
1006                    .remove(&partition)
1007                    .expect("partition not used yet");
1008
1009                (
1010                    rx,
1011                    reservation,
1012                    spill_readers,
1013                    Arc::clone(&state.abort_helper),
1014                )
1015            };
1016
1017            trace!(
1018                "Before returning stream in {name}::execute for partition: {partition}"
1019            );
1020
1021            if preserve_order {
1022                // Store streams from all the input partitions:
1023                // Each input partition gets its own spill reader to maintain proper FIFO ordering
1024                let input_streams = rx
1025                    .into_iter()
1026                    .zip(spill_readers)
1027                    .map(|(receiver, spill_stream)| {
1028                        // In preserve_order mode, each receiver corresponds to exactly one input partition
1029                        Box::pin(PerPartitionStream::new(
1030                            Arc::clone(&schema_captured),
1031                            receiver,
1032                            Arc::clone(&abort_helper),
1033                            Arc::clone(&reservation),
1034                            spill_stream,
1035                            1, // Each receiver handles one input partition
1036                            BaselineMetrics::new(&metrics, partition),
1037                            None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286
1038                        )) as SendableRecordBatchStream
1039                    })
1040                    .collect::<Vec<_>>();
1041                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
1042
1043                // Merge streams (while preserving ordering) coming from
1044                // input partitions to this partition:
1045                let fetch = None;
1046                let merge_reservation =
1047                    MemoryConsumer::new(format!("{name}[Merge {partition}]"))
1048                        .register(context.memory_pool());
1049                StreamingMergeBuilder::new()
1050                    .with_streams(input_streams)
1051                    .with_schema(schema_captured)
1052                    .with_expressions(&sort_exprs.unwrap())
1053                    .with_metrics(BaselineMetrics::new(&metrics, partition))
1054                    .with_batch_size(context.session_config().batch_size())
1055                    .with_fetch(fetch)
1056                    .with_reservation(merge_reservation)
1057                    .with_spill_manager(spill_manager)
1058                    .build()
1059            } else {
1060                // Non-preserve-order case: single input stream, so use the first spill reader
1061                let spill_stream = spill_readers
1062                    .into_iter()
1063                    .next()
1064                    .expect("at least one spill reader should exist");
1065
1066                Ok(Box::pin(PerPartitionStream::new(
1067                    schema_captured,
1068                    rx.into_iter()
1069                        .next()
1070                        .expect("at least one receiver should exist"),
1071                    abort_helper,
1072                    reservation,
1073                    spill_stream,
1074                    num_input_partitions,
1075                    BaselineMetrics::new(&metrics, partition),
1076                    Some(context.session_config().batch_size()),
1077                )) as SendableRecordBatchStream)
1078            }
1079        })
1080        .try_flatten();
1081        let stream = RecordBatchStreamAdapter::new(schema, stream);
1082        Ok(Box::pin(stream))
1083    }
1084
1085    fn metrics(&self) -> Option<MetricsSet> {
1086        Some(self.metrics.clone_inner())
1087    }
1088
1089    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
1090        if let Some(partition) = partition {
1091            let partition_count = self.partitioning().partition_count();
1092            if partition_count == 0 {
1093                return Ok(Statistics::new_unknown(&self.schema()));
1094            }
1095
1096            assert_or_internal_err!(
1097                partition < partition_count,
1098                "RepartitionExec invalid partition {} (expected less than {})",
1099                partition,
1100                partition_count
1101            );
1102
1103            let mut stats = self.input.partition_statistics(None)?;
1104
1105            // Distribute statistics across partitions
1106            stats.num_rows = stats
1107                .num_rows
1108                .get_value()
1109                .map(|rows| Precision::Inexact(rows / partition_count))
1110                .unwrap_or(Precision::Absent);
1111            stats.total_byte_size = stats
1112                .total_byte_size
1113                .get_value()
1114                .map(|bytes| Precision::Inexact(bytes / partition_count))
1115                .unwrap_or(Precision::Absent);
1116
1117            // Make all column stats unknown
1118            stats.column_statistics = stats
1119                .column_statistics
1120                .iter()
1121                .map(|_| ColumnStatistics::new_unknown())
1122                .collect();
1123
1124            Ok(stats)
1125        } else {
1126            self.input.partition_statistics(None)
1127        }
1128    }
1129
1130    fn cardinality_effect(&self) -> CardinalityEffect {
1131        CardinalityEffect::Equal
1132    }
1133
1134    fn try_swapping_with_projection(
1135        &self,
1136        projection: &ProjectionExec,
1137    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1138        // If the projection does not narrow the schema, we should not try to push it down.
1139        if projection.expr().len() >= projection.input().schema().fields().len() {
1140            return Ok(None);
1141        }
1142
1143        // If pushdown is not beneficial or applicable, break it.
1144        if projection.benefits_from_input_partitioning()[0]
1145            || !all_columns(projection.expr())
1146        {
1147            return Ok(None);
1148        }
1149
1150        let new_projection = make_with_child(projection, self.input())?;
1151
1152        let new_partitioning = match self.partitioning() {
1153            Partitioning::Hash(partitions, size) => {
1154                let mut new_partitions = vec![];
1155                for partition in partitions {
1156                    let Some(new_partition) =
1157                        update_expr(partition, projection.expr(), false)?
1158                    else {
1159                        return Ok(None);
1160                    };
1161                    new_partitions.push(new_partition);
1162                }
1163                Partitioning::Hash(new_partitions, *size)
1164            }
1165            others => others.clone(),
1166        };
1167
1168        Ok(Some(Arc::new(RepartitionExec::try_new(
1169            new_projection,
1170            new_partitioning,
1171        )?)))
1172    }
1173
1174    fn gather_filters_for_pushdown(
1175        &self,
1176        _phase: FilterPushdownPhase,
1177        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1178        _config: &ConfigOptions,
1179    ) -> Result<FilterDescription> {
1180        FilterDescription::from_children(parent_filters, &self.children())
1181    }
1182
1183    fn handle_child_pushdown_result(
1184        &self,
1185        _phase: FilterPushdownPhase,
1186        child_pushdown_result: ChildPushdownResult,
1187        _config: &ConfigOptions,
1188    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1189        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1190    }
1191
1192    fn try_pushdown_sort(
1193        &self,
1194        order: &[PhysicalSortExpr],
1195    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
1196        // RepartitionExec only maintains input order if preserve_order is set
1197        // or if there's only one partition
1198        if !self.maintains_input_order()[0] {
1199            return Ok(SortOrderPushdownResult::Unsupported);
1200        }
1201
1202        // Delegate to the child and wrap with a new RepartitionExec
1203        self.input.try_pushdown_sort(order)?.try_map(|new_input| {
1204            let mut new_repartition =
1205                RepartitionExec::try_new(new_input, self.partitioning().clone())?;
1206            if self.preserve_order {
1207                new_repartition = new_repartition.with_preserve_order();
1208            }
1209            Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
1210        })
1211    }
1212
1213    fn repartitioned(
1214        &self,
1215        target_partitions: usize,
1216        _config: &ConfigOptions,
1217    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1218        use Partitioning::*;
1219        let mut new_properties = PlanProperties::clone(&self.cache);
1220        new_properties.partitioning = match new_properties.partitioning {
1221            RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1222            Hash(hash, _) => Hash(hash, target_partitions),
1223            UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1224        };
1225        Ok(Some(Arc::new(Self {
1226            input: Arc::clone(&self.input),
1227            state: Arc::clone(&self.state),
1228            metrics: self.metrics.clone(),
1229            preserve_order: self.preserve_order,
1230            cache: new_properties.into(),
1231        })))
1232    }
1233}
1234
1235impl RepartitionExec {
1236    /// Create a new RepartitionExec, that produces output `partitioning`, and
1237    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
1238    /// for more details)
1239    pub fn try_new(
1240        input: Arc<dyn ExecutionPlan>,
1241        partitioning: Partitioning,
1242    ) -> Result<Self> {
1243        let preserve_order = false;
1244        let cache = Self::compute_properties(&input, partitioning, preserve_order);
1245        Ok(RepartitionExec {
1246            input,
1247            state: Default::default(),
1248            metrics: ExecutionPlanMetricsSet::new(),
1249            preserve_order,
1250            cache: Arc::new(cache),
1251        })
1252    }
1253
1254    fn maintains_input_order_helper(
1255        input: &Arc<dyn ExecutionPlan>,
1256        preserve_order: bool,
1257    ) -> Vec<bool> {
1258        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
1259        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1260    }
1261
1262    fn eq_properties_helper(
1263        input: &Arc<dyn ExecutionPlan>,
1264        preserve_order: bool,
1265    ) -> EquivalenceProperties {
1266        // Equivalence Properties
1267        let mut eq_properties = input.equivalence_properties().clone();
1268        // If the ordering is lost, reset the ordering equivalence class:
1269        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1270            eq_properties.clear_orderings();
1271        }
1272        // When there are more than one input partitions, they will be fused at the output.
1273        // Therefore, remove per partition constants.
1274        if input.output_partitioning().partition_count() > 1 {
1275            eq_properties.clear_per_partition_constants();
1276        }
1277        eq_properties
1278    }
1279
1280    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1281    fn compute_properties(
1282        input: &Arc<dyn ExecutionPlan>,
1283        partitioning: Partitioning,
1284        preserve_order: bool,
1285    ) -> PlanProperties {
1286        PlanProperties::new(
1287            Self::eq_properties_helper(input, preserve_order),
1288            partitioning,
1289            input.pipeline_behavior(),
1290            input.boundedness(),
1291        )
1292        .with_scheduling_type(SchedulingType::Cooperative)
1293        .with_evaluation_type(EvaluationType::Eager)
1294    }
1295
1296    /// Specify if this repartitioning operation should preserve the order of
1297    /// rows from its input when producing output. Preserving order is more
1298    /// expensive at runtime, so should only be set if the output of this
1299    /// operator can take advantage of it.
1300    ///
1301    /// If the input is not ordered, or has only one partition, this is a no op,
1302    /// and the node remains a `RepartitionExec`.
1303    pub fn with_preserve_order(mut self) -> Self {
1304        self.preserve_order =
1305                // If the input isn't ordered, there is no ordering to preserve
1306                self.input.output_ordering().is_some() &&
1307                // if there is only one input partition, merging is not required
1308                // to maintain order
1309                self.input.output_partitioning().partition_count() > 1;
1310        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1311        Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties);
1312        self
1313    }
1314
1315    /// Return the sort expressions that are used to merge
1316    fn sort_exprs(&self) -> Option<&LexOrdering> {
1317        if self.preserve_order {
1318            self.input.output_ordering()
1319        } else {
1320            None
1321        }
1322    }
1323
1324    /// Pulls data from the specified input plan, feeding it to the
1325    /// output partitions based on the desired partitioning
1326    ///
1327    /// `output_channels` holds the output sending channels for each output partition
1328    async fn pull_from_input(
1329        mut stream: SendableRecordBatchStream,
1330        mut output_channels: HashMap<usize, OutputChannel>,
1331        partitioning: Partitioning,
1332        metrics: RepartitionMetrics,
1333        input_partition: usize,
1334        num_input_partitions: usize,
1335    ) -> Result<()> {
1336        let mut partitioner = match &partitioning {
1337            Partitioning::Hash(exprs, num_partitions) => {
1338                BatchPartitioner::new_hash_partitioner(
1339                    exprs.clone(),
1340                    *num_partitions,
1341                    metrics.repartition_time.clone(),
1342                )
1343            }
1344            Partitioning::RoundRobinBatch(num_partitions) => {
1345                BatchPartitioner::new_round_robin_partitioner(
1346                    *num_partitions,
1347                    metrics.repartition_time.clone(),
1348                    input_partition,
1349                    num_input_partitions,
1350                )
1351            }
1352            other => {
1353                return not_impl_err!("Unsupported repartitioning scheme {other:?}");
1354            }
1355        };
1356
1357        // While there are still outputs to send to, keep pulling inputs
1358        let mut batches_until_yield = partitioner.num_partitions();
1359        while !output_channels.is_empty() {
1360            // fetch the next batch
1361            let timer = metrics.fetch_time.timer();
1362            let result = stream.next().await;
1363            timer.done();
1364
1365            // Input is done
1366            let batch = match result {
1367                Some(result) => result?,
1368                None => break,
1369            };
1370
1371            // Handle empty batch
1372            if batch.num_rows() == 0 {
1373                continue;
1374            }
1375
1376            for res in partitioner.partition_iter(batch)? {
1377                let (partition, batch) = res?;
1378                let size = batch.get_array_memory_size();
1379
1380                let timer = metrics.send_time[partition].timer();
1381                // if there is still a receiver, send to it
1382                if let Some(channel) = output_channels.get_mut(&partition) {
1383                    let (batch_to_send, is_memory_batch) =
1384                        match channel.reservation.lock().try_grow(size) {
1385                            Ok(_) => {
1386                                // Memory available - send in-memory batch
1387                                (RepartitionBatch::Memory(batch), true)
1388                            }
1389                            Err(_) => {
1390                                // We're memory limited - spill to SpillPool
1391                                // SpillPool handles file handle reuse and rotation
1392                                channel.spill_writer.push_batch(&batch)?;
1393                                // Send marker indicating batch was spilled
1394                                (RepartitionBatch::Spilled, false)
1395                            }
1396                        };
1397
1398                    if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
1399                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
1400                        // Only shrink memory if it was a memory batch
1401                        if is_memory_batch {
1402                            channel.reservation.lock().shrink(size);
1403                        }
1404                        output_channels.remove(&partition);
1405                    }
1406                }
1407                timer.done();
1408            }
1409
1410            // If the input stream is endless, we may spin forever and
1411            // never yield back to tokio.  See
1412            // https://github.com/apache/datafusion/issues/5278.
1413            //
1414            // However, yielding on every batch causes a bottleneck
1415            // when running with multiple cores. See
1416            // https://github.com/apache/datafusion/issues/6290
1417            //
1418            // Thus, heuristically yield after producing num_partition
1419            // batches
1420            //
1421            // In round robin this is ideal as each input will get a
1422            // new batch. In hash partitioning it may yield too often
1423            // on uneven distributions even if some partition can not
1424            // make progress, but parallelism is going to be limited
1425            // in that case anyways
1426            if batches_until_yield == 0 {
1427                tokio::task::yield_now().await;
1428                batches_until_yield = partitioner.num_partitions();
1429            } else {
1430                batches_until_yield -= 1;
1431            }
1432        }
1433
1434        // Spill writers will auto-finalize when dropped
1435        // No need for explicit flush
1436        Ok(())
1437    }
1438
1439    /// Waits for `input_task` which is consuming one of the inputs to
1440    /// complete. Upon each successful completion, sends a `None` to
1441    /// each of the output tx channels to signal one of the inputs is
1442    /// complete. Upon error, propagates the errors to all output tx
1443    /// channels.
1444    async fn wait_for_task(
1445        input_task: SpawnedTask<Result<()>>,
1446        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1447    ) {
1448        // wait for completion, and propagate error
1449        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
1450
1451        match input_task.join().await {
1452            // Error in joining task
1453            Err(e) => {
1454                let e = Arc::new(e);
1455
1456                for (_, tx) in txs {
1457                    let err = Err(DataFusionError::Context(
1458                        "Join Error".to_string(),
1459                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1460                    ));
1461                    tx.send(Some(err)).await.ok();
1462                }
1463            }
1464            // Error from running input task
1465            Ok(Err(e)) => {
1466                // send the same Arc'd error to all output partitions
1467                let e = Arc::new(e);
1468
1469                for (_, tx) in txs {
1470                    // wrap it because need to send error to all output partitions
1471                    let err = Err(DataFusionError::from(&e));
1472                    tx.send(Some(err)).await.ok();
1473                }
1474            }
1475            // Input task completed successfully
1476            Ok(Ok(())) => {
1477                // notify each output partition that this input partition has no more data
1478                for (_partition, tx) in txs {
1479                    tx.send(None).await.ok();
1480                }
1481            }
1482        }
1483    }
1484}
1485
1486/// State for tracking whether we're reading from memory channel or spill stream.
1487///
1488/// This state machine ensures proper ordering when batches are mixed between memory
1489/// and spilled storage. When a [`RepartitionBatch::Spilled`] marker is received,
1490/// the stream must block on the spill stream until the corresponding batch arrives.
1491///
1492/// # State Machine
1493///
1494/// ```text
1495///                        ┌─────────────────┐
1496///                   ┌───▶│  ReadingMemory  │◀───┐
1497///                   │    └────────┬────────┘    │
1498///                   │             │             │
1499///                   │     Poll channel          │
1500///                   │             │             │
1501///                   │  ┌──────────┼─────────────┐
1502///                   │  │          │             │
1503///                   │  ▼          ▼             │
1504///                   │ Memory   Spilled          │
1505///       Got batch   │ batch    marker           │
1506///       from spill  │  │          │             │
1507///                   │  │          ▼             │
1508///                   │  │  ┌──────────────────┐  │
1509///                   │  │  │ ReadingSpilled   │  │
1510///                   │  │  └────────┬─────────┘  │
1511///                   │  │           │            │
1512///                   │  │   Poll spill_stream    │
1513///                   │  │           │            │
1514///                   │  │           ▼            │
1515///                   │  │      Get batch         │
1516///                   │  │           │            │
1517///                   └──┴───────────┴────────────┘
1518///                                  │
1519///                                  ▼
1520///                           Return batch
1521///                     (Order preserved within
1522///                      (input, output) pair)
1523/// ```
1524///
1525/// The transition to `ReadingSpilled` blocks further channel polling to maintain
1526/// FIFO ordering - we cannot read the next item from the channel until the spill
1527/// stream provides the current batch.
1528#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1529enum StreamState {
1530    /// Reading from the memory channel (normal operation)
1531    ReadingMemory,
1532    /// Waiting for a spilled batch from the spill stream.
1533    /// Must not poll channel until spilled batch is received to preserve ordering.
1534    ReadingSpilled,
1535}
1536
1537/// This struct converts a receiver to a stream.
1538/// Receiver receives data on an SPSC channel.
1539struct PerPartitionStream {
1540    /// Schema wrapped by Arc
1541    schema: SchemaRef,
1542
1543    /// channel containing the repartitioned batches
1544    receiver: DistributionReceiver<MaybeBatch>,
1545
1546    /// Handle to ensure background tasks are killed when no longer needed.
1547    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1548
1549    /// Memory reservation.
1550    reservation: SharedMemoryReservation,
1551
1552    /// Infinite stream for reading from the spill pool
1553    spill_stream: SendableRecordBatchStream,
1554
1555    /// Internal state indicating if we are reading from memory or spill stream
1556    state: StreamState,
1557
1558    /// Number of input partitions that have not yet finished.
1559    /// In non-preserve-order mode, multiple input partitions send to the same channel,
1560    /// each sending None when complete. We must wait for all of them.
1561    remaining_partitions: usize,
1562
1563    /// Execution metrics
1564    baseline_metrics: BaselineMetrics,
1565
1566    /// None for sort preserving variant (merge sort already does coalescing)
1567    batch_coalescer: Option<LimitedBatchCoalescer>,
1568}
1569
1570impl PerPartitionStream {
1571    #[expect(clippy::too_many_arguments)]
1572    fn new(
1573        schema: SchemaRef,
1574        receiver: DistributionReceiver<MaybeBatch>,
1575        drop_helper: Arc<Vec<SpawnedTask<()>>>,
1576        reservation: SharedMemoryReservation,
1577        spill_stream: SendableRecordBatchStream,
1578        num_input_partitions: usize,
1579        baseline_metrics: BaselineMetrics,
1580        batch_size: Option<usize>,
1581    ) -> Self {
1582        let batch_coalescer =
1583            batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
1584        Self {
1585            schema,
1586            receiver,
1587            _drop_helper: drop_helper,
1588            reservation,
1589            spill_stream,
1590            state: StreamState::ReadingMemory,
1591            remaining_partitions: num_input_partitions,
1592            baseline_metrics,
1593            batch_coalescer,
1594        }
1595    }
1596
1597    fn poll_next_inner(
1598        self: &mut Pin<&mut Self>,
1599        cx: &mut Context<'_>,
1600    ) -> Poll<Option<Result<RecordBatch>>> {
1601        use futures::StreamExt;
1602        let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1603        let _timer = cloned_time.timer();
1604
1605        loop {
1606            match self.state {
1607                StreamState::ReadingMemory => {
1608                    // Poll the memory channel for next message
1609                    let value = match self.receiver.recv().poll_unpin(cx) {
1610                        Poll::Ready(v) => v,
1611                        Poll::Pending => {
1612                            // Nothing from channel, wait
1613                            return Poll::Pending;
1614                        }
1615                    };
1616
1617                    match value {
1618                        Some(Some(v)) => match v {
1619                            Ok(RepartitionBatch::Memory(batch)) => {
1620                                // Release memory and return batch
1621                                self.reservation
1622                                    .lock()
1623                                    .shrink(batch.get_array_memory_size());
1624                                return Poll::Ready(Some(Ok(batch)));
1625                            }
1626                            Ok(RepartitionBatch::Spilled) => {
1627                                // Batch was spilled, transition to reading from spill stream
1628                                // We must block on spill stream until we get the batch
1629                                // to preserve ordering
1630                                self.state = StreamState::ReadingSpilled;
1631                                continue;
1632                            }
1633                            Err(e) => {
1634                                return Poll::Ready(Some(Err(e)));
1635                            }
1636                        },
1637                        Some(None) => {
1638                            // One input partition finished
1639                            self.remaining_partitions -= 1;
1640                            if self.remaining_partitions == 0 {
1641                                // All input partitions finished
1642                                return Poll::Ready(None);
1643                            }
1644                            // Continue to poll for more data from other partitions
1645                            continue;
1646                        }
1647                        None => {
1648                            // Channel closed unexpectedly
1649                            return Poll::Ready(None);
1650                        }
1651                    }
1652                }
1653                StreamState::ReadingSpilled => {
1654                    // Poll spill stream for the spilled batch
1655                    match self.spill_stream.poll_next_unpin(cx) {
1656                        Poll::Ready(Some(Ok(batch))) => {
1657                            self.state = StreamState::ReadingMemory;
1658                            return Poll::Ready(Some(Ok(batch)));
1659                        }
1660                        Poll::Ready(Some(Err(e))) => {
1661                            return Poll::Ready(Some(Err(e)));
1662                        }
1663                        Poll::Ready(None) => {
1664                            // Spill stream ended, keep draining the memory channel
1665                            self.state = StreamState::ReadingMemory;
1666                        }
1667                        Poll::Pending => {
1668                            // Spilled batch not ready yet, must wait
1669                            // This preserves ordering by blocking until spill data arrives
1670                            return Poll::Pending;
1671                        }
1672                    }
1673                }
1674            }
1675        }
1676    }
1677
1678    fn poll_next_and_coalesce(
1679        self: &mut Pin<&mut Self>,
1680        cx: &mut Context<'_>,
1681        coalescer: &mut LimitedBatchCoalescer,
1682    ) -> Poll<Option<Result<RecordBatch>>> {
1683        let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1684        let mut completed = false;
1685
1686        loop {
1687            if let Some(batch) = coalescer.next_completed_batch() {
1688                return Poll::Ready(Some(Ok(batch)));
1689            }
1690            if completed {
1691                return Poll::Ready(None);
1692            }
1693
1694            match ready!(self.poll_next_inner(cx)) {
1695                Some(Ok(batch)) => {
1696                    let _timer = cloned_time.timer();
1697                    if let Err(err) = coalescer.push_batch(batch) {
1698                        return Poll::Ready(Some(Err(err)));
1699                    }
1700                }
1701                Some(err) => {
1702                    return Poll::Ready(Some(err));
1703                }
1704                None => {
1705                    completed = true;
1706                    let _timer = cloned_time.timer();
1707                    if let Err(err) = coalescer.finish() {
1708                        return Poll::Ready(Some(Err(err)));
1709                    }
1710                }
1711            }
1712        }
1713    }
1714}
1715
1716impl Stream for PerPartitionStream {
1717    type Item = Result<RecordBatch>;
1718
1719    fn poll_next(
1720        mut self: Pin<&mut Self>,
1721        cx: &mut Context<'_>,
1722    ) -> Poll<Option<Self::Item>> {
1723        let poll;
1724        if let Some(mut coalescer) = self.batch_coalescer.take() {
1725            poll = self.poll_next_and_coalesce(cx, &mut coalescer);
1726            self.batch_coalescer = Some(coalescer);
1727        } else {
1728            poll = self.poll_next_inner(cx);
1729        }
1730        self.baseline_metrics.record_poll(poll)
1731    }
1732}
1733
1734impl RecordBatchStream for PerPartitionStream {
1735    /// Get the schema
1736    fn schema(&self) -> SchemaRef {
1737        Arc::clone(&self.schema)
1738    }
1739}
1740
1741#[cfg(test)]
1742mod tests {
1743    use std::collections::HashSet;
1744
1745    use super::*;
1746    use crate::test::TestMemoryExec;
1747    use crate::{
1748        test::{
1749            assert_is_pending,
1750            exec::{
1751                BarrierExec, BlockingExec, ErrorExec, MockExec,
1752                assert_strong_count_converges_to_zero,
1753            },
1754        },
1755        {collect, expressions::col},
1756    };
1757
1758    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1759    use arrow::datatypes::{DataType, Field, Schema};
1760    use datafusion_common::cast::as_string_array;
1761    use datafusion_common::exec_err;
1762    use datafusion_common::test_util::batches_to_sort_string;
1763    use datafusion_common_runtime::JoinSet;
1764    use datafusion_execution::config::SessionConfig;
1765    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1766    use insta::assert_snapshot;
1767
1768    #[tokio::test]
1769    async fn one_to_many_round_robin() -> Result<()> {
1770        // define input partitions
1771        let schema = test_schema();
1772        let partition = create_vec_batches(50);
1773        let partitions = vec![partition];
1774
1775        // repartition from 1 input to 4 output
1776        let output_partitions =
1777            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1778
1779        assert_eq!(4, output_partitions.len());
1780        for partition in &output_partitions {
1781            assert_eq!(1, partition.len());
1782        }
1783        assert_eq!(13 * 8, output_partitions[0][0].num_rows());
1784        assert_eq!(13 * 8, output_partitions[1][0].num_rows());
1785        assert_eq!(12 * 8, output_partitions[2][0].num_rows());
1786        assert_eq!(12 * 8, output_partitions[3][0].num_rows());
1787
1788        Ok(())
1789    }
1790
1791    #[tokio::test]
1792    async fn many_to_one_round_robin() -> Result<()> {
1793        // define input partitions
1794        let schema = test_schema();
1795        let partition = create_vec_batches(50);
1796        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1797
1798        // repartition from 3 input to 1 output
1799        let output_partitions =
1800            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1801
1802        assert_eq!(1, output_partitions.len());
1803        assert_eq!(150 * 8, output_partitions[0][0].num_rows());
1804
1805        Ok(())
1806    }
1807
1808    #[tokio::test]
1809    async fn many_to_many_round_robin() -> Result<()> {
1810        // define input partitions
1811        let schema = test_schema();
1812        let partition = create_vec_batches(50);
1813        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1814
1815        // repartition from 3 input to 5 output
1816        let output_partitions =
1817            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1818
1819        let total_rows_per_partition = 8 * 50 * 3 / 5;
1820        assert_eq!(5, output_partitions.len());
1821        for partition in output_partitions {
1822            assert_eq!(1, partition.len());
1823            assert_eq!(total_rows_per_partition, partition[0].num_rows());
1824        }
1825
1826        Ok(())
1827    }
1828
1829    #[tokio::test]
1830    async fn many_to_many_hash_partition() -> Result<()> {
1831        // define input partitions
1832        let schema = test_schema();
1833        let partition = create_vec_batches(50);
1834        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1835
1836        let output_partitions = repartition(
1837            &schema,
1838            partitions,
1839            Partitioning::Hash(vec![col("c0", &schema)?], 8),
1840        )
1841        .await?;
1842
1843        let total_rows: usize = output_partitions
1844            .iter()
1845            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1846            .sum();
1847
1848        assert_eq!(8, output_partitions.len());
1849        assert_eq!(total_rows, 8 * 50 * 3);
1850
1851        Ok(())
1852    }
1853
1854    #[tokio::test]
1855    async fn test_repartition_with_coalescing() -> Result<()> {
1856        let schema = test_schema();
1857        // create 50 batches, each having 8 rows
1858        let partition = create_vec_batches(50);
1859        let partitions = vec![partition.clone(), partition.clone()];
1860        let partitioning = Partitioning::RoundRobinBatch(1);
1861
1862        let session_config = SessionConfig::new().with_batch_size(200);
1863        let task_ctx = TaskContext::default().with_session_config(session_config);
1864        let task_ctx = Arc::new(task_ctx);
1865
1866        // create physical plan
1867        let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
1868        let exec = RepartitionExec::try_new(exec, partitioning)?;
1869
1870        for i in 0..exec.partitioning().partition_count() {
1871            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1872            while let Some(result) = stream.next().await {
1873                let batch = result?;
1874                assert_eq!(200, batch.num_rows());
1875            }
1876        }
1877        Ok(())
1878    }
1879
1880    fn test_schema() -> Arc<Schema> {
1881        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1882    }
1883
1884    async fn repartition(
1885        schema: &SchemaRef,
1886        input_partitions: Vec<Vec<RecordBatch>>,
1887        partitioning: Partitioning,
1888    ) -> Result<Vec<Vec<RecordBatch>>> {
1889        let task_ctx = Arc::new(TaskContext::default());
1890        // create physical plan
1891        let exec =
1892            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1893        let exec = RepartitionExec::try_new(exec, partitioning)?;
1894
1895        // execute and collect results
1896        let mut output_partitions = vec![];
1897        for i in 0..exec.partitioning().partition_count() {
1898            // execute this *output* partition and collect all batches
1899            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1900            let mut batches = vec![];
1901            while let Some(result) = stream.next().await {
1902                batches.push(result?);
1903            }
1904            output_partitions.push(batches);
1905        }
1906        Ok(output_partitions)
1907    }
1908
1909    #[tokio::test]
1910    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1911        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1912            SpawnedTask::spawn(async move {
1913                // define input partitions
1914                let schema = test_schema();
1915                let partition = create_vec_batches(50);
1916                let partitions =
1917                    vec![partition.clone(), partition.clone(), partition.clone()];
1918
1919                // repartition from 3 input to 5 output
1920                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1921            });
1922
1923        let output_partitions = handle.join().await.unwrap().unwrap();
1924
1925        let total_rows_per_partition = 8 * 50 * 3 / 5;
1926        assert_eq!(5, output_partitions.len());
1927        for partition in output_partitions {
1928            assert_eq!(1, partition.len());
1929            assert_eq!(total_rows_per_partition, partition[0].num_rows());
1930        }
1931
1932        Ok(())
1933    }
1934
1935    #[tokio::test]
1936    async fn unsupported_partitioning() {
1937        let task_ctx = Arc::new(TaskContext::default());
1938        // have to send at least one batch through to provoke error
1939        let batch = RecordBatch::try_from_iter(vec![(
1940            "my_awesome_field",
1941            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1942        )])
1943        .unwrap();
1944
1945        let schema = batch.schema();
1946        let input = MockExec::new(vec![Ok(batch)], schema);
1947        // This generates an error (partitioning type not supported)
1948        // but only after the plan is executed. The error should be
1949        // returned and no results produced
1950        let partitioning = Partitioning::UnknownPartitioning(1);
1951        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1952        let output_stream = exec.execute(0, task_ctx).unwrap();
1953
1954        // Expect that an error is returned
1955        let result_string = crate::common::collect(output_stream)
1956            .await
1957            .unwrap_err()
1958            .to_string();
1959        assert!(
1960            result_string
1961                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1962            "actual: {result_string}"
1963        );
1964    }
1965
1966    #[tokio::test]
1967    async fn error_for_input_exec() {
1968        // This generates an error on a call to execute. The error
1969        // should be returned and no results produced.
1970
1971        let task_ctx = Arc::new(TaskContext::default());
1972        let input = ErrorExec::new();
1973        let partitioning = Partitioning::RoundRobinBatch(1);
1974        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1975
1976        // Expect that an error is returned
1977        let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1978
1979        assert!(
1980            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1981            "actual: {result_string}"
1982        );
1983    }
1984
1985    #[tokio::test]
1986    async fn repartition_with_error_in_stream() {
1987        let task_ctx = Arc::new(TaskContext::default());
1988        let batch = RecordBatch::try_from_iter(vec![(
1989            "my_awesome_field",
1990            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1991        )])
1992        .unwrap();
1993
1994        // input stream returns one good batch and then one error. The
1995        // error should be returned.
1996        let err = exec_err!("bad data error");
1997
1998        let schema = batch.schema();
1999        let input = MockExec::new(vec![Ok(batch), err], schema);
2000        let partitioning = Partitioning::RoundRobinBatch(1);
2001        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2002
2003        // Note: this should pass (the stream can be created) but the
2004        // error when the input is executed should get passed back
2005        let output_stream = exec.execute(0, task_ctx).unwrap();
2006
2007        // Expect that an error is returned
2008        let result_string = crate::common::collect(output_stream)
2009            .await
2010            .unwrap_err()
2011            .to_string();
2012        assert!(
2013            result_string.contains("bad data error"),
2014            "actual: {result_string}"
2015        );
2016    }
2017
2018    #[tokio::test]
2019    async fn repartition_with_delayed_stream() {
2020        let task_ctx = Arc::new(TaskContext::default());
2021        let batch1 = RecordBatch::try_from_iter(vec![(
2022            "my_awesome_field",
2023            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2024        )])
2025        .unwrap();
2026
2027        let batch2 = RecordBatch::try_from_iter(vec![(
2028            "my_awesome_field",
2029            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2030        )])
2031        .unwrap();
2032
2033        // The mock exec doesn't return immediately (instead it
2034        // requires the input to wait at least once)
2035        let schema = batch1.schema();
2036        let expected_batches = vec![batch1.clone(), batch2.clone()];
2037        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
2038        let partitioning = Partitioning::RoundRobinBatch(1);
2039
2040        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2041
2042        assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
2043        +------------------+
2044        | my_awesome_field |
2045        +------------------+
2046        | bar              |
2047        | baz              |
2048        | foo              |
2049        | frob             |
2050        +------------------+
2051        ");
2052
2053        let output_stream = exec.execute(0, task_ctx).unwrap();
2054        let batches = crate::common::collect(output_stream).await.unwrap();
2055
2056        assert_snapshot!(batches_to_sort_string(&batches), @r"
2057        +------------------+
2058        | my_awesome_field |
2059        +------------------+
2060        | bar              |
2061        | baz              |
2062        | foo              |
2063        | frob             |
2064        +------------------+
2065        ");
2066    }
2067
2068    #[tokio::test]
2069    async fn robin_repartition_with_dropping_output_stream() {
2070        let task_ctx = Arc::new(TaskContext::default());
2071        let partitioning = Partitioning::RoundRobinBatch(2);
2072        // The barrier exec waits to be pinged
2073        // requires the input to wait at least once)
2074        let input = Arc::new(make_barrier_exec());
2075
2076        // partition into two output streams
2077        let exec = RepartitionExec::try_new(
2078            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2079            partitioning,
2080        )
2081        .unwrap();
2082
2083        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2084        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2085
2086        // now, purposely drop output stream 0
2087        // *before* any outputs are produced
2088        drop(output_stream0);
2089
2090        // Now, start sending input
2091        let mut background_task = JoinSet::new();
2092        background_task.spawn(async move {
2093            input.wait().await;
2094        });
2095
2096        // output stream 1 should *not* error and have one of the input batches
2097        let batches = crate::common::collect(output_stream1).await.unwrap();
2098
2099        assert_snapshot!(batches_to_sort_string(&batches), @r"
2100        +------------------+
2101        | my_awesome_field |
2102        +------------------+
2103        | baz              |
2104        | frob             |
2105        | gar              |
2106        | goo              |
2107        +------------------+
2108        ");
2109    }
2110
2111    #[tokio::test]
2112    // As the hash results might be different on different platforms or
2113    // with different compilers, we will compare the same execution with
2114    // and without dropping the output stream.
2115    async fn hash_repartition_with_dropping_output_stream() {
2116        let task_ctx = Arc::new(TaskContext::default());
2117        let partitioning = Partitioning::Hash(
2118            vec![Arc::new(crate::expressions::Column::new(
2119                "my_awesome_field",
2120                0,
2121            ))],
2122            2,
2123        );
2124
2125        // We first collect the results without dropping the output stream.
2126        let input = Arc::new(make_barrier_exec());
2127        let exec = RepartitionExec::try_new(
2128            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2129            partitioning.clone(),
2130        )
2131        .unwrap();
2132        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2133        let mut background_task = JoinSet::new();
2134        background_task.spawn(async move {
2135            input.wait().await;
2136        });
2137        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
2138
2139        // run some checks on the result
2140        let items_vec = str_batches_to_vec(&batches_without_drop);
2141        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
2142        assert_eq!(items_vec.len(), items_set.len());
2143        let source_str_set: HashSet<&str> =
2144            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
2145                .iter()
2146                .copied()
2147                .collect();
2148        assert_eq!(items_set.difference(&source_str_set).count(), 0);
2149
2150        // Now do the same but dropping the stream before waiting for the barrier
2151        let input = Arc::new(make_barrier_exec());
2152        let exec = RepartitionExec::try_new(
2153            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2154            partitioning,
2155        )
2156        .unwrap();
2157        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2158        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2159        // now, purposely drop output stream 0
2160        // *before* any outputs are produced
2161        drop(output_stream0);
2162        let mut background_task = JoinSet::new();
2163        background_task.spawn(async move {
2164            input.wait().await;
2165        });
2166        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
2167
2168        let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2169        let items_set_with_drop: HashSet<&str> =
2170            items_vec_with_drop.iter().copied().collect();
2171        assert_eq!(
2172            items_set_with_drop.symmetric_difference(&items_set).count(),
2173            0
2174        );
2175    }
2176
2177    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
2178        batches
2179            .iter()
2180            .flat_map(|batch| {
2181                assert_eq!(batch.columns().len(), 1);
2182                let string_array = as_string_array(batch.column(0))
2183                    .expect("Unexpected type for repartitioned batch");
2184
2185                string_array
2186                    .iter()
2187                    .map(|v| v.expect("Unexpected null"))
2188                    .collect::<Vec<_>>()
2189            })
2190            .collect::<Vec<_>>()
2191    }
2192
2193    /// Create a BarrierExec that returns two partitions of two batches each
2194    fn make_barrier_exec() -> BarrierExec {
2195        let batch1 = RecordBatch::try_from_iter(vec![(
2196            "my_awesome_field",
2197            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2198        )])
2199        .unwrap();
2200
2201        let batch2 = RecordBatch::try_from_iter(vec![(
2202            "my_awesome_field",
2203            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2204        )])
2205        .unwrap();
2206
2207        let batch3 = RecordBatch::try_from_iter(vec![(
2208            "my_awesome_field",
2209            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
2210        )])
2211        .unwrap();
2212
2213        let batch4 = RecordBatch::try_from_iter(vec![(
2214            "my_awesome_field",
2215            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
2216        )])
2217        .unwrap();
2218
2219        // The barrier exec waits to be pinged
2220        // requires the input to wait at least once)
2221        let schema = batch1.schema();
2222        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
2223    }
2224
2225    #[tokio::test]
2226    async fn test_drop_cancel() -> Result<()> {
2227        let task_ctx = Arc::new(TaskContext::default());
2228        let schema =
2229            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2230
2231        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
2232        let refs = blocking_exec.refs();
2233        let repartition_exec = Arc::new(RepartitionExec::try_new(
2234            blocking_exec,
2235            Partitioning::UnknownPartitioning(1),
2236        )?);
2237
2238        let fut = collect(repartition_exec, task_ctx);
2239        let mut fut = fut.boxed();
2240
2241        assert_is_pending(&mut fut);
2242        drop(fut);
2243        assert_strong_count_converges_to_zero(refs).await;
2244
2245        Ok(())
2246    }
2247
2248    #[tokio::test]
2249    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
2250        let task_ctx = Arc::new(TaskContext::default());
2251        let batch = RecordBatch::try_from_iter(vec![(
2252            "a",
2253            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2254        )])
2255        .unwrap();
2256        let partitioning = Partitioning::Hash(
2257            vec![Arc::new(crate::expressions::Column::new("a", 0))],
2258            2,
2259        );
2260        let schema = batch.schema();
2261        let input = MockExec::new(vec![Ok(batch)], schema);
2262        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2263        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2264        let batch0 = crate::common::collect(output_stream0).await.unwrap();
2265        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2266        let batch1 = crate::common::collect(output_stream1).await.unwrap();
2267        assert!(batch0.is_empty() || batch1.is_empty());
2268        Ok(())
2269    }
2270
2271    #[tokio::test]
2272    async fn repartition_with_spilling() -> Result<()> {
2273        // Test that repartition successfully spills to disk when memory is constrained
2274        let schema = test_schema();
2275        let partition = create_vec_batches(50);
2276        let input_partitions = vec![partition];
2277        let partitioning = Partitioning::RoundRobinBatch(4);
2278
2279        // Set up context with very tight memory limit to force spilling
2280        let runtime = RuntimeEnvBuilder::default()
2281            .with_memory_limit(1, 1.0)
2282            .build_arc()?;
2283
2284        let task_ctx = TaskContext::default().with_runtime(runtime);
2285        let task_ctx = Arc::new(task_ctx);
2286
2287        // create physical plan
2288        let exec =
2289            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2290        let exec = RepartitionExec::try_new(exec, partitioning)?;
2291
2292        // Collect all partitions - should succeed by spilling to disk
2293        let mut total_rows = 0;
2294        for i in 0..exec.partitioning().partition_count() {
2295            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2296            while let Some(result) = stream.next().await {
2297                let batch = result?;
2298                total_rows += batch.num_rows();
2299            }
2300        }
2301
2302        // Verify we got all the data (50 batches * 8 rows each)
2303        assert_eq!(total_rows, 50 * 8);
2304
2305        // Verify spilling metrics to confirm spilling actually happened
2306        let metrics = exec.metrics().unwrap();
2307        assert!(
2308            metrics.spill_count().unwrap() > 0,
2309            "Expected spill_count > 0, but got {:?}",
2310            metrics.spill_count()
2311        );
2312        println!("Spilled {} times", metrics.spill_count().unwrap());
2313        assert!(
2314            metrics.spilled_bytes().unwrap() > 0,
2315            "Expected spilled_bytes > 0, but got {:?}",
2316            metrics.spilled_bytes()
2317        );
2318        println!(
2319            "Spilled {} bytes in {} spills",
2320            metrics.spilled_bytes().unwrap(),
2321            metrics.spill_count().unwrap()
2322        );
2323        assert!(
2324            metrics.spilled_rows().unwrap() > 0,
2325            "Expected spilled_rows > 0, but got {:?}",
2326            metrics.spilled_rows()
2327        );
2328        println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2329
2330        Ok(())
2331    }
2332
2333    #[tokio::test]
2334    async fn repartition_with_partial_spilling() -> Result<()> {
2335        // Test that repartition can handle partial spilling (some batches in memory, some spilled)
2336        let schema = test_schema();
2337        let partition = create_vec_batches(50);
2338        let input_partitions = vec![partition];
2339        let partitioning = Partitioning::RoundRobinBatch(4);
2340
2341        // Set up context with moderate memory limit to force partial spilling
2342        // 2KB should allow some batches in memory but force others to spill
2343        let runtime = RuntimeEnvBuilder::default()
2344            .with_memory_limit(2 * 1024, 1.0)
2345            .build_arc()?;
2346
2347        let task_ctx = TaskContext::default().with_runtime(runtime);
2348        let task_ctx = Arc::new(task_ctx);
2349
2350        // create physical plan
2351        let exec =
2352            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2353        let exec = RepartitionExec::try_new(exec, partitioning)?;
2354
2355        // Collect all partitions - should succeed with partial spilling
2356        let mut total_rows = 0;
2357        for i in 0..exec.partitioning().partition_count() {
2358            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2359            while let Some(result) = stream.next().await {
2360                let batch = result?;
2361                total_rows += batch.num_rows();
2362            }
2363        }
2364
2365        // Verify we got all the data (50 batches * 8 rows each)
2366        assert_eq!(total_rows, 50 * 8);
2367
2368        // Verify partial spilling metrics
2369        let metrics = exec.metrics().unwrap();
2370        let spill_count = metrics.spill_count().unwrap();
2371        let spilled_rows = metrics.spilled_rows().unwrap();
2372        let spilled_bytes = metrics.spilled_bytes().unwrap();
2373
2374        assert!(
2375            spill_count > 0,
2376            "Expected some spilling to occur, but got spill_count={spill_count}"
2377        );
2378        assert!(
2379            spilled_rows > 0 && spilled_rows < total_rows,
2380            "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2381        );
2382        assert!(
2383            spilled_bytes > 0,
2384            "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2385        );
2386
2387        println!(
2388            "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2389            spilled_rows,
2390            total_rows,
2391            (spilled_rows as f64 / total_rows as f64) * 100.0,
2392            spill_count,
2393            spilled_bytes
2394        );
2395
2396        Ok(())
2397    }
2398
2399    #[tokio::test]
2400    async fn repartition_without_spilling() -> Result<()> {
2401        // Test that repartition does not spill when there's ample memory
2402        let schema = test_schema();
2403        let partition = create_vec_batches(50);
2404        let input_partitions = vec![partition];
2405        let partitioning = Partitioning::RoundRobinBatch(4);
2406
2407        // Set up context with generous memory limit - no spilling should occur
2408        let runtime = RuntimeEnvBuilder::default()
2409            .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB
2410            .build_arc()?;
2411
2412        let task_ctx = TaskContext::default().with_runtime(runtime);
2413        let task_ctx = Arc::new(task_ctx);
2414
2415        // create physical plan
2416        let exec =
2417            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2418        let exec = RepartitionExec::try_new(exec, partitioning)?;
2419
2420        // Collect all partitions - should succeed without spilling
2421        let mut total_rows = 0;
2422        for i in 0..exec.partitioning().partition_count() {
2423            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2424            while let Some(result) = stream.next().await {
2425                let batch = result?;
2426                total_rows += batch.num_rows();
2427            }
2428        }
2429
2430        // Verify we got all the data (50 batches * 8 rows each)
2431        assert_eq!(total_rows, 50 * 8);
2432
2433        // Verify no spilling occurred
2434        let metrics = exec.metrics().unwrap();
2435        assert_eq!(
2436            metrics.spill_count(),
2437            Some(0),
2438            "Expected no spilling, but got spill_count={:?}",
2439            metrics.spill_count()
2440        );
2441        assert_eq!(
2442            metrics.spilled_bytes(),
2443            Some(0),
2444            "Expected no bytes spilled, but got spilled_bytes={:?}",
2445            metrics.spilled_bytes()
2446        );
2447        assert_eq!(
2448            metrics.spilled_rows(),
2449            Some(0),
2450            "Expected no rows spilled, but got spilled_rows={:?}",
2451            metrics.spilled_rows()
2452        );
2453
2454        println!("No spilling occurred - all data processed in memory");
2455
2456        Ok(())
2457    }
2458
2459    #[tokio::test]
2460    async fn oom() -> Result<()> {
2461        use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2462
2463        // Test that repartition fails with OOM when disk manager is disabled
2464        let schema = test_schema();
2465        let partition = create_vec_batches(50);
2466        let input_partitions = vec![partition];
2467        let partitioning = Partitioning::RoundRobinBatch(4);
2468
2469        // Setup context with memory limit but NO disk manager (explicitly disabled)
2470        let runtime = RuntimeEnvBuilder::default()
2471            .with_memory_limit(1, 1.0)
2472            .with_disk_manager_builder(
2473                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2474            )
2475            .build_arc()?;
2476
2477        let task_ctx = TaskContext::default().with_runtime(runtime);
2478        let task_ctx = Arc::new(task_ctx);
2479
2480        // create physical plan
2481        let exec =
2482            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2483        let exec = RepartitionExec::try_new(exec, partitioning)?;
2484
2485        // Attempt to execute - should fail with ResourcesExhausted error
2486        for i in 0..exec.partitioning().partition_count() {
2487            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2488            let err = stream.next().await.unwrap().unwrap_err();
2489            let err = err.find_root();
2490            assert!(
2491                matches!(err, DataFusionError::ResourcesExhausted(_)),
2492                "Wrong error type: {err}",
2493            );
2494        }
2495
2496        Ok(())
2497    }
2498
2499    /// Create vector batches
2500    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2501        let batch = create_batch();
2502        std::iter::repeat_n(batch, n).collect()
2503    }
2504
2505    /// Create batch
2506    fn create_batch() -> RecordBatch {
2507        let schema = test_schema();
2508        RecordBatch::try_new(
2509            schema,
2510            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2511        )
2512        .unwrap()
2513    }
2514
2515    /// Create batches with sequential values for ordering tests
2516    fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2517        let schema = test_schema();
2518        (0..num_batches)
2519            .map(|i| {
2520                let start = (i * 8) as u32;
2521                RecordBatch::try_new(
2522                    Arc::clone(&schema),
2523                    vec![Arc::new(UInt32Array::from(
2524                        (start..start + 8).collect::<Vec<_>>(),
2525                    ))],
2526                )
2527                .unwrap()
2528            })
2529            .collect()
2530    }
2531
2532    #[tokio::test]
2533    async fn test_repartition_ordering_with_spilling() -> Result<()> {
2534        // Test that repartition preserves ordering when spilling occurs
2535        // This tests the state machine fix where we must block on spill_stream
2536        // when a Spilled marker is received, rather than continuing to poll the channel
2537
2538        let schema = test_schema();
2539        // Create batches with sequential values: batch 0 has [0,1,2,3,4,5,6,7],
2540        // batch 1 has [8,9,10,11,12,13,14,15], etc.
2541        let partition = create_ordered_batches(20);
2542        let input_partitions = vec![partition];
2543
2544        // Use RoundRobinBatch to ensure predictable ordering
2545        let partitioning = Partitioning::RoundRobinBatch(2);
2546
2547        // Set up context with very tight memory limit to force spilling
2548        let runtime = RuntimeEnvBuilder::default()
2549            .with_memory_limit(1, 1.0)
2550            .build_arc()?;
2551
2552        let task_ctx = TaskContext::default().with_runtime(runtime);
2553        let task_ctx = Arc::new(task_ctx);
2554
2555        // create physical plan
2556        let exec =
2557            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2558        let exec = RepartitionExec::try_new(exec, partitioning)?;
2559
2560        // Collect all output partitions
2561        let mut all_batches = Vec::new();
2562        for i in 0..exec.partitioning().partition_count() {
2563            let mut partition_batches = Vec::new();
2564            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2565            while let Some(result) = stream.next().await {
2566                let batch = result?;
2567                partition_batches.push(batch);
2568            }
2569            all_batches.push(partition_batches);
2570        }
2571
2572        // Verify spilling occurred
2573        let metrics = exec.metrics().unwrap();
2574        assert!(
2575            metrics.spill_count().unwrap() > 0,
2576            "Expected spilling to occur, but spill_count = 0"
2577        );
2578
2579        // Verify ordering is preserved within each partition
2580        // With RoundRobinBatch, even batches go to partition 0, odd batches to partition 1
2581        for (partition_idx, batches) in all_batches.iter().enumerate() {
2582            let mut last_value = None;
2583            for batch in batches {
2584                let array = batch
2585                    .column(0)
2586                    .as_any()
2587                    .downcast_ref::<UInt32Array>()
2588                    .unwrap();
2589
2590                for i in 0..array.len() {
2591                    let value = array.value(i);
2592                    if let Some(last) = last_value {
2593                        assert!(
2594                            value > last,
2595                            "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2596                        );
2597                    }
2598                    last_value = Some(value);
2599                }
2600            }
2601        }
2602
2603        Ok(())
2604    }
2605}
2606
2607#[cfg(test)]
2608mod test {
2609    use arrow::array::record_batch;
2610    use arrow::compute::SortOptions;
2611    use arrow::datatypes::{DataType, Field, Schema};
2612    use datafusion_common::assert_batches_eq;
2613
2614    use super::*;
2615    use crate::test::TestMemoryExec;
2616    use crate::union::UnionExec;
2617
2618    use datafusion_physical_expr::expressions::col;
2619    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2620
2621    /// Asserts that the plan is as expected
2622    ///
2623    /// `$EXPECTED_PLAN_LINES`: input plan
2624    /// `$PLAN`: the plan to optimized
2625    macro_rules! assert_plan {
2626        ($PLAN: expr,  @ $EXPECTED: expr) => {
2627            let formatted = crate::displayable($PLAN).indent(true).to_string();
2628
2629            insta::assert_snapshot!(
2630                formatted,
2631                @$EXPECTED
2632            );
2633        };
2634    }
2635
2636    #[tokio::test]
2637    async fn test_preserve_order() -> Result<()> {
2638        let schema = test_schema();
2639        let sort_exprs = sort_exprs(&schema);
2640        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2641        let source2 = sorted_memory_exec(&schema, sort_exprs);
2642        // output has multiple partitions, and is sorted
2643        let union = UnionExec::try_new(vec![source1, source2])?;
2644        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2645            .with_preserve_order();
2646
2647        // Repartition should preserve order
2648        assert_plan!(&exec, @r"
2649        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2650          UnionExec
2651            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2652            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2653        ");
2654        Ok(())
2655    }
2656
2657    #[tokio::test]
2658    async fn test_preserve_order_one_partition() -> Result<()> {
2659        let schema = test_schema();
2660        let sort_exprs = sort_exprs(&schema);
2661        let source = sorted_memory_exec(&schema, sort_exprs);
2662        // output is sorted, but has only a single partition, so no need to sort
2663        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2664            .with_preserve_order();
2665
2666        // Repartition should not preserve order
2667        assert_plan!(&exec, @r"
2668        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
2669          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2670        ");
2671
2672        Ok(())
2673    }
2674
2675    #[tokio::test]
2676    async fn test_preserve_order_input_not_sorted() -> Result<()> {
2677        let schema = test_schema();
2678        let source1 = memory_exec(&schema);
2679        let source2 = memory_exec(&schema);
2680        // output has multiple partitions, but is not sorted
2681        let union = UnionExec::try_new(vec![source1, source2])?;
2682        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2683            .with_preserve_order();
2684
2685        // Repartition should not preserve order, as there is no order to preserve
2686        assert_plan!(&exec, @r"
2687        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2688          UnionExec
2689            DataSourceExec: partitions=1, partition_sizes=[0]
2690            DataSourceExec: partitions=1, partition_sizes=[0]
2691        ");
2692        Ok(())
2693    }
2694
2695    #[tokio::test]
2696    async fn test_preserve_order_with_spilling() -> Result<()> {
2697        use datafusion_execution::TaskContext;
2698        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2699
2700        // Create sorted input data across multiple partitions
2701        // Partition1: [1,3], [5,7], [9,11]
2702        // Partition2: [2,4], [6,8], [10,12]
2703        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2704        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2705        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2706        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2707        let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
2708        let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
2709        let schema = batch1.schema();
2710        let sort_exprs = LexOrdering::new([PhysicalSortExpr {
2711            expr: col("c0", &schema).unwrap(),
2712            options: SortOptions::default().asc(),
2713        }])
2714        .unwrap();
2715        let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
2716        let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
2717        let input_partitions = vec![partition1, partition2];
2718
2719        // Set up context with tight memory limit to force spilling
2720        // Sorting needs some non-spillable memory, so 64 bytes should force spilling while still allowing the query to complete
2721        let runtime = RuntimeEnvBuilder::default()
2722            .with_memory_limit(64, 1.0)
2723            .build_arc()?;
2724
2725        let task_ctx = TaskContext::default().with_runtime(runtime);
2726        let task_ctx = Arc::new(task_ctx);
2727
2728        // Create physical plan with order preservation
2729        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
2730            .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
2731        let exec = Arc::new(exec);
2732        let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2733        // Repartition into 3 partitions with order preservation
2734        // We expect 1 batch per output partition after repartitioning
2735        let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
2736            .with_preserve_order();
2737
2738        let mut batches = vec![];
2739
2740        // Collect all partitions - should succeed by spilling to disk
2741        for i in 0..exec.partitioning().partition_count() {
2742            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2743            while let Some(result) = stream.next().await {
2744                let batch = result?;
2745                batches.push(batch);
2746            }
2747        }
2748
2749        #[rustfmt::skip]
2750        let expected = [
2751            [
2752                "+----+",
2753                "| c0 |",
2754                "+----+",
2755                "| 1  |",
2756                "| 2  |",
2757                "| 3  |",
2758                "| 4  |",
2759                "+----+",
2760            ],
2761            [
2762                "+----+",
2763                "| c0 |",
2764                "+----+",
2765                "| 5  |",
2766                "| 6  |",
2767                "| 7  |",
2768                "| 8  |",
2769                "+----+",
2770            ],
2771            [
2772                "+----+",
2773                "| c0 |",
2774                "+----+",
2775                "| 9  |",
2776                "| 10 |",
2777                "| 11 |",
2778                "| 12 |",
2779                "+----+",
2780            ],
2781        ];
2782
2783        for (batch, expected) in batches.iter().zip(expected.iter()) {
2784            assert_batches_eq!(expected, std::slice::from_ref(batch));
2785        }
2786
2787        // We should have spilled ~ all of the data.
2788        // - We spill data during the repartitioning phase
2789        // - We may also spill during the final merge sort
2790        let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
2791        let metrics = exec.metrics().unwrap();
2792        assert!(
2793            metrics.spill_count().unwrap() > input_partitions.len(),
2794            "Expected spill_count > {} for order-preserving repartition, but got {:?}",
2795            input_partitions.len(),
2796            metrics.spill_count()
2797        );
2798        assert!(
2799            metrics.spilled_bytes().unwrap()
2800                > all_batches
2801                    .iter()
2802                    .map(|b| b.get_array_memory_size())
2803                    .sum::<usize>(),
2804            "Expected spilled_bytes > {} for order-preserving repartition, got {}",
2805            all_batches
2806                .iter()
2807                .map(|b| b.get_array_memory_size())
2808                .sum::<usize>(),
2809            metrics.spilled_bytes().unwrap()
2810        );
2811        assert!(
2812            metrics.spilled_rows().unwrap()
2813                >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2814            "Expected spilled_rows > {} for order-preserving repartition, got {}",
2815            all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2816            metrics.spilled_rows().unwrap()
2817        );
2818
2819        Ok(())
2820    }
2821
2822    #[tokio::test]
2823    async fn test_hash_partitioning_with_spilling() -> Result<()> {
2824        use datafusion_execution::TaskContext;
2825        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2826
2827        // Create input data similar to the round-robin test
2828        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2829        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2830        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2831        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2832        let schema = batch1.schema();
2833
2834        let partition1 = vec![batch1.clone(), batch3.clone()];
2835        let partition2 = vec![batch2.clone(), batch4.clone()];
2836        let input_partitions = vec![partition1, partition2];
2837
2838        // Set up context with memory limit to test hash partitioning with spilling infrastructure
2839        let runtime = RuntimeEnvBuilder::default()
2840            .with_memory_limit(1, 1.0)
2841            .build_arc()?;
2842
2843        let task_ctx = TaskContext::default().with_runtime(runtime);
2844        let task_ctx = Arc::new(task_ctx);
2845
2846        // Create physical plan with hash partitioning
2847        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
2848        let exec = Arc::new(exec);
2849        let exec = Arc::new(TestMemoryExec::update_cache(&exec));
2850        // Hash partition into 2 partitions by column c0
2851        let hash_expr = col("c0", &schema)?;
2852        let exec =
2853            RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
2854
2855        // Collect all partitions concurrently using JoinSet - this prevents deadlock
2856        // where the distribution channel gate closes when all output channels are full
2857        let mut join_set = tokio::task::JoinSet::new();
2858        for i in 0..exec.partitioning().partition_count() {
2859            let stream = exec.execute(i, Arc::clone(&task_ctx))?;
2860            join_set.spawn(async move {
2861                let mut count = 0;
2862                futures::pin_mut!(stream);
2863                while let Some(result) = stream.next().await {
2864                    let batch = result?;
2865                    count += batch.num_rows();
2866                }
2867                Ok::<usize, DataFusionError>(count)
2868            });
2869        }
2870
2871        // Wait for all partitions and sum the rows
2872        let mut total_rows = 0;
2873        while let Some(result) = join_set.join_next().await {
2874            total_rows += result.unwrap()?;
2875        }
2876
2877        // Verify we got all rows back
2878        let all_batches = [batch1, batch2, batch3, batch4];
2879        let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
2880        assert_eq!(total_rows, expected_rows);
2881
2882        // Verify metrics are available
2883        let metrics = exec.metrics().unwrap();
2884        // Just verify the metrics can be retrieved (spilling may or may not occur)
2885        let spill_count = metrics.spill_count().unwrap_or(0);
2886        assert!(spill_count > 0);
2887        let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
2888        assert!(spilled_bytes > 0);
2889        let spilled_rows = metrics.spilled_rows().unwrap_or(0);
2890        assert!(spilled_rows > 0);
2891
2892        Ok(())
2893    }
2894
2895    #[tokio::test]
2896    async fn test_repartition() -> Result<()> {
2897        let schema = test_schema();
2898        let sort_exprs = sort_exprs(&schema);
2899        let source = sorted_memory_exec(&schema, sort_exprs);
2900        // output is sorted, but has only a single partition, so no need to sort
2901        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2902            .repartitioned(20, &Default::default())?
2903            .unwrap();
2904
2905        // Repartition should not preserve order
2906        assert_plan!(exec.as_ref(), @r"
2907        RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
2908          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2909        ");
2910        Ok(())
2911    }
2912
2913    fn test_schema() -> Arc<Schema> {
2914        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2915    }
2916
2917    fn sort_exprs(schema: &Schema) -> LexOrdering {
2918        [PhysicalSortExpr {
2919            expr: col("c0", schema).unwrap(),
2920            options: SortOptions::default(),
2921        }]
2922        .into()
2923    }
2924
2925    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2926        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2927    }
2928
2929    fn sorted_memory_exec(
2930        schema: &SchemaRef,
2931        sort_exprs: LexOrdering,
2932    ) -> Arc<dyn ExecutionPlan> {
2933        let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2934            .unwrap()
2935            .try_with_sort_information(vec![sort_exprs])
2936            .unwrap();
2937        let exec = Arc::new(exec);
2938        Arc::new(TestMemoryExec::update_cache(&exec))
2939    }
2940}