Skip to main content

datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::task::{Context, Poll};
27use std::vec;
28
29use super::common::SharedMemoryReservation;
30use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
31use super::{
32    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
33};
34use crate::coalesce::LimitedBatchCoalescer;
35use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
36use crate::hash_utils::create_hashes;
37use crate::metrics::{BaselineMetrics, SpillMetrics};
38use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
39use crate::sorts::streaming_merge::StreamingMergeBuilder;
40use crate::spill::spill_manager::SpillManager;
41use crate::spill::spill_pool::{self, SpillPoolWriter};
42use crate::stream::{EmptyRecordBatchStream, RecordBatchStreamAdapter};
43use crate::{
44    DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics,
45    check_if_same_properties,
46};
47
48use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
49use arrow::compute::take_arrays;
50use arrow::datatypes::{SchemaRef, UInt32Type};
51use datafusion_common::config::ConfigOptions;
52use datafusion_common::stats::Precision;
53use datafusion_common::utils::transpose;
54use datafusion_common::{
55    ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err,
56};
57use datafusion_common::{Result, not_impl_err};
58use datafusion_common_runtime::SpawnedTask;
59use datafusion_execution::TaskContext;
60use datafusion_execution::memory_pool::MemoryConsumer;
61use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
62use datafusion_physical_expr_common::sort_expr::LexOrdering;
63
64use crate::filter_pushdown::{
65    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
66    FilterPushdownPropagation,
67};
68use crate::joins::SeededRandomState;
69use crate::sort_pushdown::SortOrderPushdownResult;
70use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
71use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
72use futures::stream::Stream;
73use futures::{FutureExt, StreamExt, TryStreamExt};
74use log::trace;
75use parking_lot::Mutex;
76
77mod distributor_channels;
78use crate::repartition::distributor_channels::SendError;
79use distributor_channels::{
80    DistributionReceiver, DistributionSender, channels, partition_aware_channels,
81};
82
83/// A batch in the repartition queue - either in memory or spilled to disk.
84///
85/// This enum represents the two states a batch can be in during repartitioning.
86/// The decision to spill is made based on memory availability when sending a batch
87/// to an output partition.
88///
89/// # Batch Flow with Spilling
90///
91/// ```text
92///                      Input Stream   ◀──────┐
93///                           │                │
94///                           ▼                │
95///                    Partition Logic         │
96///                           │           `batch_size` not
97///                           ▼            reached yet
98///                    Coalesce Batch          │
99///           ┌───────────────┴────────────────┘
100///           ▼
101/// `batch_size` reached
102///           │
103///           └───────────────┐
104///                           ▼
105///                        try_grow()
106///           ┌───────────────┴────────────────┐
107///           ▼                                ▼
108/// try_grow() succeeds              try_grow() fails
109/// (Memory Available)               (Memory Pressure)
110///           │                                │
111///           ▼                                ▼
112/// RepartitionBatch::Memory         spill_writer.push_batch()
113/// (batch held in memory)           (batch written to disk)
114///           │                                │
115///           │                                ▼
116///           │                      RepartitionBatch::Spilled
117///           │                      (marker - no batch data)
118///           └──────────────┬─────────────────┘
119///                          │
120///                          ▼
121///                   Send to channel
122///                          │
123///                          ▼
124///                 Output Stream (poll)
125///                          │
126///           ┌──────────────┴────────────────┐
127///           ▼                               ▼
128/// RepartitionBatch::Memory      RepartitionBatch::Spilled
129/// Return batch immediately       Poll spill_stream (blocks)
130///           └─────────────┬─────────────────┘
131///                         │
132///                         ▼
133///                    Return batch
134///               (FIFO order preserved)
135/// ```
136///
137/// See [`RepartitionExec`] for overall architecture and [`StreamState`] for
138/// the state machine that handles reading these batches.
139#[derive(Debug)]
140enum RepartitionBatch {
141    /// Batch held in memory (counts against memory reservation)
142    Memory(RecordBatch),
143    /// Marker indicating a batch was spilled to the partition's SpillPool.
144    /// The actual batch can be retrieved by reading from the SpillPoolStream.
145    /// This variant contains no data itself - it's just a signal to the reader
146    /// to fetch the next batch from the spill stream.
147    Spilled,
148}
149
150type MaybeBatch = Option<Result<RepartitionBatch>>;
151type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
152type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
153
154/// Output channel with its associated memory reservation and spill writer.
155///
156/// `coalescer` is `None` for preserve-order mode, where downstream
157/// [`StreamingMergeBuilder`] performs the batching; otherwise it's a
158/// [`SharedCoalescer`] cloned from the per-partition one held by
159/// [`PartitionChannels`].
160struct OutputChannel {
161    sender: DistributionSender<MaybeBatch>,
162    reservation: SharedMemoryReservation,
163    spill_writer: SpillPoolWriter,
164    shared_coalescer: Option<SharedCoalescer>,
165}
166
167impl OutputChannel {
168    fn coalesce(&mut self, batch: RecordBatch) -> Result<Vec<RecordBatch>> {
169        match &self.shared_coalescer {
170            Some(shared) => Ok(shared.push_and_drain(batch)?),
171            None => Ok(vec![batch]),
172        }
173    }
174
175    /// Send a single batch through the channel for `partition`, applying
176    /// the memory reservation / spill-writer fallback. Removes the channel
177    /// from `self.inner` if the receiver has hung up.
178    ///
179    /// Used after [`OutputChannel::coalesce`] for performance purposes.
180    async fn send(&mut self, batch: RecordBatch) -> Result<(), SendError<MaybeBatch>> {
181        let size = batch.get_array_memory_size();
182
183        // Decide the payload outside of any await: never hold a MutexGuard
184        // across an await point.
185        let (payload, is_memory_batch) = {
186            match self.reservation.try_grow(size) {
187                Ok(_) => (Ok(RepartitionBatch::Memory(batch)), true),
188                Err(_) => match self.spill_writer.push_batch(&batch) {
189                    Ok(()) => (Ok(RepartitionBatch::Spilled), false),
190                    Err(err) => (Err(err), false),
191                },
192            }
193        };
194
195        let result = self.sender.send(Some(payload)).await;
196        if result.is_err() && is_memory_batch {
197            self.reservation.shrink(size);
198        }
199        result
200    }
201
202    async fn finalize(mut self) -> Result<()> {
203        let Some(shared) = self.shared_coalescer.take() else {
204            return Ok(());
205        };
206        for batch in shared.finalize()? {
207            // If this errored, it means that nobody is listening on the other side, which is fine
208            // and can happen in certain cases, like when a LIMIT drops the stream that listens.
209            let _ = self.send(batch).await;
210        }
211        Ok(())
212    }
213}
214
215/// A producer-side coalescer shared across all input tasks targeting a
216/// single output partition.
217///
218/// Bundles the [`LimitedBatchCoalescer`] (behind a [`Mutex`]) with the
219/// active-sender counter that tracks how many input tasks may still push
220/// into it. The last task to call [`Self::finalize`] is the one that
221/// finalizes the coalescer and ships the residual batch.
222///
223/// Cheap to [`Clone`]: both fields are [`Arc`]s.
224#[derive(Clone)]
225struct SharedCoalescer {
226    inner: Arc<Mutex<LimitedBatchCoalescer>>,
227    active_senders: Arc<AtomicUsize>,
228}
229
230impl SharedCoalescer {
231    fn new(schema: SchemaRef, target_batch_size: usize, num_senders: usize) -> Self {
232        Self {
233            inner: Arc::new(Mutex::new(LimitedBatchCoalescer::new(
234                schema,
235                target_batch_size,
236                None,
237            ))),
238            active_senders: Arc::new(AtomicUsize::new(num_senders)),
239        }
240    }
241
242    /// Push `batch` into the coalescer and drain any newly completed
243    /// batches. The mutex is held only briefly.
244    fn push_and_drain(&self, batch: RecordBatch) -> Result<Vec<RecordBatch>> {
245        let mut acc = Vec::new();
246        let mut c = self.inner.lock();
247        c.push_batch(batch)?;
248        while let Some(b) = c.next_completed_batch() {
249            acc.push(b);
250        }
251        Ok(acc)
252    }
253
254    /// Decrement the active-senders counter. If this caller was the last
255    /// sender, finalize the coalescer and return its residual batches; if
256    /// other senders are still active, return `Ok(None)`.
257    fn finalize(&self) -> Result<Vec<RecordBatch>> {
258        let was_last = self.active_senders.fetch_sub(1, Ordering::AcqRel) == 1;
259        if !was_last {
260            return Ok(vec![]);
261        }
262        let mut acc = Vec::new();
263        let mut c = self.inner.lock();
264        c.finish()?;
265        while let Some(b) = c.next_completed_batch() {
266            acc.push(b);
267        }
268        Ok(acc)
269    }
270}
271
272/// Channels and resources for a single output partition.
273///
274/// Each output partition has channels to receive data from all input partitions.
275/// To handle memory pressure, each (input, output) pair gets its own
276/// [`SpillPool`](crate::spill::spill_pool) channel via [`spill_pool::channel`].
277///
278/// # Structure
279///
280/// For an output partition receiving from N input partitions:
281/// - `tx`: N senders (one per input) for sending batches to this output
282/// - `rx`: N receivers (one per input) for receiving batches at this output
283/// - `spill_writers`: N spill writers (one per input) for writing spilled data
284/// - `spill_readers`: N spill readers (one per input) for reading spilled data
285///
286/// This 1:1 mapping between input partitions and spill channels ensures that
287/// batches from each input are processed in FIFO order, even when some batches
288/// are spilled to disk and others remain in memory.
289///
290/// See [`RepartitionExec`] for the overall N×M architecture.
291///
292/// [`spill_pool::channel`]: crate::spill::spill_pool::channel
293struct PartitionChannels {
294    /// Senders for each input partition to send data to this output partition
295    tx: InputPartitionsToCurrentPartitionSender,
296    /// Receivers for each input partition sending data to this output partition
297    rx: InputPartitionsToCurrentPartitionReceiver,
298    /// Memory reservation for this output partition
299    reservation: SharedMemoryReservation,
300    /// Shared coalescer used by all input tasks targeting this output
301    /// partition. `None` in preserve-order mode (downstream
302    /// `StreamingMergeBuilder` handles batching).
303    shared_coalescer: Option<SharedCoalescer>,
304    /// Spill writers for writing spilled data.
305    /// SpillPoolWriter is Clone, so multiple writers can share state in non-preserve-order mode.
306    spill_writers: Vec<SpillPoolWriter>,
307    /// Spill readers for reading spilled data - one per input partition (FIFO semantics).
308    /// Each (input, output) pair gets its own reader to maintain proper ordering.
309    spill_readers: Vec<SendableRecordBatchStream>,
310}
311
312struct ConsumingInputStreamsState {
313    /// Channels for sending batches from input partitions to output partitions.
314    /// Key is the partition number.
315    channels: HashMap<usize, PartitionChannels>,
316
317    /// Helper that ensures that background jobs are killed once they are no longer needed.
318    abort_helper: Arc<Vec<SpawnedTask<()>>>,
319}
320
321impl Debug for ConsumingInputStreamsState {
322    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
323        f.debug_struct("ConsumingInputStreamsState")
324            .field("num_channels", &self.channels.len())
325            .field("abort_helper", &self.abort_helper)
326            .finish()
327    }
328}
329
330/// Inner state of [`RepartitionExec`].
331#[derive(Default)]
332enum RepartitionExecState {
333    /// Not initialized yet. This is the default state stored in the RepartitionExec node
334    /// upon instantiation.
335    #[default]
336    NotInitialized,
337    /// Input streams are initialized, but they are still not being consumed. The node
338    /// transitions to this state when the arrow's RecordBatch stream is created in
339    /// RepartitionExec::execute(), but before any message is polled.
340    InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
341    /// The input streams are being consumed. The node transitions to this state when
342    /// the first message in the arrow's RecordBatch stream is consumed.
343    ConsumingInputStreams(ConsumingInputStreamsState),
344}
345
346impl Debug for RepartitionExecState {
347    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
348        match self {
349            RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
350            RepartitionExecState::InputStreamsInitialized(v) => {
351                write!(f, "InputStreamsInitialized({:?})", v.len())
352            }
353            RepartitionExecState::ConsumingInputStreams(v) => {
354                write!(f, "ConsumingInputStreams({v:?})")
355            }
356        }
357    }
358}
359
360impl RepartitionExecState {
361    fn ensure_input_streams_initialized(
362        &mut self,
363        input: &Arc<dyn ExecutionPlan>,
364        metrics: &ExecutionPlanMetricsSet,
365        output_partitions: usize,
366        ctx: &Arc<TaskContext>,
367    ) -> Result<()> {
368        if !matches!(self, RepartitionExecState::NotInitialized) {
369            return Ok(());
370        }
371
372        let num_input_partitions = input.output_partitioning().partition_count();
373        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
374
375        for i in 0..num_input_partitions {
376            let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
377
378            let timer = metrics.fetch_time.timer();
379            let stream = input.execute(i, Arc::clone(ctx))?;
380            timer.done();
381
382            streams_and_metrics.push((stream, metrics));
383        }
384        *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
385        Ok(())
386    }
387
388    #[expect(clippy::too_many_arguments)]
389    fn consume_input_streams(
390        &mut self,
391        input: &Arc<dyn ExecutionPlan>,
392        metrics: &ExecutionPlanMetricsSet,
393        partitioning: &Partitioning,
394        preserve_order: bool,
395        name: &str,
396        context: &Arc<TaskContext>,
397        spill_manager: SpillManager,
398    ) -> Result<&mut ConsumingInputStreamsState> {
399        let streams_and_metrics = match self {
400            RepartitionExecState::NotInitialized => {
401                self.ensure_input_streams_initialized(
402                    input,
403                    metrics,
404                    partitioning.partition_count(),
405                    context,
406                )?;
407                let RepartitionExecState::InputStreamsInitialized(value) = self else {
408                    // This cannot happen, as ensure_input_streams_initialized() was just called,
409                    // but the compiler does not know.
410                    return internal_err!(
411                        "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
412                    );
413                };
414                value
415            }
416            RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
417            RepartitionExecState::InputStreamsInitialized(value) => value,
418        };
419
420        let num_input_partitions = streams_and_metrics.len();
421        let num_output_partitions = partitioning.partition_count();
422
423        let spill_manager = Arc::new(spill_manager);
424
425        let (txs, rxs) = if preserve_order {
426            // Create partition-aware channels with one channel per (input, output) pair
427            // This provides backpressure while maintaining proper ordering
428            let (txs_all, rxs_all) =
429                partition_aware_channels(num_input_partitions, num_output_partitions);
430            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
431            let txs = transpose(txs_all);
432            let rxs = transpose(rxs_all);
433            (txs, rxs)
434        } else {
435            // Create one channel per *output* partition with backpressure
436            let (txs, rxs) = channels(num_output_partitions);
437            // Clone sender for each input partitions
438            let txs = txs
439                .into_iter()
440                .map(|item| vec![item; num_input_partitions])
441                .collect::<Vec<_>>();
442            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
443            (txs, rxs)
444        };
445
446        let mut channels = HashMap::with_capacity(txs.len());
447        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
448            let reservation = Arc::new(
449                MemoryConsumer::new(format!("{name}[{partition}]"))
450                    .with_can_spill(true)
451                    .register(context.memory_pool()),
452            );
453
454            // Create spill channels based on mode:
455            // - preserve_order: one spill channel per (input, output) pair for proper FIFO ordering
456            // - non-preserve-order: one shared spill channel per output partition since all inputs
457            //   share the same receiver
458            let max_file_size = context
459                .session_config()
460                .options()
461                .execution
462                .max_spill_file_size_bytes;
463            let num_spill_channels = if preserve_order {
464                num_input_partitions
465            } else {
466                1
467            };
468            let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
469                ..num_spill_channels)
470                .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
471                .unzip();
472
473            // Coalesce on the producer side, before the channel's gate, so
474            // the consumer never sees the per-input-task small batches.
475            // Skip in preserve-order mode: each input has its own dedicated
476            // channel and `StreamingMergeBuilder` handles batching.
477            let shared_coalescer = (!preserve_order).then(|| {
478                SharedCoalescer::new(
479                    input.schema(),
480                    context.session_config().batch_size(),
481                    num_input_partitions,
482                )
483            });
484
485            channels.insert(
486                partition,
487                PartitionChannels {
488                    tx,
489                    rx,
490                    reservation,
491                    spill_readers,
492                    spill_writers,
493                    shared_coalescer,
494                },
495            );
496        }
497
498        // launch one async task per *input* partition
499        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
500        for (i, (stream, metrics)) in
501            std::mem::take(streams_and_metrics).into_iter().enumerate()
502        {
503            let txs: HashMap<_, _> = channels
504                .iter()
505                .map(|(partition, channels)| {
506                    // In preserve_order mode: each input gets its own spill writer (index i)
507                    // In non-preserve-order mode: all inputs share spill writer 0 via clone
508                    let spill_writer_idx = if preserve_order { i } else { 0 };
509                    (
510                        *partition,
511                        OutputChannel {
512                            sender: channels.tx[i].clone(),
513                            reservation: Arc::clone(&channels.reservation),
514                            spill_writer: channels.spill_writers[spill_writer_idx]
515                                .clone(),
516                            shared_coalescer: channels.shared_coalescer.clone(),
517                        },
518                    )
519                })
520                .collect();
521
522            // Extract senders for wait_for_task before moving txs
523            let senders: HashMap<_, _> = txs
524                .iter()
525                .map(|(partition, channel)| (*partition, channel.sender.clone()))
526                .collect();
527
528            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
529                stream,
530                txs,
531                partitioning.clone(),
532                metrics,
533                // preserve_order depends on partition index to start from 0
534                if preserve_order { 0 } else { i },
535                num_input_partitions,
536            ));
537
538            // In a separate task, wait for each input to be done
539            // (and pass along any errors, including panic!s)
540            let wait_for_task =
541                SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
542            spawned_tasks.push(wait_for_task);
543        }
544        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
545            channels,
546            abort_helper: Arc::new(spawned_tasks),
547        });
548        match self {
549            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
550            _ => unreachable!(),
551        }
552    }
553}
554
555/// A utility that can be used to partition batches based on [`Partitioning`]
556pub struct BatchPartitioner {
557    state: BatchPartitionerState,
558    timer: metrics::Time,
559}
560
561enum BatchPartitionerState {
562    Hash {
563        exprs: Vec<Arc<dyn PhysicalExpr>>,
564        partition_reducer: StrengthReducedU64,
565        hash_buffer: Vec<u64>,
566        indices: Vec<Vec<u32>>,
567    },
568    RoundRobin {
569        num_partitions: usize,
570        next_idx: usize,
571    },
572}
573
574/// Fixed RandomState used for hash repartitioning to ensure consistent behavior across
575/// executions and runs.
576pub const REPARTITION_RANDOM_STATE: SeededRandomState = SeededRandomState::with_seed(0);
577
578/// Computes `value % divisor` without division in the hot loop when `divisor`
579/// is fixed for many values.
580///
581/// Hash repartitioning computes a remainder for every row. Integer division is
582/// relatively expensive, so this precomputes the strength-reduced form of the
583/// divisor: powers of two use a bit mask, and other divisors use a reciprocal
584/// multiply to recover the quotient and therefore the remainder. This is the
585/// same invariant-divisor optimization compilers use for `%` by a constant.
586#[derive(Debug, Clone, Copy)]
587enum StrengthReducedU64 {
588    PowerOfTwo { mask: u64 },
589    Reciprocal { divisor: u64, reciprocal: u128 },
590}
591
592impl StrengthReducedU64 {
593    fn new(divisor: u64) -> Self {
594        debug_assert!(divisor > 0);
595
596        if divisor.is_power_of_two() {
597            Self::PowerOfTwo { mask: divisor - 1 }
598        } else {
599            Self::Reciprocal {
600                divisor,
601                // ceil(2^128 / divisor), computed without representing 2^128
602                reciprocal: u128::MAX / u128::from(divisor) + 1,
603            }
604        }
605    }
606
607    fn partition_indices(self, hash_buffer: &[u64], indices: &mut [Vec<u32>]) {
608        match self {
609            Self::PowerOfTwo { mask } => {
610                for (index, hash) in hash_buffer.iter().enumerate() {
611                    indices[(*hash & mask) as usize].push(index as u32);
612                }
613            }
614            Self::Reciprocal {
615                divisor,
616                reciprocal,
617            } => {
618                for (index, hash) in hash_buffer.iter().enumerate() {
619                    let quotient = Self::quotient(*hash, reciprocal);
620                    let partition = *hash - quotient * divisor;
621                    indices[partition as usize].push(index as u32);
622                }
623            }
624        }
625    }
626
627    #[cfg(test)]
628    fn remainder(self, value: u64) -> u64 {
629        match self {
630            Self::PowerOfTwo { mask } => value & mask,
631            Self::Reciprocal {
632                divisor,
633                reciprocal,
634            } => value - Self::quotient(value, reciprocal) * divisor,
635        }
636    }
637
638    #[inline]
639    fn quotient(value: u64, reciprocal: u128) -> u64 {
640        let reciprocal_low = reciprocal as u64;
641        let reciprocal_high = (reciprocal >> 64) as u64;
642        let low_product = u128::from(value) * u128::from(reciprocal_low);
643        let high_product = u128::from(value) * u128::from(reciprocal_high);
644        let carry = ((high_product & u128::from(u64::MAX)) + (low_product >> 64)) >> 64;
645
646        ((high_product >> 64) + carry) as u64
647    }
648}
649
650impl BatchPartitioner {
651    /// Create a new [`BatchPartitioner`] for hash-based repartitioning.
652    ///
653    /// # Parameters
654    /// - `exprs`: Expressions used to compute the hash for each input row.
655    /// - `num_partitions`: Total number of output partitions.
656    /// - `timer`: Metric used to record time spent during repartitioning.
657    ///
658    /// The partition count is fixed for the lifetime of the partitioner, so this
659    /// precomputes a strength-reduced reducer for `hash % num_partitions`.
660    ///
661    /// # Errors
662    /// Returns an error if `num_partitions` is zero.
663    pub fn new_hash_partitioner(
664        exprs: Vec<Arc<dyn PhysicalExpr>>,
665        num_partitions: usize,
666        timer: metrics::Time,
667    ) -> Result<Self> {
668        if num_partitions == 0 {
669            return internal_err!("Hash repartition requires at least one partition");
670        }
671
672        Ok(Self {
673            state: BatchPartitionerState::Hash {
674                exprs,
675                partition_reducer: StrengthReducedU64::new(num_partitions as u64),
676                hash_buffer: vec![],
677                indices: vec![vec![]; num_partitions],
678            },
679            timer,
680        })
681    }
682
683    /// Create a new [`BatchPartitioner`] for round-robin repartitioning.
684    ///
685    /// # Parameters
686    /// - `num_partitions`: Total number of output partitions.
687    /// - `timer`: Metric used to record time spent during repartitioning.
688    /// - `input_partition`: Index of the current input partition.
689    /// - `num_input_partitions`: Total number of input partitions.
690    ///
691    /// # Notes
692    /// The starting output partition is derived from the input partition
693    /// to avoid skew when multiple input partitions are used.
694    pub fn new_round_robin_partitioner(
695        num_partitions: usize,
696        timer: metrics::Time,
697        input_partition: usize,
698        num_input_partitions: usize,
699    ) -> Self {
700        Self {
701            state: BatchPartitionerState::RoundRobin {
702                num_partitions,
703                next_idx: (input_partition * num_partitions) / num_input_partitions,
704            },
705            timer,
706        }
707    }
708    /// Create a new [`BatchPartitioner`] based on the provided [`Partitioning`] scheme.
709    ///
710    /// This is a convenience constructor that delegates to the specialized
711    /// hash or round-robin constructors depending on the partitioning variant.
712    ///
713    /// # Parameters
714    /// - `partitioning`: Partitioning scheme to apply (hash or round-robin).
715    /// - `timer`: Metric used to record time spent during repartitioning.
716    /// - `input_partition`: Index of the current input partition.
717    /// - `num_input_partitions`: Total number of input partitions.
718    ///
719    /// # Errors
720    /// Returns an error if the provided partitioning scheme is not supported,
721    /// or if hash partitioning is requested with zero output partitions.
722    pub fn try_new(
723        partitioning: Partitioning,
724        timer: metrics::Time,
725        input_partition: usize,
726        num_input_partitions: usize,
727    ) -> Result<Self> {
728        match partitioning {
729            Partitioning::Hash(exprs, num_partitions) => {
730                Self::new_hash_partitioner(exprs, num_partitions, timer)
731            }
732            Partitioning::RoundRobinBatch(num_partitions) => {
733                Ok(Self::new_round_robin_partitioner(
734                    num_partitions,
735                    timer,
736                    input_partition,
737                    num_input_partitions,
738                ))
739            }
740            other => {
741                not_impl_err!("Unsupported repartitioning scheme {other:?}")
742            }
743        }
744    }
745
746    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
747    /// based on the [`Partitioning`] specified on construction
748    ///
749    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
750    /// partition index. Any error returned by `f` will be immediately returned by this
751    /// function without attempting to publish further [`RecordBatch`]
752    ///
753    /// The time spent repartitioning, not including time spent in `f` will be recorded
754    /// to the [`metrics::Time`] provided on construction
755    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
756    where
757        F: FnMut(usize, RecordBatch) -> Result<()>,
758    {
759        self.partition_iter(batch)?.try_for_each(|res| match res {
760            Ok((partition, batch)) => f(partition, batch),
761            Err(e) => Err(e),
762        })
763    }
764
765    /// Returns an iterator of `(partition_index, RecordBatch)` pairs for the given batch.
766    ///
767    /// This is useful for async consumers that want to separate CPU-bound partitioning
768    /// from I/O. For example, you can iterate results on the async side and send them
769    /// through a channel, while performing file I/O on a blocking task:
770    ///
771    /// ```ignore
772    /// for result in partitioner.partition_iter(batch)? {
773    ///     let (partition, batch) = result?;
774    ///     tx.send((partition, batch)).await?;
775    /// }
776    /// ```
777    ///
778    /// The sync [`partition`](Self::partition) method is implemented on top of this.
779    pub fn partition_iter(
780        &mut self,
781        batch: RecordBatch,
782    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
783        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
784            match &mut self.state {
785                BatchPartitionerState::RoundRobin {
786                    num_partitions,
787                    next_idx,
788                } => {
789                    let idx = *next_idx;
790                    *next_idx = (*next_idx + 1) % *num_partitions;
791                    Box::new(std::iter::once(Ok((idx, batch))))
792                }
793                BatchPartitionerState::Hash {
794                    exprs,
795                    partition_reducer,
796                    hash_buffer,
797                    indices,
798                } => {
799                    // Tracking time required for distributing indexes across output partitions
800                    let timer = self.timer.timer();
801
802                    let arrays =
803                        evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
804
805                    hash_buffer.clear();
806                    hash_buffer.resize(batch.num_rows(), 0);
807
808                    create_hashes(
809                        &arrays,
810                        REPARTITION_RANDOM_STATE.random_state(),
811                        hash_buffer,
812                    )?;
813
814                    indices.iter_mut().for_each(|v| v.clear());
815
816                    partition_reducer.partition_indices(hash_buffer, indices);
817
818                    // Finished building index-arrays for output partitions
819                    timer.done();
820
821                    let partitioned_batches =
822                        Self::partition_grouped_take(&batch, indices, &self.timer)?;
823
824                    Box::new(partitioned_batches.into_iter())
825                }
826            };
827
828        Ok(it)
829    }
830
831    // return the number of output partitions
832    fn num_partitions(&self) -> usize {
833        match &self.state {
834            BatchPartitionerState::RoundRobin { num_partitions, .. } => *num_partitions,
835            BatchPartitionerState::Hash { indices, .. } => indices.len(),
836        }
837    }
838
839    /// Build repartitioned hash output batches using one `take` per input batch.
840    ///
841    /// The hash router first fills one index vector per output partition. This method
842    /// concatenates those index vectors, performs one grouped `take_arrays`, and
843    /// then returns each output partition as a slice of the reordered batch.
844    ///
845    /// For example, given partition indices:
846    ///
847    /// ```text
848    /// partition 0: [2, 5]
849    /// partition 1: []
850    /// partition 2: [0, 3, 4]
851    /// ```
852    ///
853    /// this method takes rows in `[2, 5, 0, 3, 4]` order once, then returns
854    /// `partition 0 = slice(0, 2)` and `partition 2 = slice(2, 3)`.
855    fn partition_grouped_take(
856        batch: &RecordBatch,
857        indices: &mut [Vec<u32>],
858        timer: &metrics::Time,
859    ) -> Result<Vec<Result<(usize, RecordBatch)>>> {
860        let mut partition_ranges = Vec::with_capacity(indices.len());
861        let mut reordered_indices = Vec::with_capacity(batch.num_rows());
862
863        for (partition, p_indices) in indices.iter_mut().enumerate() {
864            if p_indices.is_empty() {
865                continue;
866            }
867
868            let start = reordered_indices.len();
869            reordered_indices.extend_from_slice(p_indices);
870            partition_ranges.push((partition, start, p_indices.len()));
871            p_indices.clear();
872        }
873
874        if reordered_indices.is_empty() {
875            return Ok(vec![]);
876        }
877
878        let batches = {
879            let _timer = timer.timer();
880            let indices_array: PrimitiveArray<UInt32Type> = reordered_indices.into();
881            let columns = take_arrays(batch.columns(), &indices_array, None)?;
882
883            let mut options = RecordBatchOptions::new();
884            options = options.with_row_count(Some(indices_array.len()));
885            let reordered_batch =
886                RecordBatch::try_new_with_options(batch.schema(), columns, &options)?;
887
888            partition_ranges
889                .into_iter()
890                .map(|(partition, start, len)| {
891                    Ok((partition, reordered_batch.slice(start, len)))
892                })
893                .collect()
894        };
895
896        Ok(batches)
897    }
898}
899
900/// Maps `N` input partitions to `M` output partitions based on a
901/// [`Partitioning`] scheme.
902///
903/// # Background
904///
905/// DataFusion, like most other commercial systems, with the
906/// notable exception of DuckDB, uses the "Exchange Operator" based
907/// approach to parallelism which works well in practice given
908/// sufficient care in implementation.
909///
910/// DataFusion's planner picks the target number of partitions and
911/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
912/// of output partitions.
913///
914/// For example, given `target_partitions=3` (trying to use 3 cores)
915/// but scanning an input with 2 partitions, `RepartitionExec` can be
916/// used to get 3 even streams of `RecordBatch`es
917///
918///
919///```text
920///        ▲                  ▲                  ▲
921///        │                  │                  │
922///        │                  │                  │
923///        │                  │                  │
924/// ┌───────────────┐  ┌───────────────┐  ┌───────────────┐
925/// │    GroupBy    │  │    GroupBy    │  │    GroupBy    │
926/// │   (Partial)   │  │   (Partial)   │  │   (Partial)   │
927/// └───────────────┘  └───────────────┘  └───────────────┘
928///        ▲                  ▲                  ▲
929///        └──────────────────┼──────────────────┘
930///                           │
931///              ┌─────────────────────────┐
932///              │     RepartitionExec     │
933///              │   (hash/round robin)    │
934///              └─────────────────────────┘
935///                         ▲   ▲
936///             ┌───────────┘   └───────────┐
937///             │                           │
938///             │                           │
939///        .─────────.                 .─────────.
940///     ,─'           '─.           ,─'           '─.
941///    ;      Input      :         ;      Input      :
942///    :   Partition 0   ;         :   Partition 1   ;
943///     ╲               ╱           ╲               ╱
944///      '─.         ,─'             '─.         ,─'
945///         `───────'                   `───────'
946/// ```
947///
948/// # Error Handling
949///
950/// If any of the input partitions return an error, the error is propagated to
951/// all output partitions and inputs are not polled again.
952///
953/// # Output Ordering
954///
955/// If more than one stream is being repartitioned, the output will be some
956/// arbitrary interleaving (and thus unordered) unless
957/// [`Self::with_preserve_order`] specifies otherwise.
958///
959/// # Batch coalescing
960///
961/// Repartitioning one [`RecordBatch`] implies creating multiple smaller batches, potentially
962/// as many as the number of output partitions. [`RepartitionExec`] makes sure that the returned
963/// batches adhere to the configured `datafusion.execution.batch_size` for efficient operations,
964/// and for that, it will automatically coalesce batches right after repartitioning.
965///
966/// For this, one shared [`LimitedBatchCoalescer`] per output partition is used:
967///
968/// ```text
969///                         ┌───┐                           ┌───┐
970///                      ┌─▶│   │────────▶.───────────.     │   │     ┌──────────────────┐
971///                      │  └───┘ ┌───┐  ( Coalescer 0 )──▶ ├───┤ ───▶│     Output 0     │
972///                      │┌──────▶│   │──▶`───────────'     │   │     └──────────────────┘
973///                      ││       └───┘                     └───┘
974/// ┌──────────────────┐ ││                                           ┌──────────────────┐
975/// │BatchPartitioner 0│─┘│                                           │     Output 1     │
976/// └──────────────────┘  │                                           └──────────────────┘
977///                       │
978/// ┌──────────────────┐  │                ...                        ┌──────────────────┐
979/// │BatchPartitioner 1│──┘                                           │     Output 2     │
980/// └──────────────────┘                                              └──────────────────┘
981///
982///                                                                   ┌──────────────────┐
983///                                                                   │     Output 3     │
984///                                                                   └──────────────────┘
985/// ```
986///
987/// # Spilling Architecture
988///
989/// RepartitionExec uses [`SpillPool`](crate::spill::spill_pool) channels to handle
990/// memory pressure during repartitioning. Each (input partition, output partition)
991/// pair gets its own SpillPool channel for FIFO ordering.
992///
993/// ```text
994/// Input Partitions (N)          Output Partitions (M)
995/// ────────────────────          ─────────────────────
996///
997///    Input 0 ──┐                      ┌──▶ Output 0
998///              │  ┌──────────────┐    │
999///              ├─▶│ SpillPool    │────┤
1000///              │  │ [In0→Out0]   │    │
1001///    Input 1 ──┤  └──────────────┘    ├──▶ Output 1
1002///              │                       │
1003///              │  ┌──────────────┐    │
1004///              ├─▶│ SpillPool    │────┤
1005///              │  │ [In1→Out0]   │    │
1006///    Input 2 ──┤  └──────────────┘    ├──▶ Output 2
1007///              │                      │
1008///              │       ... (N×M SpillPools total)
1009///              │                      │
1010///              │  ┌──────────────┐    │
1011///              └─▶│ SpillPool    │────┘
1012///                 │ [InN→OutM]   │
1013///                 └──────────────┘
1014///
1015/// Each SpillPool maintains FIFO order for its (input, output) pair.
1016/// See `RepartitionBatch` for details on the memory/spill decision logic.
1017/// ```
1018///
1019/// # Footnote
1020///
1021/// The "Exchange Operator" was first described in the 1989 paper
1022/// [Encapsulation of parallelism in the Volcano query processing
1023/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
1024/// which uses the term "Exchange" for the concept of repartitioning
1025/// data across threads.
1026///
1027/// For more background, please also see the [Optimizing Repartitions in DataFusion] blog.
1028///
1029/// [Optimizing Repartitions in DataFusion]: https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions
1030#[derive(Debug, Clone)]
1031pub struct RepartitionExec {
1032    /// Input execution plan
1033    input: Arc<dyn ExecutionPlan>,
1034    /// Inner state that is initialized when the parent calls .execute() on this node
1035    /// and consumed as soon as the parent starts consuming this node.
1036    state: Arc<Mutex<RepartitionExecState>>,
1037    /// Execution metrics
1038    metrics: ExecutionPlanMetricsSet,
1039    /// Boolean flag to decide whether to preserve ordering. If true means
1040    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
1041    preserve_order: bool,
1042    /// Cache holding plan properties like equivalences, output partitioning etc.
1043    cache: Arc<PlanProperties>,
1044}
1045
1046#[derive(Debug, Clone)]
1047struct RepartitionMetrics {
1048    /// Time in nanos to execute child operator and fetch batches
1049    fetch_time: metrics::Time,
1050    /// Repartitioning elapsed time in nanos
1051    repartition_time: metrics::Time,
1052    /// Time in nanos for sending resulting batches to channels.
1053    ///
1054    /// One metric per output partition.
1055    send_time: Vec<metrics::Time>,
1056}
1057
1058impl RepartitionMetrics {
1059    pub fn new(
1060        input_partition: usize,
1061        num_output_partitions: usize,
1062        metrics: &ExecutionPlanMetricsSet,
1063    ) -> Self {
1064        // Time in nanos to execute child operator and fetch batches
1065        let fetch_time =
1066            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
1067
1068        // Time in nanos to perform repartitioning
1069        let repartition_time =
1070            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
1071
1072        // Time in nanos for sending resulting batches to channels
1073        let send_time = (0..num_output_partitions)
1074            .map(|output_partition| {
1075                let label =
1076                    metrics::Label::new("outputPartition", output_partition.to_string());
1077                MetricBuilder::new(metrics)
1078                    .with_label(label)
1079                    .subset_time("send_time", input_partition)
1080            })
1081            .collect();
1082
1083        Self {
1084            fetch_time,
1085            repartition_time,
1086            send_time,
1087        }
1088    }
1089}
1090
1091impl RepartitionExec {
1092    /// Input execution plan
1093    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
1094        &self.input
1095    }
1096
1097    /// Partitioning scheme to use
1098    pub fn partitioning(&self) -> &Partitioning {
1099        &self.cache.partitioning
1100    }
1101
1102    /// Get preserve_order flag of the RepartitionExec
1103    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
1104    pub fn preserve_order(&self) -> bool {
1105        self.preserve_order
1106    }
1107
1108    /// Get name used to display this Exec
1109    pub fn name(&self) -> &str {
1110        "RepartitionExec"
1111    }
1112
1113    fn with_new_children_and_same_properties(
1114        &self,
1115        mut children: Vec<Arc<dyn ExecutionPlan>>,
1116    ) -> Self {
1117        Self {
1118            input: children.swap_remove(0),
1119            metrics: ExecutionPlanMetricsSet::new(),
1120            state: Default::default(),
1121            ..Self::clone(self)
1122        }
1123    }
1124}
1125
1126impl DisplayAs for RepartitionExec {
1127    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1128        let input_partition_count = self.input.output_partitioning().partition_count();
1129        match t {
1130            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1131                write!(
1132                    f,
1133                    "{}: partitioning={}, input_partitions={}",
1134                    self.name(),
1135                    self.partitioning(),
1136                    input_partition_count,
1137                )?;
1138
1139                if self.preserve_order {
1140                    write!(f, ", preserve_order=true")?;
1141                } else if input_partition_count <= 1
1142                    && self.input.output_ordering().is_some()
1143                {
1144                    // Make it explicit that repartition maintains sortedness for a single input partition even
1145                    // when `preserve_sort order` is false
1146                    write!(f, ", maintains_sort_order=true")?;
1147                }
1148
1149                if let Some(sort_exprs) = self.sort_exprs() {
1150                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
1151                }
1152                Ok(())
1153            }
1154            DisplayFormatType::TreeRender => {
1155                writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
1156                let output_partition_count = self.partitioning().partition_count();
1157                let input_to_output_partition_str =
1158                    format!("{input_partition_count} -> {output_partition_count}");
1159                writeln!(
1160                    f,
1161                    "partition_count(in->out)={input_to_output_partition_str}"
1162                )?;
1163
1164                if self.preserve_order {
1165                    writeln!(f, "preserve_order={}", self.preserve_order)?;
1166                }
1167                Ok(())
1168            }
1169        }
1170    }
1171}
1172
1173impl ExecutionPlan for RepartitionExec {
1174    fn name(&self) -> &'static str {
1175        "RepartitionExec"
1176    }
1177
1178    /// Return a reference to Any that can be used for downcasting
1179    fn properties(&self) -> &Arc<PlanProperties> {
1180        &self.cache
1181    }
1182
1183    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1184        vec![&self.input]
1185    }
1186
1187    fn with_new_children(
1188        self: Arc<Self>,
1189        mut children: Vec<Arc<dyn ExecutionPlan>>,
1190    ) -> Result<Arc<dyn ExecutionPlan>> {
1191        check_if_same_properties!(self, children);
1192        let mut repartition = RepartitionExec::try_new(
1193            children.swap_remove(0),
1194            self.partitioning().clone(),
1195        )?;
1196        if self.preserve_order {
1197            repartition = repartition.with_preserve_order();
1198        }
1199        Ok(Arc::new(repartition))
1200    }
1201
1202    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
1203        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
1204    }
1205
1206    fn maintains_input_order(&self) -> Vec<bool> {
1207        Self::maintains_input_order_helper(self.input(), self.preserve_order)
1208    }
1209
1210    fn execute(
1211        &self,
1212        partition: usize,
1213        context: Arc<TaskContext>,
1214    ) -> Result<SendableRecordBatchStream> {
1215        trace!(
1216            "Start {}::execute for partition: {}",
1217            self.name(),
1218            partition
1219        );
1220
1221        let spill_metrics = SpillMetrics::new(&self.metrics, partition);
1222
1223        let input = Arc::clone(&self.input);
1224        let partitioning = self.partitioning().clone();
1225        let metrics = self.metrics.clone();
1226        let preserve_order = self.sort_exprs().is_some();
1227        let name = self.name().to_owned();
1228        let schema = self.schema();
1229        let schema_captured = Arc::clone(&schema);
1230
1231        let spill_manager = SpillManager::new(
1232            Arc::clone(&context.runtime_env()),
1233            spill_metrics,
1234            input.schema(),
1235        );
1236
1237        // Get existing ordering to use for merging
1238        let sort_exprs = self.sort_exprs().cloned();
1239
1240        let state = Arc::clone(&self.state);
1241        if let Some(mut state) = state.try_lock() {
1242            state.ensure_input_streams_initialized(
1243                &input,
1244                &metrics,
1245                partitioning.partition_count(),
1246                &context,
1247            )?;
1248        }
1249
1250        let num_input_partitions = input.output_partitioning().partition_count();
1251
1252        let stream = futures::stream::once(async move {
1253            // lock scope
1254            let (rx, reservation, spill_readers, abort_helper) = {
1255                // lock mutexes
1256                let mut state = state.lock();
1257                let state = state.consume_input_streams(
1258                    &input,
1259                    &metrics,
1260                    &partitioning,
1261                    preserve_order,
1262                    &name,
1263                    &context,
1264                    spill_manager.clone(),
1265                )?;
1266
1267                // now return stream for the specified *output* partition which will
1268                // read from the channel
1269                let PartitionChannels {
1270                    rx,
1271                    reservation,
1272                    spill_readers,
1273                    ..
1274                } = state
1275                    .channels
1276                    .remove(&partition)
1277                    .expect("partition not used yet");
1278
1279                (
1280                    rx,
1281                    reservation,
1282                    spill_readers,
1283                    Arc::clone(&state.abort_helper),
1284                )
1285            };
1286
1287            trace!(
1288                "Before returning stream in {name}::execute for partition: {partition}"
1289            );
1290
1291            if preserve_order {
1292                // Store streams from all the input partitions:
1293                // Each input partition gets its own spill reader to maintain proper FIFO ordering
1294                let input_streams = rx
1295                    .into_iter()
1296                    .zip(spill_readers)
1297                    .map(|(receiver, spill_stream)| {
1298                        // In preserve_order mode, each receiver corresponds to exactly one input partition
1299                        Box::pin(PerPartitionStream::new(
1300                            Arc::clone(&schema_captured),
1301                            receiver,
1302                            Arc::clone(&abort_helper),
1303                            Arc::clone(&reservation),
1304                            spill_stream,
1305                            1, // Each receiver handles one input partition
1306                            BaselineMetrics::new(&metrics, partition),
1307                        )) as SendableRecordBatchStream
1308                    })
1309                    .collect::<Vec<_>>();
1310                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
1311
1312                // Merge streams (while preserving ordering) coming from
1313                // input partitions to this partition:
1314                let fetch = None;
1315                let merge_reservation =
1316                    MemoryConsumer::new(format!("{name}[Merge {partition}]"))
1317                        .register(context.memory_pool());
1318                StreamingMergeBuilder::new()
1319                    .with_streams(input_streams)
1320                    .with_schema(schema_captured)
1321                    .with_expressions(&sort_exprs.unwrap())
1322                    .with_metrics(BaselineMetrics::new(&metrics, partition))
1323                    .with_batch_size(context.session_config().batch_size())
1324                    .with_fetch(fetch)
1325                    .with_reservation(merge_reservation)
1326                    .with_spill_manager(spill_manager)
1327                    .build()
1328            } else {
1329                // Non-preserve-order case: single input stream, so use the first spill reader
1330                let spill_stream = spill_readers
1331                    .into_iter()
1332                    .next()
1333                    .expect("at least one spill reader should exist");
1334
1335                Ok(Box::pin(PerPartitionStream::new(
1336                    schema_captured,
1337                    rx.into_iter()
1338                        .next()
1339                        .expect("at least one receiver should exist"),
1340                    abort_helper,
1341                    reservation,
1342                    spill_stream,
1343                    num_input_partitions,
1344                    BaselineMetrics::new(&metrics, partition),
1345                )) as SendableRecordBatchStream)
1346            }
1347        })
1348        .try_flatten();
1349        let stream = RecordBatchStreamAdapter::new(schema, stream);
1350        Ok(Box::pin(stream))
1351    }
1352
1353    fn metrics(&self) -> Option<MetricsSet> {
1354        Some(self.metrics.clone_inner())
1355    }
1356
1357    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
1358        if let Some(partition) = partition {
1359            let partition_count = self.partitioning().partition_count();
1360            if partition_count == 0 {
1361                return Ok(Arc::new(Statistics::new_unknown(&self.schema())));
1362            }
1363
1364            assert_or_internal_err!(
1365                partition < partition_count,
1366                "RepartitionExec invalid partition {} (expected less than {})",
1367                partition,
1368                partition_count
1369            );
1370
1371            let mut stats = Arc::unwrap_or_clone(self.input.partition_statistics(None)?);
1372
1373            // Distribute statistics across partitions
1374            stats.num_rows = stats
1375                .num_rows
1376                .get_value()
1377                .map(|rows| Precision::Inexact(rows / partition_count))
1378                .unwrap_or(Precision::Absent);
1379            stats.total_byte_size = stats
1380                .total_byte_size
1381                .get_value()
1382                .map(|bytes| Precision::Inexact(bytes / partition_count))
1383                .unwrap_or(Precision::Absent);
1384
1385            // Make all column stats unknown
1386            stats.column_statistics = stats
1387                .column_statistics
1388                .iter()
1389                .map(|_| ColumnStatistics::new_unknown())
1390                .collect();
1391
1392            Ok(Arc::new(stats))
1393        } else {
1394            self.input.partition_statistics(None)
1395        }
1396    }
1397
1398    fn cardinality_effect(&self) -> CardinalityEffect {
1399        CardinalityEffect::Equal
1400    }
1401
1402    fn try_swapping_with_projection(
1403        &self,
1404        projection: &ProjectionExec,
1405    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1406        // If the projection does not narrow the schema, we should not try to push it down.
1407        if projection.expr().len() >= projection.input().schema().fields().len() {
1408            return Ok(None);
1409        }
1410
1411        // If pushdown is not beneficial or applicable, break it.
1412        if projection.benefits_from_input_partitioning()[0]
1413            || !all_columns(projection.expr())
1414        {
1415            return Ok(None);
1416        }
1417
1418        let new_projection = make_with_child(projection, self.input())?;
1419
1420        let new_partitioning = match self.partitioning() {
1421            Partitioning::Hash(partitions, size) => {
1422                let mut new_partitions = vec![];
1423                for partition in partitions {
1424                    let Some(new_partition) =
1425                        update_expr(partition, projection.expr(), false)?
1426                    else {
1427                        return Ok(None);
1428                    };
1429                    new_partitions.push(new_partition);
1430                }
1431                Partitioning::Hash(new_partitions, *size)
1432            }
1433            others => others.clone(),
1434        };
1435
1436        Ok(Some(Arc::new(RepartitionExec::try_new(
1437            new_projection,
1438            new_partitioning,
1439        )?)))
1440    }
1441
1442    fn gather_filters_for_pushdown(
1443        &self,
1444        _phase: FilterPushdownPhase,
1445        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1446        _config: &ConfigOptions,
1447    ) -> Result<FilterDescription> {
1448        FilterDescription::from_children(parent_filters, &self.children())
1449    }
1450
1451    fn handle_child_pushdown_result(
1452        &self,
1453        _phase: FilterPushdownPhase,
1454        child_pushdown_result: ChildPushdownResult,
1455        _config: &ConfigOptions,
1456    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1457        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1458    }
1459
1460    fn try_pushdown_sort(
1461        &self,
1462        order: &[PhysicalSortExpr],
1463    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
1464        // RepartitionExec only maintains input order if preserve_order is set
1465        // or if there's only one partition
1466        if !self.maintains_input_order()[0] {
1467            return Ok(SortOrderPushdownResult::Unsupported);
1468        }
1469
1470        // Delegate to the child and wrap with a new RepartitionExec
1471        self.input.try_pushdown_sort(order)?.try_map(|new_input| {
1472            let mut new_repartition =
1473                RepartitionExec::try_new(new_input, self.partitioning().clone())?;
1474            if self.preserve_order {
1475                new_repartition = new_repartition.with_preserve_order();
1476            }
1477            Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
1478        })
1479    }
1480
1481    fn repartitioned(
1482        &self,
1483        target_partitions: usize,
1484        _config: &ConfigOptions,
1485    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1486        use Partitioning::*;
1487        let mut new_properties = PlanProperties::clone(&self.cache);
1488        new_properties.partitioning = match new_properties.partitioning {
1489            RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1490            Hash(hash, _) => Hash(hash, target_partitions),
1491            UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1492        };
1493        Ok(Some(Arc::new(Self {
1494            input: Arc::clone(&self.input),
1495            state: Arc::clone(&self.state),
1496            metrics: self.metrics.clone(),
1497            preserve_order: self.preserve_order,
1498            cache: new_properties.into(),
1499        })))
1500    }
1501}
1502
1503impl RepartitionExec {
1504    /// Create a new RepartitionExec, that produces output `partitioning`, and
1505    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
1506    /// for more details)
1507    pub fn try_new(
1508        input: Arc<dyn ExecutionPlan>,
1509        partitioning: Partitioning,
1510    ) -> Result<Self> {
1511        let preserve_order = false;
1512        let cache = Self::compute_properties(&input, partitioning, preserve_order);
1513        Ok(RepartitionExec {
1514            input,
1515            state: Default::default(),
1516            metrics: ExecutionPlanMetricsSet::new(),
1517            preserve_order,
1518            cache: Arc::new(cache),
1519        })
1520    }
1521
1522    fn maintains_input_order_helper(
1523        input: &Arc<dyn ExecutionPlan>,
1524        preserve_order: bool,
1525    ) -> Vec<bool> {
1526        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
1527        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1528    }
1529
1530    fn eq_properties_helper(
1531        input: &Arc<dyn ExecutionPlan>,
1532        preserve_order: bool,
1533    ) -> EquivalenceProperties {
1534        // Equivalence Properties
1535        let mut eq_properties = input.equivalence_properties().clone();
1536        // If the ordering is lost, reset the ordering equivalence class:
1537        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1538            eq_properties.clear_orderings();
1539        }
1540        // When there are more than one input partitions, they will be fused at the output.
1541        // Therefore, remove per partition constants.
1542        if input.output_partitioning().partition_count() > 1 {
1543            eq_properties.clear_per_partition_constants();
1544        }
1545        eq_properties
1546    }
1547
1548    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1549    fn compute_properties(
1550        input: &Arc<dyn ExecutionPlan>,
1551        partitioning: Partitioning,
1552        preserve_order: bool,
1553    ) -> PlanProperties {
1554        PlanProperties::new(
1555            Self::eq_properties_helper(input, preserve_order),
1556            partitioning,
1557            input.pipeline_behavior(),
1558            input.boundedness(),
1559        )
1560        .with_scheduling_type(SchedulingType::Cooperative)
1561        .with_evaluation_type(EvaluationType::Eager)
1562    }
1563
1564    /// Specify if this repartitioning operation should preserve the order of
1565    /// rows from its input when producing output. Preserving order is more
1566    /// expensive at runtime, so should only be set if the output of this
1567    /// operator can take advantage of it.
1568    ///
1569    /// If the input is not ordered, or has only one partition, this is a no op,
1570    /// and the node remains a `RepartitionExec`.
1571    pub fn with_preserve_order(mut self) -> Self {
1572        self.preserve_order =
1573                // If the input isn't ordered, there is no ordering to preserve
1574                self.input.output_ordering().is_some() &&
1575                // if there is only one input partition, merging is not required
1576                // to maintain order
1577                self.input.output_partitioning().partition_count() > 1;
1578        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1579        Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties);
1580        self
1581    }
1582
1583    /// Return the sort expressions that are used to merge
1584    fn sort_exprs(&self) -> Option<&LexOrdering> {
1585        if self.preserve_order {
1586            self.input.output_ordering()
1587        } else {
1588            None
1589        }
1590    }
1591
1592    /// Pulls data from the specified input plan, feeding it to the
1593    /// output partitions based on the desired partitioning
1594    ///
1595    /// `output_channels` holds the output sending channels for each output partition
1596    async fn pull_from_input(
1597        mut stream: SendableRecordBatchStream,
1598        mut output_channels: HashMap<usize, OutputChannel>,
1599        partitioning: Partitioning,
1600        metrics: RepartitionMetrics,
1601        input_partition: usize,
1602        num_input_partitions: usize,
1603    ) -> Result<()> {
1604        let mut partitioner = match &partitioning {
1605            Partitioning::Hash(exprs, num_partitions) => {
1606                BatchPartitioner::new_hash_partitioner(
1607                    exprs.clone(),
1608                    *num_partitions,
1609                    metrics.repartition_time.clone(),
1610                )?
1611            }
1612            Partitioning::RoundRobinBatch(num_partitions) => {
1613                BatchPartitioner::new_round_robin_partitioner(
1614                    *num_partitions,
1615                    metrics.repartition_time.clone(),
1616                    input_partition,
1617                    num_input_partitions,
1618                )
1619            }
1620            other => {
1621                return not_impl_err!("Unsupported repartitioning scheme {other:?}");
1622            }
1623        };
1624
1625        // While there are still outputs to send to, keep pulling inputs
1626        let mut batches_until_yield = partitioner.num_partitions();
1627        while !output_channels.is_empty() {
1628            // fetch the next batch
1629            let timer = metrics.fetch_time.timer();
1630            let result = stream.next().await;
1631            timer.done();
1632
1633            // Input is done
1634            let batch = match result {
1635                Some(result) => result?,
1636                None => break,
1637            };
1638
1639            // Handle empty batch
1640            if batch.num_rows() == 0 {
1641                continue;
1642            }
1643
1644            for res in partitioner.partition_iter(batch)? {
1645                let (partition, batch) = res?;
1646
1647                let timer = metrics.send_time[partition].timer();
1648                // if there is still a receiver, send to it
1649                if let Some(output_channel) = output_channels.get_mut(&partition) {
1650                    for batch in output_channel.coalesce(batch)? {
1651                        if output_channel.send(batch).await.is_err() {
1652                            // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
1653                            // so ignore this channel from now on.
1654                            output_channels.remove(&partition);
1655                            break;
1656                        }
1657                    }
1658                }
1659                timer.done();
1660            }
1661
1662            // If the input stream is endless, we may spin forever and
1663            // never yield back to tokio.  See
1664            // https://github.com/apache/datafusion/issues/5278.
1665            //
1666            // However, yielding on every batch causes a bottleneck
1667            // when running with multiple cores. See
1668            // https://github.com/apache/datafusion/issues/6290
1669            //
1670            // Thus, heuristically yield after producing num_partition
1671            // batches
1672            //
1673            // In round robin this is ideal as each input will get a
1674            // new batch. In hash partitioning it may yield too often
1675            // on uneven distributions even if some partition can not
1676            // make progress, but parallelism is going to be limited
1677            // in that case anyways
1678            if batches_until_yield == 0 {
1679                tokio::task::yield_now().await;
1680                batches_until_yield = partitioner.num_partitions();
1681            } else {
1682                batches_until_yield -= 1;
1683            }
1684        }
1685
1686        // End of input for this task. For each output partition we still
1687        // have a channel to, decrement the active-senders counter; whoever
1688        // sees the count drop to zero is the last input task and must
1689        // finalize the shared coalescer and ship its residual.
1690        for (_, output_channel) in output_channels.drain() {
1691            output_channel.finalize().await?;
1692        }
1693
1694        // Spill writers will auto-finalize when dropped
1695        // No need for explicit flush
1696        Ok(())
1697    }
1698
1699    /// Waits for `input_task` which is consuming one of the inputs to
1700    /// complete. Upon each successful completion, sends a `None` to
1701    /// each of the output tx channels to signal one of the inputs is
1702    /// complete. Upon error, propagates the errors to all output tx
1703    /// channels.
1704    async fn wait_for_task(
1705        input_task: SpawnedTask<Result<()>>,
1706        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1707    ) {
1708        // wait for completion, and propagate error
1709        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
1710
1711        match input_task.join().await {
1712            // Error in joining task
1713            Err(e) => {
1714                let e = Arc::new(e);
1715
1716                for (_, tx) in txs {
1717                    let err = Err(DataFusionError::Context(
1718                        "Join Error".to_string(),
1719                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1720                    ));
1721                    tx.send(Some(err)).await.ok();
1722                }
1723            }
1724            // Error from running input task
1725            Ok(Err(e)) => {
1726                // send the same Arc'd error to all output partitions
1727                let e = Arc::new(e);
1728
1729                for (_, tx) in txs {
1730                    // wrap it because need to send error to all output partitions
1731                    let err = Err(DataFusionError::from(&e));
1732                    tx.send(Some(err)).await.ok();
1733                }
1734            }
1735            // Input task completed successfully
1736            Ok(Ok(())) => {
1737                // notify each output partition that this input partition has no more data
1738                for (_partition, tx) in txs {
1739                    tx.send(None).await.ok();
1740                }
1741            }
1742        }
1743    }
1744}
1745
1746/// State for tracking whether we're reading from memory channel or spill stream.
1747///
1748/// This state machine ensures proper ordering when batches are mixed between memory
1749/// and spilled storage. When a [`RepartitionBatch::Spilled`] marker is received,
1750/// the stream must block on the spill stream until the corresponding batch arrives.
1751///
1752/// # State Machine
1753///
1754/// ```text
1755///                        ┌─────────────────┐
1756///                   ┌───▶│  ReadingMemory  │◀───┐
1757///                   │    └────────┬────────┘    │
1758///                   │             │             │
1759///                   │     Poll channel          │
1760///                   │             │             │
1761///                   │  ┌──────────┼─────────────┐
1762///                   │  │          │             │
1763///                   │  ▼          ▼             │
1764///                   │ Memory   Spilled          │
1765///       Got batch   │ batch    marker           │
1766///       from spill  │  │          │             │
1767///                   │  │          ▼             │
1768///                   │  │  ┌──────────────────┐  │
1769///                   │  │  │ ReadingSpilled   │  │
1770///                   │  │  └────────┬─────────┘  │
1771///                   │  │           │            │
1772///                   │  │   Poll spill_stream    │
1773///                   │  │           │            │
1774///                   │  │           ▼            │
1775///                   │  │      Get batch         │
1776///                   │  │           │            │
1777///                   └──┴───────────┴────────────┘
1778///                                  │
1779///                                  ▼
1780///                           Return batch
1781///                     (Order preserved within
1782///                      (input, output) pair)
1783/// ```
1784///
1785/// The transition to `ReadingSpilled` blocks further channel polling to maintain
1786/// FIFO ordering - we cannot read the next item from the channel until the spill
1787/// stream provides the current batch.
1788#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1789enum StreamState {
1790    /// Reading from the memory channel (normal operation)
1791    ReadingMemory,
1792    /// Waiting for a spilled batch from the spill stream.
1793    /// Must not poll channel until spilled batch is received to preserve ordering.
1794    ReadingSpilled,
1795}
1796
1797/// This struct converts a receiver to a stream.
1798/// Receiver receives data on an SPSC channel.
1799struct PerPartitionStream {
1800    /// Schema wrapped by Arc
1801    schema: SchemaRef,
1802
1803    /// channel containing the repartitioned batches
1804    receiver: DistributionReceiver<MaybeBatch>,
1805
1806    /// Handle to ensure background tasks are killed when no longer needed.
1807    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1808
1809    /// Memory reservation.
1810    reservation: SharedMemoryReservation,
1811
1812    /// Infinite stream for reading from the spill pool
1813    spill_stream: SendableRecordBatchStream,
1814
1815    /// Internal state indicating if we are reading from memory or spill stream
1816    state: StreamState,
1817
1818    /// Number of input partitions that have not yet finished.
1819    /// In non-preserve-order mode, multiple input partitions send to the same channel,
1820    /// each sending None when complete. We must wait for all of them.
1821    remaining_partitions: usize,
1822
1823    /// Execution metrics
1824    baseline_metrics: BaselineMetrics,
1825}
1826
1827impl PerPartitionStream {
1828    fn new(
1829        schema: SchemaRef,
1830        receiver: DistributionReceiver<MaybeBatch>,
1831        drop_helper: Arc<Vec<SpawnedTask<()>>>,
1832        reservation: SharedMemoryReservation,
1833        spill_stream: SendableRecordBatchStream,
1834        num_input_partitions: usize,
1835        baseline_metrics: BaselineMetrics,
1836    ) -> Self {
1837        Self {
1838            schema,
1839            receiver,
1840            _drop_helper: drop_helper,
1841            reservation,
1842            spill_stream,
1843            state: StreamState::ReadingMemory,
1844            remaining_partitions: num_input_partitions,
1845            baseline_metrics,
1846        }
1847    }
1848
1849    fn poll_next_inner(
1850        self: &mut Pin<&mut Self>,
1851        cx: &mut Context<'_>,
1852    ) -> Poll<Option<Result<RecordBatch>>> {
1853        use futures::StreamExt;
1854        let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1855        let _timer = cloned_time.timer();
1856
1857        loop {
1858            match self.state {
1859                StreamState::ReadingMemory => {
1860                    // Poll the memory channel for next message
1861                    let value = match self.receiver.recv().poll_unpin(cx) {
1862                        Poll::Ready(v) => v,
1863                        Poll::Pending => {
1864                            // Nothing from channel, wait
1865                            return Poll::Pending;
1866                        }
1867                    };
1868
1869                    match value {
1870                        Some(Some(v)) => match v {
1871                            Ok(RepartitionBatch::Memory(batch)) => {
1872                                // Release memory and return batch
1873                                self.reservation.shrink(batch.get_array_memory_size());
1874                                return Poll::Ready(Some(Ok(batch)));
1875                            }
1876                            Ok(RepartitionBatch::Spilled) => {
1877                                // Batch was spilled, transition to reading from spill stream
1878                                // We must block on spill stream until we get the batch
1879                                // to preserve ordering
1880                                self.state = StreamState::ReadingSpilled;
1881                                continue;
1882                            }
1883                            Err(e) => {
1884                                return Poll::Ready(Some(Err(e)));
1885                            }
1886                        },
1887                        Some(None) => {
1888                            // One input partition finished
1889                            self.remaining_partitions -= 1;
1890                            if self.remaining_partitions == 0 {
1891                                // All input partitions finished
1892                                return Poll::Ready(None);
1893                            }
1894                            // Continue to poll for more data from other partitions
1895                            continue;
1896                        }
1897                        None => {
1898                            // Channel closed unexpectedly
1899                            return Poll::Ready(None);
1900                        }
1901                    }
1902                }
1903                StreamState::ReadingSpilled => {
1904                    // Poll spill stream for the spilled batch
1905                    match self.spill_stream.poll_next_unpin(cx) {
1906                        Poll::Ready(Some(Ok(batch))) => {
1907                            self.state = StreamState::ReadingMemory;
1908                            return Poll::Ready(Some(Ok(batch)));
1909                        }
1910                        Poll::Ready(Some(Err(e))) => {
1911                            return Poll::Ready(Some(Err(e)));
1912                        }
1913                        Poll::Ready(None) => {
1914                            // Spill stream ended — release its resources before
1915                            // we go back to draining the memory channel.
1916                            let spill_schema = self.spill_stream.schema();
1917                            self.spill_stream =
1918                                Box::pin(EmptyRecordBatchStream::new(spill_schema));
1919                            self.state = StreamState::ReadingMemory;
1920                        }
1921                        Poll::Pending => {
1922                            // Spilled batch not ready yet, must wait
1923                            // This preserves ordering by blocking until spill data arrives
1924                            return Poll::Pending;
1925                        }
1926                    }
1927                }
1928            }
1929        }
1930    }
1931}
1932
1933impl Stream for PerPartitionStream {
1934    type Item = Result<RecordBatch>;
1935
1936    fn poll_next(
1937        mut self: Pin<&mut Self>,
1938        cx: &mut Context<'_>,
1939    ) -> Poll<Option<Self::Item>> {
1940        let poll = self.poll_next_inner(cx);
1941        self.baseline_metrics.record_poll(poll)
1942    }
1943}
1944
1945impl RecordBatchStream for PerPartitionStream {
1946    /// Get the schema
1947    fn schema(&self) -> SchemaRef {
1948        Arc::clone(&self.schema)
1949    }
1950}
1951
1952#[cfg(test)]
1953mod tests {
1954    use std::collections::HashSet;
1955
1956    use super::*;
1957    use crate::test::TestMemoryExec;
1958    use crate::{
1959        test::{
1960            assert_is_pending,
1961            exec::{
1962                BarrierExec, BlockingExec, ErrorExec, MockExec,
1963                assert_strong_count_converges_to_zero,
1964            },
1965        },
1966        {collect, expressions::col},
1967    };
1968
1969    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1970    use arrow::datatypes::{DataType, Field, Schema};
1971    use datafusion_common::cast::as_string_array;
1972    use datafusion_common::exec_err;
1973    use datafusion_common::test_util::batches_to_sort_string;
1974    use datafusion_common_runtime::JoinSet;
1975    use datafusion_execution::config::SessionConfig;
1976    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1977    use insta::assert_snapshot;
1978
1979    #[test]
1980    fn strength_reduced_u64_remainder_matches_modulo() {
1981        let divisors = [
1982            1,
1983            2,
1984            3,
1985            4,
1986            5,
1987            7,
1988            8,
1989            10,
1990            16,
1991            31,
1992            32,
1993            63,
1994            64,
1995            65,
1996            97,
1997            u64::from(u32::MAX),
1998            u64::from(u32::MAX) + 1,
1999            1_u64 << 32,
2000            (1_u64 << 63) - 1,
2001            1_u64 << 63,
2002            u64::MAX - 1,
2003            u64::MAX,
2004        ];
2005        let values = [
2006            0,
2007            1,
2008            2,
2009            3,
2010            4,
2011            5,
2012            31,
2013            32,
2014            33,
2015            63,
2016            64,
2017            65,
2018            u64::from(u32::MAX) - 1,
2019            u64::from(u32::MAX),
2020            u64::from(u32::MAX) + 1,
2021            (1_u64 << 32) - 1,
2022            1_u64 << 32,
2023            (1_u64 << 32) + 1,
2024            (1_u64 << 63) - 1,
2025            1_u64 << 63,
2026            (1_u64 << 63) + 1,
2027            u64::MAX - 1,
2028            u64::MAX,
2029        ];
2030
2031        for divisor in divisors {
2032            let reducer = StrengthReducedU64::new(divisor);
2033            for value in values {
2034                assert_eq!(
2035                    reducer.remainder(value),
2036                    value % divisor,
2037                    "value={value} divisor={divisor}"
2038                );
2039            }
2040
2041            let mut value = 0x1234_5678_9abc_def0 ^ divisor;
2042            for _ in 0..10_000 {
2043                value = value
2044                    .wrapping_mul(6_364_136_223_846_793_005)
2045                    .wrapping_add(1_442_695_040_888_963_407);
2046                assert_eq!(
2047                    reducer.remainder(value),
2048                    value % divisor,
2049                    "value={value} divisor={divisor}"
2050                );
2051            }
2052        }
2053    }
2054
2055    #[test]
2056    fn hash_partitioner_requires_nonzero_partitions() {
2057        let metrics = ExecutionPlanMetricsSet::new();
2058        let timer = MetricBuilder::new(&metrics).subset_time("test", 0);
2059
2060        let err = BatchPartitioner::new_hash_partitioner(vec![], 0, timer)
2061            .err()
2062            .expect("zero hash partitions should fail")
2063            .to_string();
2064
2065        assert!(
2066            err.contains("Hash repartition requires at least one partition"),
2067            "actual: {err}"
2068        );
2069    }
2070
2071    #[tokio::test]
2072    async fn one_to_many_round_robin() -> Result<()> {
2073        // define input partitions
2074        let schema = test_schema();
2075        let partition = create_vec_batches(50);
2076        let partitions = vec![partition];
2077
2078        // repartition from 1 input to 4 output
2079        let output_partitions =
2080            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
2081
2082        assert_eq!(4, output_partitions.len());
2083        for partition in &output_partitions {
2084            assert_eq!(1, partition.len());
2085        }
2086        assert_eq!(13 * 8, output_partitions[0][0].num_rows());
2087        assert_eq!(13 * 8, output_partitions[1][0].num_rows());
2088        assert_eq!(12 * 8, output_partitions[2][0].num_rows());
2089        assert_eq!(12 * 8, output_partitions[3][0].num_rows());
2090
2091        Ok(())
2092    }
2093
2094    #[tokio::test]
2095    async fn many_to_one_round_robin() -> Result<()> {
2096        // define input partitions
2097        let schema = test_schema();
2098        let partition = create_vec_batches(50);
2099        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
2100
2101        // repartition from 3 input to 1 output
2102        let output_partitions =
2103            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
2104
2105        assert_eq!(1, output_partitions.len());
2106        assert_eq!(150 * 8, output_partitions[0][0].num_rows());
2107
2108        Ok(())
2109    }
2110
2111    #[tokio::test]
2112    async fn many_to_many_round_robin() -> Result<()> {
2113        // define input partitions
2114        let schema = test_schema();
2115        let partition = create_vec_batches(50);
2116        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
2117
2118        // repartition from 3 input to 5 output
2119        let output_partitions =
2120            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
2121
2122        let total_rows_per_partition = 8 * 50 * 3 / 5;
2123        assert_eq!(5, output_partitions.len());
2124        for partition in output_partitions {
2125            assert_eq!(1, partition.len());
2126            assert_eq!(total_rows_per_partition, partition[0].num_rows());
2127        }
2128
2129        Ok(())
2130    }
2131
2132    #[tokio::test]
2133    async fn many_to_many_hash_partition() -> Result<()> {
2134        // define input partitions
2135        let schema = test_schema();
2136        let partition = create_vec_batches(50);
2137        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
2138
2139        let output_partitions = repartition(
2140            &schema,
2141            partitions,
2142            Partitioning::Hash(vec![col("c0", &schema)?], 8),
2143        )
2144        .await?;
2145
2146        let total_rows: usize = output_partitions
2147            .iter()
2148            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
2149            .sum();
2150
2151        assert_eq!(8, output_partitions.len());
2152        assert_eq!(total_rows, 8 * 50 * 3);
2153
2154        Ok(())
2155    }
2156
2157    #[tokio::test]
2158    async fn test_repartition_with_coalescing() -> Result<()> {
2159        let schema = test_schema();
2160        // create 50 batches, each having 8 rows
2161        let partition = create_vec_batches(50);
2162        let partitions = vec![partition.clone(), partition.clone()];
2163        let partitioning = Partitioning::RoundRobinBatch(1);
2164
2165        let session_config = SessionConfig::new().with_batch_size(200);
2166        let task_ctx = TaskContext::default().with_session_config(session_config);
2167        let task_ctx = Arc::new(task_ctx);
2168
2169        // create physical plan
2170        let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
2171        let exec = RepartitionExec::try_new(exec, partitioning)?;
2172
2173        for i in 0..exec.partitioning().partition_count() {
2174            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2175            while let Some(result) = stream.next().await {
2176                let batch = result?;
2177                assert_eq!(200, batch.num_rows());
2178            }
2179        }
2180        Ok(())
2181    }
2182
2183    fn test_schema() -> Arc<Schema> {
2184        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2185    }
2186
2187    async fn repartition(
2188        schema: &SchemaRef,
2189        input_partitions: Vec<Vec<RecordBatch>>,
2190        partitioning: Partitioning,
2191    ) -> Result<Vec<Vec<RecordBatch>>> {
2192        let task_ctx = Arc::new(TaskContext::default());
2193        // create physical plan
2194        let exec =
2195            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
2196        let exec = RepartitionExec::try_new(exec, partitioning)?;
2197
2198        // execute and collect results
2199        let mut output_partitions = vec![];
2200        for i in 0..exec.partitioning().partition_count() {
2201            // execute this *output* partition and collect all batches
2202            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2203            let mut batches = vec![];
2204            while let Some(result) = stream.next().await {
2205                batches.push(result?);
2206            }
2207            output_partitions.push(batches);
2208        }
2209        Ok(output_partitions)
2210    }
2211
2212    #[tokio::test]
2213    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
2214        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
2215            SpawnedTask::spawn(async move {
2216                // define input partitions
2217                let schema = test_schema();
2218                let partition = create_vec_batches(50);
2219                let partitions =
2220                    vec![partition.clone(), partition.clone(), partition.clone()];
2221
2222                // repartition from 3 input to 5 output
2223                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
2224            });
2225
2226        let output_partitions = handle.join().await.unwrap().unwrap();
2227
2228        let total_rows_per_partition = 8 * 50 * 3 / 5;
2229        assert_eq!(5, output_partitions.len());
2230        for partition in output_partitions {
2231            assert_eq!(1, partition.len());
2232            assert_eq!(total_rows_per_partition, partition[0].num_rows());
2233        }
2234
2235        Ok(())
2236    }
2237
2238    #[tokio::test]
2239    async fn unsupported_partitioning() {
2240        let task_ctx = Arc::new(TaskContext::default());
2241        // have to send at least one batch through to provoke error
2242        let batch = RecordBatch::try_from_iter(vec![(
2243            "my_awesome_field",
2244            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2245        )])
2246        .unwrap();
2247
2248        let schema = batch.schema();
2249        let input = MockExec::new(vec![Ok(batch)], schema);
2250        // This generates an error (partitioning type not supported)
2251        // but only after the plan is executed. The error should be
2252        // returned and no results produced
2253        let partitioning = Partitioning::UnknownPartitioning(1);
2254        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2255        let output_stream = exec.execute(0, task_ctx).unwrap();
2256
2257        // Expect that an error is returned
2258        let result_string = crate::common::collect(output_stream)
2259            .await
2260            .unwrap_err()
2261            .to_string();
2262        assert!(
2263            result_string
2264                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
2265            "actual: {result_string}"
2266        );
2267    }
2268
2269    #[tokio::test]
2270    async fn error_for_input_exec() {
2271        // This generates an error on a call to execute. The error
2272        // should be returned and no results produced.
2273
2274        let task_ctx = Arc::new(TaskContext::default());
2275        let input = ErrorExec::new();
2276        let partitioning = Partitioning::RoundRobinBatch(1);
2277        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2278
2279        // Expect that an error is returned
2280        let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
2281
2282        assert!(
2283            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
2284            "actual: {result_string}"
2285        );
2286    }
2287
2288    #[tokio::test]
2289    async fn repartition_with_error_in_stream() {
2290        let task_ctx = Arc::new(TaskContext::default());
2291        let batch = RecordBatch::try_from_iter(vec![(
2292            "my_awesome_field",
2293            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2294        )])
2295        .unwrap();
2296
2297        // input stream returns one good batch and then one error. The
2298        // error should be returned.
2299        let err = exec_err!("bad data error");
2300
2301        let schema = batch.schema();
2302        let input = MockExec::new(vec![Ok(batch), err], schema);
2303        let partitioning = Partitioning::RoundRobinBatch(1);
2304        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2305
2306        // Note: this should pass (the stream can be created) but the
2307        // error when the input is executed should get passed back
2308        let output_stream = exec.execute(0, task_ctx).unwrap();
2309
2310        // Expect that an error is returned
2311        let result_string = crate::common::collect(output_stream)
2312            .await
2313            .unwrap_err()
2314            .to_string();
2315        assert!(
2316            result_string.contains("bad data error"),
2317            "actual: {result_string}"
2318        );
2319    }
2320
2321    #[tokio::test]
2322    async fn repartition_with_delayed_stream() {
2323        let task_ctx = Arc::new(TaskContext::default());
2324        let batch1 = RecordBatch::try_from_iter(vec![(
2325            "my_awesome_field",
2326            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2327        )])
2328        .unwrap();
2329
2330        let batch2 = RecordBatch::try_from_iter(vec![(
2331            "my_awesome_field",
2332            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2333        )])
2334        .unwrap();
2335
2336        // The mock exec doesn't return immediately (instead it
2337        // requires the input to wait at least once)
2338        let schema = batch1.schema();
2339        let expected_batches = vec![batch1.clone(), batch2.clone()];
2340        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
2341        let partitioning = Partitioning::RoundRobinBatch(1);
2342
2343        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2344
2345        assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
2346        +------------------+
2347        | my_awesome_field |
2348        +------------------+
2349        | bar              |
2350        | baz              |
2351        | foo              |
2352        | frob             |
2353        +------------------+
2354        ");
2355
2356        let output_stream = exec.execute(0, task_ctx).unwrap();
2357        let batches = crate::common::collect(output_stream).await.unwrap();
2358
2359        assert_snapshot!(batches_to_sort_string(&batches), @r"
2360        +------------------+
2361        | my_awesome_field |
2362        +------------------+
2363        | bar              |
2364        | baz              |
2365        | foo              |
2366        | frob             |
2367        +------------------+
2368        ");
2369    }
2370
2371    #[tokio::test]
2372    async fn robin_repartition_with_dropping_output_stream() {
2373        let task_ctx = Arc::new(TaskContext::default());
2374        let partitioning = Partitioning::RoundRobinBatch(2);
2375        // The barrier exec waits to be pinged
2376        // requires the input to wait at least once)
2377        let input = Arc::new(make_barrier_exec());
2378
2379        // partition into two output streams
2380        let exec = RepartitionExec::try_new(
2381            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2382            partitioning,
2383        )
2384        .unwrap();
2385
2386        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2387        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2388
2389        // now, purposely drop output stream 0
2390        // *before* any outputs are produced
2391        drop(output_stream0);
2392
2393        // Now, start sending input
2394        let mut background_task = JoinSet::new();
2395        background_task.spawn(async move {
2396            input.wait().await;
2397        });
2398
2399        // output stream 1 should *not* error and have one of the input batches
2400        let batches = crate::common::collect(output_stream1).await.unwrap();
2401
2402        assert_snapshot!(batches_to_sort_string(&batches), @r"
2403        +------------------+
2404        | my_awesome_field |
2405        +------------------+
2406        | baz              |
2407        | frob             |
2408        | gar              |
2409        | goo              |
2410        +------------------+
2411        ");
2412    }
2413
2414    #[tokio::test]
2415    // As the hash results might be different on different platforms or
2416    // with different compilers, we will compare the same execution with
2417    // and without dropping the output stream.
2418    async fn hash_repartition_with_dropping_output_stream() {
2419        let task_ctx = Arc::new(TaskContext::default());
2420        let partitioning = Partitioning::Hash(
2421            vec![Arc::new(crate::expressions::Column::new(
2422                "my_awesome_field",
2423                0,
2424            ))],
2425            2,
2426        );
2427
2428        // We first collect the results without dropping the output stream.
2429        let input = Arc::new(make_barrier_exec());
2430        let exec = RepartitionExec::try_new(
2431            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2432            partitioning.clone(),
2433        )
2434        .unwrap();
2435        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2436        let mut background_task = JoinSet::new();
2437        background_task.spawn(async move {
2438            input.wait().await;
2439        });
2440        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
2441
2442        // run some checks on the result
2443        let items_vec = str_batches_to_vec(&batches_without_drop);
2444        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
2445        assert_eq!(items_vec.len(), items_set.len());
2446        let source_str_set: HashSet<&str> =
2447            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
2448                .iter()
2449                .copied()
2450                .collect();
2451        assert_eq!(items_set.difference(&source_str_set).count(), 0);
2452
2453        // Now do the same but dropping the stream before waiting for the barrier
2454        let input = Arc::new(make_barrier_exec());
2455        let exec = RepartitionExec::try_new(
2456            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2457            partitioning,
2458        )
2459        .unwrap();
2460        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2461        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2462        // now, purposely drop output stream 0
2463        // *before* any outputs are produced
2464        drop(output_stream0);
2465        let mut background_task = JoinSet::new();
2466        background_task.spawn(async move {
2467            input.wait().await;
2468        });
2469        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
2470
2471        let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2472        let items_set_with_drop: HashSet<&str> =
2473            items_vec_with_drop.iter().copied().collect();
2474        assert_eq!(
2475            items_set_with_drop.symmetric_difference(&items_set).count(),
2476            0
2477        );
2478    }
2479
2480    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
2481        batches
2482            .iter()
2483            .flat_map(|batch| {
2484                assert_eq!(batch.columns().len(), 1);
2485                let string_array = as_string_array(batch.column(0))
2486                    .expect("Unexpected type for repartitioned batch");
2487
2488                string_array
2489                    .iter()
2490                    .map(|v| v.expect("Unexpected null"))
2491                    .collect::<Vec<_>>()
2492            })
2493            .collect::<Vec<_>>()
2494    }
2495
2496    /// Create a BarrierExec that returns two partitions of two batches each
2497    fn make_barrier_exec() -> BarrierExec {
2498        let batch1 = RecordBatch::try_from_iter(vec![(
2499            "my_awesome_field",
2500            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
2501        )])
2502        .unwrap();
2503
2504        let batch2 = RecordBatch::try_from_iter(vec![(
2505            "my_awesome_field",
2506            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
2507        )])
2508        .unwrap();
2509
2510        let batch3 = RecordBatch::try_from_iter(vec![(
2511            "my_awesome_field",
2512            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
2513        )])
2514        .unwrap();
2515
2516        let batch4 = RecordBatch::try_from_iter(vec![(
2517            "my_awesome_field",
2518            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
2519        )])
2520        .unwrap();
2521
2522        // The barrier exec waits to be pinged
2523        // requires the input to wait at least once)
2524        let schema = batch1.schema();
2525        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
2526    }
2527
2528    #[tokio::test]
2529    async fn test_drop_cancel() -> Result<()> {
2530        let task_ctx = Arc::new(TaskContext::default());
2531        let schema =
2532            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
2533
2534        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
2535        let refs = blocking_exec.refs();
2536        let repartition_exec = Arc::new(RepartitionExec::try_new(
2537            blocking_exec,
2538            Partitioning::UnknownPartitioning(1),
2539        )?);
2540
2541        let fut = collect(repartition_exec, task_ctx);
2542        let mut fut = fut.boxed();
2543
2544        assert_is_pending(&mut fut);
2545        drop(fut);
2546        assert_strong_count_converges_to_zero(refs).await;
2547
2548        Ok(())
2549    }
2550
2551    #[tokio::test]
2552    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
2553        let task_ctx = Arc::new(TaskContext::default());
2554        let batch = RecordBatch::try_from_iter(vec![(
2555            "a",
2556            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2557        )])
2558        .unwrap();
2559        let partitioning = Partitioning::Hash(
2560            vec![Arc::new(crate::expressions::Column::new("a", 0))],
2561            2,
2562        );
2563        let schema = batch.schema();
2564        let input = MockExec::new(vec![Ok(batch)], schema);
2565        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2566        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2567        let batch0 = crate::common::collect(output_stream0).await.unwrap();
2568        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2569        let batch1 = crate::common::collect(output_stream1).await.unwrap();
2570        assert!(batch0.is_empty() || batch1.is_empty());
2571        Ok(())
2572    }
2573
2574    #[tokio::test]
2575    async fn repartition_with_spilling() -> Result<()> {
2576        // Test that repartition successfully spills to disk when memory is constrained
2577        let schema = test_schema();
2578        let partition = create_vec_batches(50);
2579        let input_partitions = vec![partition];
2580        let partitioning = Partitioning::RoundRobinBatch(4);
2581
2582        // Set up context with very tight memory limit to force spilling
2583        let runtime = RuntimeEnvBuilder::default()
2584            .with_memory_limit(1, 1.0)
2585            .build_arc()?;
2586
2587        let task_ctx = TaskContext::default().with_runtime(runtime);
2588        let task_ctx = Arc::new(task_ctx);
2589
2590        // create physical plan
2591        let exec =
2592            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2593        let exec = RepartitionExec::try_new(exec, partitioning)?;
2594
2595        // Collect all partitions - should succeed by spilling to disk
2596        let mut total_rows = 0;
2597        for i in 0..exec.partitioning().partition_count() {
2598            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2599            while let Some(result) = stream.next().await {
2600                let batch = result?;
2601                total_rows += batch.num_rows();
2602            }
2603        }
2604
2605        // Verify we got all the data (50 batches * 8 rows each)
2606        assert_eq!(total_rows, 50 * 8);
2607
2608        // Verify spilling metrics to confirm spilling actually happened
2609        let metrics = exec.metrics().unwrap();
2610        assert!(
2611            metrics.spill_count().unwrap() > 0,
2612            "Expected spill_count > 0, but got {:?}",
2613            metrics.spill_count()
2614        );
2615        println!("Spilled {} times", metrics.spill_count().unwrap());
2616        assert!(
2617            metrics.spilled_bytes().unwrap() > 0,
2618            "Expected spilled_bytes > 0, but got {:?}",
2619            metrics.spilled_bytes()
2620        );
2621        println!(
2622            "Spilled {} bytes in {} spills",
2623            metrics.spilled_bytes().unwrap(),
2624            metrics.spill_count().unwrap()
2625        );
2626        assert!(
2627            metrics.spilled_rows().unwrap() > 0,
2628            "Expected spilled_rows > 0, but got {:?}",
2629            metrics.spilled_rows()
2630        );
2631        println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2632
2633        Ok(())
2634    }
2635
2636    #[tokio::test]
2637    async fn repartition_with_partial_spilling() -> Result<()> {
2638        // Test that repartition can handle partial spilling (some batches in memory, some spilled)
2639        let schema = test_schema();
2640        let partition = create_vec_batches(50);
2641        let input_partitions = vec![partition];
2642        let partitioning = Partitioning::RoundRobinBatch(4);
2643
2644        // With `batch_size = 1024` and a single UInt32 column, each
2645        // coalesced residual is ~4 KiB. An 8 KiB pool fits one and forces
2646        // the rest to spill.
2647        let runtime = RuntimeEnvBuilder::default()
2648            .with_memory_limit(8 * 1024, 1.0)
2649            .build_arc()?;
2650
2651        let session_config = SessionConfig::new().with_batch_size(1024);
2652        let task_ctx = TaskContext::default()
2653            .with_runtime(runtime)
2654            .with_session_config(session_config);
2655        let task_ctx = Arc::new(task_ctx);
2656
2657        // create physical plan
2658        let exec =
2659            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2660        let exec = RepartitionExec::try_new(exec, partitioning)?;
2661
2662        // Collect all partitions - should succeed with partial spilling
2663        let mut total_rows = 0;
2664        for i in 0..exec.partitioning().partition_count() {
2665            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2666            while let Some(result) = stream.next().await {
2667                let batch = result?;
2668                total_rows += batch.num_rows();
2669            }
2670        }
2671
2672        // Verify we got all the data (50 batches * 8 rows each)
2673        assert_eq!(total_rows, 50 * 8);
2674
2675        // Verify partial spilling metrics
2676        let metrics = exec.metrics().unwrap();
2677        let spill_count = metrics.spill_count().unwrap();
2678        let spilled_rows = metrics.spilled_rows().unwrap();
2679        let spilled_bytes = metrics.spilled_bytes().unwrap();
2680
2681        assert!(
2682            spill_count > 0,
2683            "Expected some spilling to occur, but got spill_count={spill_count}"
2684        );
2685        assert!(
2686            spilled_rows > 0 && spilled_rows < total_rows,
2687            "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2688        );
2689        assert!(
2690            spilled_bytes > 0,
2691            "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2692        );
2693
2694        println!(
2695            "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2696            spilled_rows,
2697            total_rows,
2698            (spilled_rows as f64 / total_rows as f64) * 100.0,
2699            spill_count,
2700            spilled_bytes
2701        );
2702
2703        Ok(())
2704    }
2705
2706    #[tokio::test]
2707    async fn repartition_without_spilling() -> Result<()> {
2708        // Test that repartition does not spill when there's ample memory
2709        let schema = test_schema();
2710        let partition = create_vec_batches(50);
2711        let input_partitions = vec![partition];
2712        let partitioning = Partitioning::RoundRobinBatch(4);
2713
2714        // Set up context with generous memory limit - no spilling should occur
2715        let runtime = RuntimeEnvBuilder::default()
2716            .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB
2717            .build_arc()?;
2718
2719        let task_ctx = TaskContext::default().with_runtime(runtime);
2720        let task_ctx = Arc::new(task_ctx);
2721
2722        // create physical plan
2723        let exec =
2724            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2725        let exec = RepartitionExec::try_new(exec, partitioning)?;
2726
2727        // Collect all partitions - should succeed without spilling
2728        let mut total_rows = 0;
2729        for i in 0..exec.partitioning().partition_count() {
2730            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2731            while let Some(result) = stream.next().await {
2732                let batch = result?;
2733                total_rows += batch.num_rows();
2734            }
2735        }
2736
2737        // Verify we got all the data (50 batches * 8 rows each)
2738        assert_eq!(total_rows, 50 * 8);
2739
2740        // Verify no spilling occurred
2741        let metrics = exec.metrics().unwrap();
2742        assert_eq!(
2743            metrics.spill_count(),
2744            Some(0),
2745            "Expected no spilling, but got spill_count={:?}",
2746            metrics.spill_count()
2747        );
2748        assert_eq!(
2749            metrics.spilled_bytes(),
2750            Some(0),
2751            "Expected no bytes spilled, but got spilled_bytes={:?}",
2752            metrics.spilled_bytes()
2753        );
2754        assert_eq!(
2755            metrics.spilled_rows(),
2756            Some(0),
2757            "Expected no rows spilled, but got spilled_rows={:?}",
2758            metrics.spilled_rows()
2759        );
2760
2761        println!("No spilling occurred - all data processed in memory");
2762
2763        Ok(())
2764    }
2765
2766    #[tokio::test]
2767    async fn oom() -> Result<()> {
2768        use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2769
2770        // Test that repartition fails with OOM when disk manager is disabled
2771        let schema = test_schema();
2772        let partition = create_vec_batches(50);
2773        let input_partitions = vec![partition];
2774        let partitioning = Partitioning::RoundRobinBatch(4);
2775
2776        // Setup context with memory limit but NO disk manager (explicitly disabled)
2777        let runtime = RuntimeEnvBuilder::default()
2778            .with_memory_limit(1, 1.0)
2779            .with_disk_manager_builder(
2780                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2781            )
2782            .build_arc()?;
2783
2784        let task_ctx = TaskContext::default().with_runtime(runtime);
2785        let task_ctx = Arc::new(task_ctx);
2786
2787        // create physical plan
2788        let exec =
2789            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2790        let exec = RepartitionExec::try_new(exec, partitioning)?;
2791
2792        // Attempt to execute - should fail with ResourcesExhausted error
2793        for i in 0..exec.partitioning().partition_count() {
2794            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2795            let err = stream.next().await.unwrap().unwrap_err();
2796            let err = err.find_root();
2797            assert!(
2798                matches!(err, DataFusionError::ResourcesExhausted(_)),
2799                "Wrong error type: {err}",
2800            );
2801        }
2802
2803        Ok(())
2804    }
2805
2806    /// Create vector batches
2807    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2808        let batch = create_batch();
2809        std::iter::repeat_n(batch, n).collect()
2810    }
2811
2812    /// Create batch
2813    fn create_batch() -> RecordBatch {
2814        let schema = test_schema();
2815        RecordBatch::try_new(
2816            schema,
2817            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2818        )
2819        .unwrap()
2820    }
2821
2822    /// Create batches with sequential values for ordering tests
2823    fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2824        let schema = test_schema();
2825        (0..num_batches)
2826            .map(|i| {
2827                let start = (i * 8) as u32;
2828                RecordBatch::try_new(
2829                    Arc::clone(&schema),
2830                    vec![Arc::new(UInt32Array::from(
2831                        (start..start + 8).collect::<Vec<_>>(),
2832                    ))],
2833                )
2834                .unwrap()
2835            })
2836            .collect()
2837    }
2838
2839    #[tokio::test]
2840    async fn test_repartition_ordering_with_spilling() -> Result<()> {
2841        // Test that repartition preserves ordering when spilling occurs
2842        // This tests the state machine fix where we must block on spill_stream
2843        // when a Spilled marker is received, rather than continuing to poll the channel
2844
2845        let schema = test_schema();
2846        // Create batches with sequential values: batch 0 has [0,1,2,3,4,5,6,7],
2847        // batch 1 has [8,9,10,11,12,13,14,15], etc.
2848        let partition = create_ordered_batches(20);
2849        let input_partitions = vec![partition];
2850
2851        // Use RoundRobinBatch to ensure predictable ordering
2852        let partitioning = Partitioning::RoundRobinBatch(2);
2853
2854        // Set up context with very tight memory limit to force spilling
2855        let runtime = RuntimeEnvBuilder::default()
2856            .with_memory_limit(1, 1.0)
2857            .build_arc()?;
2858
2859        let task_ctx = TaskContext::default().with_runtime(runtime);
2860        let task_ctx = Arc::new(task_ctx);
2861
2862        // create physical plan
2863        let exec =
2864            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2865        let exec = RepartitionExec::try_new(exec, partitioning)?;
2866
2867        // Collect all output partitions
2868        let mut all_batches = Vec::new();
2869        for i in 0..exec.partitioning().partition_count() {
2870            let mut partition_batches = Vec::new();
2871            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2872            while let Some(result) = stream.next().await {
2873                let batch = result?;
2874                partition_batches.push(batch);
2875            }
2876            all_batches.push(partition_batches);
2877        }
2878
2879        // Verify spilling occurred
2880        let metrics = exec.metrics().unwrap();
2881        assert!(
2882            metrics.spill_count().unwrap() > 0,
2883            "Expected spilling to occur, but spill_count = 0"
2884        );
2885
2886        // Verify ordering is preserved within each partition
2887        // With RoundRobinBatch, even batches go to partition 0, odd batches to partition 1
2888        for (partition_idx, batches) in all_batches.iter().enumerate() {
2889            let mut last_value = None;
2890            for batch in batches {
2891                let array = batch
2892                    .column(0)
2893                    .as_any()
2894                    .downcast_ref::<UInt32Array>()
2895                    .unwrap();
2896
2897                for i in 0..array.len() {
2898                    let value = array.value(i);
2899                    if let Some(last) = last_value {
2900                        assert!(
2901                            value > last,
2902                            "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2903                        );
2904                    }
2905                    last_value = Some(value);
2906                }
2907            }
2908        }
2909
2910        Ok(())
2911    }
2912}
2913
2914#[cfg(test)]
2915mod test {
2916    use arrow::array::record_batch;
2917    use arrow::compute::SortOptions;
2918    use arrow::datatypes::{DataType, Field, Schema};
2919    use datafusion_common::assert_batches_eq;
2920
2921    use super::*;
2922    use crate::test::TestMemoryExec;
2923    use crate::union::UnionExec;
2924
2925    use datafusion_physical_expr::expressions::col;
2926
2927    /// Asserts that the plan is as expected
2928    ///
2929    /// `$EXPECTED_PLAN_LINES`: input plan
2930    /// `$PLAN`: the plan to optimized
2931    macro_rules! assert_plan {
2932        ($PLAN: expr,  @ $EXPECTED: expr) => {
2933            let formatted = crate::displayable($PLAN).indent(true).to_string();
2934
2935            insta::assert_snapshot!(
2936                formatted,
2937                @$EXPECTED
2938            );
2939        };
2940    }
2941
2942    #[tokio::test]
2943    async fn test_preserve_order() -> Result<()> {
2944        let schema = test_schema();
2945        let sort_exprs = sort_exprs(&schema);
2946        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2947        let source2 = sorted_memory_exec(&schema, sort_exprs);
2948        // output has multiple partitions, and is sorted
2949        let union = UnionExec::try_new(vec![source1, source2])?;
2950        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2951            .with_preserve_order();
2952
2953        // Repartition should preserve order
2954        assert_plan!(&exec, @r"
2955        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2956          UnionExec
2957            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2958            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2959        ");
2960        Ok(())
2961    }
2962
2963    #[tokio::test]
2964    async fn test_preserve_order_one_partition() -> Result<()> {
2965        let schema = test_schema();
2966        let sort_exprs = sort_exprs(&schema);
2967        let source = sorted_memory_exec(&schema, sort_exprs);
2968        // output is sorted, but has only a single partition, so no need to sort
2969        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2970            .with_preserve_order();
2971
2972        // Repartition should not preserve order
2973        assert_plan!(&exec, @r"
2974        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
2975          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2976        ");
2977
2978        Ok(())
2979    }
2980
2981    #[tokio::test]
2982    async fn test_preserve_order_input_not_sorted() -> Result<()> {
2983        let schema = test_schema();
2984        let source1 = memory_exec(&schema);
2985        let source2 = memory_exec(&schema);
2986        // output has multiple partitions, but is not sorted
2987        let union = UnionExec::try_new(vec![source1, source2])?;
2988        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2989            .with_preserve_order();
2990
2991        // Repartition should not preserve order, as there is no order to preserve
2992        assert_plan!(&exec, @r"
2993        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2994          UnionExec
2995            DataSourceExec: partitions=1, partition_sizes=[0]
2996            DataSourceExec: partitions=1, partition_sizes=[0]
2997        ");
2998        Ok(())
2999    }
3000
3001    #[tokio::test]
3002    async fn test_preserve_order_with_spilling() -> Result<()> {
3003        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
3004
3005        // Create sorted input data across multiple partitions
3006        // Partition1: [1,3], [5,7], [9,11]
3007        // Partition2: [2,4], [6,8], [10,12]
3008        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
3009        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
3010        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
3011        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
3012        let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
3013        let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
3014        let schema = batch1.schema();
3015        let sort_exprs = LexOrdering::new([PhysicalSortExpr {
3016            expr: col("c0", &schema).unwrap(),
3017            options: SortOptions::default().asc(),
3018        }])
3019        .unwrap();
3020        let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
3021        let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
3022        let input_partitions = vec![partition1, partition2];
3023
3024        // Set up context with tight memory limit to force spilling
3025        // Sorting needs some non-spillable memory, so 64 bytes should force spilling while still allowing the query to complete
3026        let runtime = RuntimeEnvBuilder::default()
3027            .with_memory_limit(64, 1.0)
3028            .build_arc()?;
3029
3030        let task_ctx = TaskContext::default().with_runtime(runtime);
3031        let task_ctx = Arc::new(task_ctx);
3032
3033        // Create physical plan with order preservation
3034        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
3035            .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
3036        let exec = Arc::new(exec);
3037        let exec = Arc::new(TestMemoryExec::update_cache(&exec));
3038        // Repartition into 3 partitions with order preservation
3039        // We expect 1 batch per output partition after repartitioning
3040        let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
3041            .with_preserve_order();
3042
3043        let mut batches = vec![];
3044
3045        // Collect all partitions - should succeed by spilling to disk
3046        for i in 0..exec.partitioning().partition_count() {
3047            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
3048            while let Some(result) = stream.next().await {
3049                let batch = result?;
3050                batches.push(batch);
3051            }
3052        }
3053
3054        #[rustfmt::skip]
3055        let expected = [
3056            [
3057                "+----+",
3058                "| c0 |",
3059                "+----+",
3060                "| 1  |",
3061                "| 2  |",
3062                "| 3  |",
3063                "| 4  |",
3064                "+----+",
3065            ],
3066            [
3067                "+----+",
3068                "| c0 |",
3069                "+----+",
3070                "| 5  |",
3071                "| 6  |",
3072                "| 7  |",
3073                "| 8  |",
3074                "+----+",
3075            ],
3076            [
3077                "+----+",
3078                "| c0 |",
3079                "+----+",
3080                "| 9  |",
3081                "| 10 |",
3082                "| 11 |",
3083                "| 12 |",
3084                "+----+",
3085            ],
3086        ];
3087
3088        for (batch, expected) in batches.iter().zip(expected.iter()) {
3089            assert_batches_eq!(expected, std::slice::from_ref(batch));
3090        }
3091
3092        // We should have spilled ~ all of the data.
3093        // - We spill data during the repartitioning phase
3094        // - We may also spill during the final merge sort
3095        let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
3096        let metrics = exec.metrics().unwrap();
3097        assert!(
3098            metrics.spill_count().unwrap() > input_partitions.len(),
3099            "Expected spill_count > {} for order-preserving repartition, but got {:?}",
3100            input_partitions.len(),
3101            metrics.spill_count()
3102        );
3103        assert!(
3104            metrics.spilled_bytes().unwrap()
3105                > all_batches
3106                    .iter()
3107                    .map(|b| b.get_array_memory_size())
3108                    .sum::<usize>(),
3109            "Expected spilled_bytes > {} for order-preserving repartition, got {}",
3110            all_batches
3111                .iter()
3112                .map(|b| b.get_array_memory_size())
3113                .sum::<usize>(),
3114            metrics.spilled_bytes().unwrap()
3115        );
3116        assert!(
3117            metrics.spilled_rows().unwrap()
3118                >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
3119            "Expected spilled_rows > {} for order-preserving repartition, got {}",
3120            all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
3121            metrics.spilled_rows().unwrap()
3122        );
3123
3124        Ok(())
3125    }
3126
3127    #[tokio::test]
3128    async fn test_hash_partitioning_with_spilling() -> Result<()> {
3129        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
3130
3131        // Create input data similar to the round-robin test
3132        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
3133        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
3134        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
3135        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
3136        let schema = batch1.schema();
3137
3138        let partition1 = vec![batch1.clone(), batch3.clone()];
3139        let partition2 = vec![batch2.clone(), batch4.clone()];
3140        let input_partitions = vec![partition1, partition2];
3141
3142        // Set up context with memory limit to test hash partitioning with spilling infrastructure
3143        let runtime = RuntimeEnvBuilder::default()
3144            .with_memory_limit(1, 1.0)
3145            .build_arc()?;
3146
3147        let task_ctx = TaskContext::default().with_runtime(runtime);
3148        let task_ctx = Arc::new(task_ctx);
3149
3150        // Create physical plan with hash partitioning
3151        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
3152        let exec = Arc::new(exec);
3153        let exec = Arc::new(TestMemoryExec::update_cache(&exec));
3154        // Hash partition into 2 partitions by column c0
3155        let hash_expr = col("c0", &schema)?;
3156        let exec =
3157            RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
3158
3159        // Collect all partitions concurrently using JoinSet - this prevents deadlock
3160        // where the distribution channel gate closes when all output channels are full
3161        let mut join_set = tokio::task::JoinSet::new();
3162        for i in 0..exec.partitioning().partition_count() {
3163            let stream = exec.execute(i, Arc::clone(&task_ctx))?;
3164            join_set.spawn(async move {
3165                let mut count = 0;
3166                futures::pin_mut!(stream);
3167                while let Some(result) = stream.next().await {
3168                    let batch = result?;
3169                    count += batch.num_rows();
3170                }
3171                Ok::<usize, DataFusionError>(count)
3172            });
3173        }
3174
3175        // Wait for all partitions and sum the rows
3176        let mut total_rows = 0;
3177        while let Some(result) = join_set.join_next().await {
3178            total_rows += result.unwrap()?;
3179        }
3180
3181        // Verify we got all rows back
3182        let all_batches = [batch1, batch2, batch3, batch4];
3183        let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
3184        assert_eq!(total_rows, expected_rows);
3185
3186        // Verify metrics are available
3187        let metrics = exec.metrics().unwrap();
3188        // Just verify the metrics can be retrieved (spilling may or may not occur)
3189        let spill_count = metrics.spill_count().unwrap_or(0);
3190        assert!(spill_count > 0);
3191        let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
3192        assert!(spilled_bytes > 0);
3193        let spilled_rows = metrics.spilled_rows().unwrap_or(0);
3194        assert!(spilled_rows > 0);
3195
3196        Ok(())
3197    }
3198
3199    #[tokio::test]
3200    async fn test_repartition() -> Result<()> {
3201        let schema = test_schema();
3202        let sort_exprs = sort_exprs(&schema);
3203        let source = sorted_memory_exec(&schema, sort_exprs);
3204        // output is sorted, but has only a single partition, so no need to sort
3205        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
3206            .repartitioned(20, &Default::default())?
3207            .unwrap();
3208
3209        // Repartition should not preserve order
3210        assert_plan!(exec.as_ref(), @r"
3211        RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
3212          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
3213        ");
3214        Ok(())
3215    }
3216
3217    fn test_schema() -> Arc<Schema> {
3218        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
3219    }
3220
3221    fn sort_exprs(schema: &Schema) -> LexOrdering {
3222        [PhysicalSortExpr {
3223            expr: col("c0", schema).unwrap(),
3224            options: SortOptions::default(),
3225        }]
3226        .into()
3227    }
3228
3229    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
3230        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
3231    }
3232
3233    fn sorted_memory_exec(
3234        schema: &SchemaRef,
3235        sort_exprs: LexOrdering,
3236    ) -> Arc<dyn ExecutionPlan> {
3237        let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
3238            .unwrap()
3239            .try_with_sort_information(vec![sort_exprs])
3240            .unwrap();
3241        let exec = Arc::new(exec);
3242        Arc::new(TestMemoryExec::update_cache(&exec))
3243    }
3244}