datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
34use crate::hash_utils::create_hashes;
35use crate::metrics::{BaselineMetrics, SpillMetrics};
36use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
37use crate::sorts::streaming_merge::StreamingMergeBuilder;
38use crate::spill::spill_manager::SpillManager;
39use crate::spill::spill_pool::{self, SpillPoolWriter};
40use crate::stream::RecordBatchStreamAdapter;
41use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
42
43use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
44use arrow::compute::take_arrays;
45use arrow::datatypes::{SchemaRef, UInt32Type};
46use datafusion_common::config::ConfigOptions;
47use datafusion_common::stats::Precision;
48use datafusion_common::utils::transpose;
49use datafusion_common::{internal_err, ColumnStatistics, HashMap};
50use datafusion_common::{not_impl_err, DataFusionError, Result};
51use datafusion_common_runtime::SpawnedTask;
52use datafusion_execution::memory_pool::MemoryConsumer;
53use datafusion_execution::TaskContext;
54use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
55use datafusion_physical_expr_common::sort_expr::LexOrdering;
56
57use crate::filter_pushdown::{
58    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
59    FilterPushdownPropagation,
60};
61use futures::stream::Stream;
62use futures::{FutureExt, StreamExt, TryStreamExt};
63use log::trace;
64use parking_lot::Mutex;
65
66mod distributor_channels;
67use distributor_channels::{
68    channels, partition_aware_channels, DistributionReceiver, DistributionSender,
69};
70
71/// A batch in the repartition queue - either in memory or spilled to disk.
72///
73/// This enum represents the two states a batch can be in during repartitioning.
74/// The decision to spill is made based on memory availability when sending a batch
75/// to an output partition.
76///
77/// # Batch Flow with Spilling
78///
79/// ```text
80/// Input Stream ──▶ Partition Logic ──▶ try_grow()
81///                                            │
82///                            ┌───────────────┴────────────────┐
83///                            │                                │
84///                            ▼                                ▼
85///                   try_grow() succeeds            try_grow() fails
86///                   (Memory Available)              (Memory Pressure)
87///                            │                                │
88///                            ▼                                ▼
89///                  RepartitionBatch::Memory         spill_writer.push_batch()
90///                  (batch held in memory)           (batch written to disk)
91///                            │                                │
92///                            │                                ▼
93///                            │                      RepartitionBatch::Spilled
94///                            │                      (marker - no batch data)
95///                            │                                │
96///                            └────────┬───────────────────────┘
97///                                     │
98///                                     ▼
99///                              Send to channel
100///                                     │
101///                                     ▼
102///                            Output Stream (poll)
103///                                     │
104///                      ┌──────────────┴─────────────┐
105///                      │                            │
106///                      ▼                            ▼
107///         RepartitionBatch::Memory      RepartitionBatch::Spilled
108///         Return batch immediately       Poll spill_stream (blocks)
109///                      │                            │
110///                      └────────┬───────────────────┘
111///                               │
112///                               ▼
113///                          Return batch
114///                    (FIFO order preserved)
115/// ```
116///
117/// See [`RepartitionExec`] for overall architecture and [`StreamState`] for
118/// the state machine that handles reading these batches.
119#[derive(Debug)]
120enum RepartitionBatch {
121    /// Batch held in memory (counts against memory reservation)
122    Memory(RecordBatch),
123    /// Marker indicating a batch was spilled to the partition's SpillPool.
124    /// The actual batch can be retrieved by reading from the SpillPoolStream.
125    /// This variant contains no data itself - it's just a signal to the reader
126    /// to fetch the next batch from the spill stream.
127    Spilled,
128}
129
130type MaybeBatch = Option<Result<RepartitionBatch>>;
131type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
132type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
133
134/// Output channel with its associated memory reservation and spill writer
135struct OutputChannel {
136    sender: DistributionSender<MaybeBatch>,
137    reservation: SharedMemoryReservation,
138    spill_writer: SpillPoolWriter,
139}
140
141/// Channels and resources for a single output partition.
142///
143/// Each output partition has channels to receive data from all input partitions.
144/// To handle memory pressure, each (input, output) pair gets its own
145/// [`SpillPool`](crate::spill::spill_pool) channel via [`spill_pool::channel`].
146///
147/// # Structure
148///
149/// For an output partition receiving from N input partitions:
150/// - `tx`: N senders (one per input) for sending batches to this output
151/// - `rx`: N receivers (one per input) for receiving batches at this output
152/// - `spill_writers`: N spill writers (one per input) for writing spilled data
153/// - `spill_readers`: N spill readers (one per input) for reading spilled data
154///
155/// This 1:1 mapping between input partitions and spill channels ensures that
156/// batches from each input are processed in FIFO order, even when some batches
157/// are spilled to disk and others remain in memory.
158///
159/// See [`RepartitionExec`] for the overall N×M architecture.
160///
161/// [`spill_pool::channel`]: crate::spill::spill_pool::channel
162struct PartitionChannels {
163    /// Senders for each input partition to send data to this output partition
164    tx: InputPartitionsToCurrentPartitionSender,
165    /// Receivers for each input partition sending data to this output partition
166    rx: InputPartitionsToCurrentPartitionReceiver,
167    /// Memory reservation for this output partition
168    reservation: SharedMemoryReservation,
169    /// Spill writers for writing spilled data.
170    /// SpillPoolWriter is Clone, so multiple writers can share state in non-preserve-order mode.
171    spill_writers: Vec<SpillPoolWriter>,
172    /// Spill readers for reading spilled data - one per input partition (FIFO semantics).
173    /// Each (input, output) pair gets its own reader to maintain proper ordering.
174    spill_readers: Vec<SendableRecordBatchStream>,
175}
176
177struct ConsumingInputStreamsState {
178    /// Channels for sending batches from input partitions to output partitions.
179    /// Key is the partition number.
180    channels: HashMap<usize, PartitionChannels>,
181
182    /// Helper that ensures that background jobs are killed once they are no longer needed.
183    abort_helper: Arc<Vec<SpawnedTask<()>>>,
184}
185
186impl Debug for ConsumingInputStreamsState {
187    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
188        f.debug_struct("ConsumingInputStreamsState")
189            .field("num_channels", &self.channels.len())
190            .field("abort_helper", &self.abort_helper)
191            .finish()
192    }
193}
194
195/// Inner state of [`RepartitionExec`].
196#[derive(Default)]
197enum RepartitionExecState {
198    /// Not initialized yet. This is the default state stored in the RepartitionExec node
199    /// upon instantiation.
200    #[default]
201    NotInitialized,
202    /// Input streams are initialized, but they are still not being consumed. The node
203    /// transitions to this state when the arrow's RecordBatch stream is created in
204    /// RepartitionExec::execute(), but before any message is polled.
205    InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
206    /// The input streams are being consumed. The node transitions to this state when
207    /// the first message in the arrow's RecordBatch stream is consumed.
208    ConsumingInputStreams(ConsumingInputStreamsState),
209}
210
211impl Debug for RepartitionExecState {
212    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
213        match self {
214            RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
215            RepartitionExecState::InputStreamsInitialized(v) => {
216                write!(f, "InputStreamsInitialized({:?})", v.len())
217            }
218            RepartitionExecState::ConsumingInputStreams(v) => {
219                write!(f, "ConsumingInputStreams({v:?})")
220            }
221        }
222    }
223}
224
225impl RepartitionExecState {
226    fn ensure_input_streams_initialized(
227        &mut self,
228        input: Arc<dyn ExecutionPlan>,
229        metrics: ExecutionPlanMetricsSet,
230        output_partitions: usize,
231        ctx: Arc<TaskContext>,
232    ) -> Result<()> {
233        if !matches!(self, RepartitionExecState::NotInitialized) {
234            return Ok(());
235        }
236
237        let num_input_partitions = input.output_partitioning().partition_count();
238        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
239
240        for i in 0..num_input_partitions {
241            let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
242
243            let timer = metrics.fetch_time.timer();
244            let stream = input.execute(i, Arc::clone(&ctx))?;
245            timer.done();
246
247            streams_and_metrics.push((stream, metrics));
248        }
249        *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
250        Ok(())
251    }
252
253    #[expect(clippy::too_many_arguments)]
254    fn consume_input_streams(
255        &mut self,
256        input: Arc<dyn ExecutionPlan>,
257        metrics: ExecutionPlanMetricsSet,
258        partitioning: Partitioning,
259        preserve_order: bool,
260        name: String,
261        context: Arc<TaskContext>,
262        spill_manager: SpillManager,
263    ) -> Result<&mut ConsumingInputStreamsState> {
264        let streams_and_metrics = match self {
265            RepartitionExecState::NotInitialized => {
266                self.ensure_input_streams_initialized(
267                    Arc::clone(&input),
268                    metrics.clone(),
269                    partitioning.partition_count(),
270                    Arc::clone(&context),
271                )?;
272                let RepartitionExecState::InputStreamsInitialized(value) = self else {
273                    // This cannot happen, as ensure_input_streams_initialized() was just called,
274                    // but the compiler does not know.
275                    return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
276                };
277                value
278            }
279            RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
280            RepartitionExecState::InputStreamsInitialized(value) => value,
281        };
282
283        let num_input_partitions = streams_and_metrics.len();
284        let num_output_partitions = partitioning.partition_count();
285
286        let spill_manager = Arc::new(spill_manager);
287
288        let (txs, rxs) = if preserve_order {
289            // Create partition-aware channels with one channel per (input, output) pair
290            // This provides backpressure while maintaining proper ordering
291            let (txs_all, rxs_all) =
292                partition_aware_channels(num_input_partitions, num_output_partitions);
293            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
294            let txs = transpose(txs_all);
295            let rxs = transpose(rxs_all);
296            (txs, rxs)
297        } else {
298            // Create one channel per *output* partition with backpressure
299            let (txs, rxs) = channels(num_output_partitions);
300            // Clone sender for each input partitions
301            let txs = txs
302                .into_iter()
303                .map(|item| vec![item; num_input_partitions])
304                .collect::<Vec<_>>();
305            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
306            (txs, rxs)
307        };
308
309        let mut channels = HashMap::with_capacity(txs.len());
310        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
311            let reservation = Arc::new(Mutex::new(
312                MemoryConsumer::new(format!("{name}[{partition}]"))
313                    .with_can_spill(true)
314                    .register(context.memory_pool()),
315            ));
316
317            // Create spill channels based on mode:
318            // - preserve_order: one spill channel per (input, output) pair for proper FIFO ordering
319            // - non-preserve-order: one shared spill channel per output partition since all inputs
320            //   share the same receiver
321            let max_file_size = context
322                .session_config()
323                .options()
324                .execution
325                .max_spill_file_size_bytes;
326            let num_spill_channels = if preserve_order {
327                num_input_partitions
328            } else {
329                1
330            };
331            let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
332                ..num_spill_channels)
333                .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
334                .unzip();
335
336            channels.insert(
337                partition,
338                PartitionChannels {
339                    tx,
340                    rx,
341                    reservation,
342                    spill_readers,
343                    spill_writers,
344                },
345            );
346        }
347
348        // launch one async task per *input* partition
349        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
350        for (i, (stream, metrics)) in
351            std::mem::take(streams_and_metrics).into_iter().enumerate()
352        {
353            let txs: HashMap<_, _> = channels
354                .iter()
355                .map(|(partition, channels)| {
356                    // In preserve_order mode: each input gets its own spill writer (index i)
357                    // In non-preserve-order mode: all inputs share spill writer 0 via clone
358                    let spill_writer_idx = if preserve_order { i } else { 0 };
359                    (
360                        *partition,
361                        OutputChannel {
362                            sender: channels.tx[i].clone(),
363                            reservation: Arc::clone(&channels.reservation),
364                            spill_writer: channels.spill_writers[spill_writer_idx]
365                                .clone(),
366                        },
367                    )
368                })
369                .collect();
370
371            // Extract senders for wait_for_task before moving txs
372            let senders: HashMap<_, _> = txs
373                .iter()
374                .map(|(partition, channel)| (*partition, channel.sender.clone()))
375                .collect();
376
377            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
378                stream,
379                txs,
380                partitioning.clone(),
381                metrics,
382            ));
383
384            // In a separate task, wait for each input to be done
385            // (and pass along any errors, including panic!s)
386            let wait_for_task =
387                SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
388            spawned_tasks.push(wait_for_task);
389        }
390        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
391            channels,
392            abort_helper: Arc::new(spawned_tasks),
393        });
394        match self {
395            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
396            _ => unreachable!(),
397        }
398    }
399}
400
401/// A utility that can be used to partition batches based on [`Partitioning`]
402pub struct BatchPartitioner {
403    state: BatchPartitionerState,
404    timer: metrics::Time,
405}
406
407enum BatchPartitionerState {
408    Hash {
409        random_state: ahash::RandomState,
410        exprs: Vec<Arc<dyn PhysicalExpr>>,
411        num_partitions: usize,
412        hash_buffer: Vec<u64>,
413    },
414    RoundRobin {
415        num_partitions: usize,
416        next_idx: usize,
417    },
418}
419
420impl BatchPartitioner {
421    /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`]
422    ///
423    /// The time spent repartitioning will be recorded to `timer`
424    pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
425        let state = match partitioning {
426            Partitioning::RoundRobinBatch(num_partitions) => {
427                BatchPartitionerState::RoundRobin {
428                    num_partitions,
429                    next_idx: 0,
430                }
431            }
432            Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
433                exprs,
434                num_partitions,
435                // Use fixed random hash
436                random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
437                hash_buffer: vec![],
438            },
439            other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
440        };
441
442        Ok(Self { state, timer })
443    }
444
445    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
446    /// based on the [`Partitioning`] specified on construction
447    ///
448    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
449    /// partition index. Any error returned by `f` will be immediately returned by this
450    /// function without attempting to publish further [`RecordBatch`]
451    ///
452    /// The time spent repartitioning, not including time spent in `f` will be recorded
453    /// to the [`metrics::Time`] provided on construction
454    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
455    where
456        F: FnMut(usize, RecordBatch) -> Result<()>,
457    {
458        self.partition_iter(batch)?.try_for_each(|res| match res {
459            Ok((partition, batch)) => f(partition, batch),
460            Err(e) => Err(e),
461        })
462    }
463
464    /// Actual implementation of [`partition`](Self::partition).
465    ///
466    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
467    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
468    /// this (so we don't need to clone the entire implementation).
469    fn partition_iter(
470        &mut self,
471        batch: RecordBatch,
472    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
473        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
474            match &mut self.state {
475                BatchPartitionerState::RoundRobin {
476                    num_partitions,
477                    next_idx,
478                } => {
479                    let idx = *next_idx;
480                    *next_idx = (*next_idx + 1) % *num_partitions;
481                    Box::new(std::iter::once(Ok((idx, batch))))
482                }
483                BatchPartitionerState::Hash {
484                    random_state,
485                    exprs,
486                    num_partitions: partitions,
487                    hash_buffer,
488                } => {
489                    // Tracking time required for distributing indexes across output partitions
490                    let timer = self.timer.timer();
491
492                    let arrays = exprs
493                        .iter()
494                        .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
495                        .collect::<Result<Vec<_>>>()?;
496
497                    hash_buffer.clear();
498                    hash_buffer.resize(batch.num_rows(), 0);
499
500                    create_hashes(&arrays, random_state, hash_buffer)?;
501
502                    let mut indices: Vec<_> = (0..*partitions)
503                        .map(|_| Vec::with_capacity(batch.num_rows()))
504                        .collect();
505
506                    for (index, hash) in hash_buffer.iter().enumerate() {
507                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
508                    }
509
510                    // Finished building index-arrays for output partitions
511                    timer.done();
512
513                    // Borrowing partitioner timer to prevent moving `self` to closure
514                    let partitioner_timer = &self.timer;
515                    let it = indices
516                        .into_iter()
517                        .enumerate()
518                        .filter_map(|(partition, indices)| {
519                            let indices: PrimitiveArray<UInt32Type> = indices.into();
520                            (!indices.is_empty()).then_some((partition, indices))
521                        })
522                        .map(move |(partition, indices)| {
523                            // Tracking time required for repartitioned batches construction
524                            let _timer = partitioner_timer.timer();
525
526                            // Produce batches based on indices
527                            let columns = take_arrays(batch.columns(), &indices, None)?;
528
529                            let mut options = RecordBatchOptions::new();
530                            options = options.with_row_count(Some(indices.len()));
531                            let batch = RecordBatch::try_new_with_options(
532                                batch.schema(),
533                                columns,
534                                &options,
535                            )
536                            .unwrap();
537
538                            Ok((partition, batch))
539                        });
540
541                    Box::new(it)
542                }
543            };
544
545        Ok(it)
546    }
547
548    // return the number of output partitions
549    fn num_partitions(&self) -> usize {
550        match self.state {
551            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
552            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
553        }
554    }
555}
556
557/// Maps `N` input partitions to `M` output partitions based on a
558/// [`Partitioning`] scheme.
559///
560/// # Background
561///
562/// DataFusion, like most other commercial systems, with the
563/// notable exception of DuckDB, uses the "Exchange Operator" based
564/// approach to parallelism which works well in practice given
565/// sufficient care in implementation.
566///
567/// DataFusion's planner picks the target number of partitions and
568/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
569/// of output partitions.
570///
571/// For example, given `target_partitions=3` (trying to use 3 cores)
572/// but scanning an input with 2 partitions, `RepartitionExec` can be
573/// used to get 3 even streams of `RecordBatch`es
574///
575///
576///```text
577///        ▲                  ▲                  ▲
578///        │                  │                  │
579///        │                  │                  │
580///        │                  │                  │
581/// ┌───────────────┐  ┌───────────────┐  ┌───────────────┐
582/// │    GroupBy    │  │    GroupBy    │  │    GroupBy    │
583/// │   (Partial)   │  │   (Partial)   │  │   (Partial)   │
584/// └───────────────┘  └───────────────┘  └───────────────┘
585///        ▲                  ▲                  ▲
586///        └──────────────────┼──────────────────┘
587///                           │
588///              ┌─────────────────────────┐
589///              │     RepartitionExec     │
590///              │   (hash/round robin)    │
591///              └─────────────────────────┘
592///                         ▲   ▲
593///             ┌───────────┘   └───────────┐
594///             │                           │
595///             │                           │
596///        .─────────.                 .─────────.
597///     ,─'           '─.           ,─'           '─.
598///    ;      Input      :         ;      Input      :
599///    :   Partition 0   ;         :   Partition 1   ;
600///     ╲               ╱           ╲               ╱
601///      '─.         ,─'             '─.         ,─'
602///         `───────'                   `───────'
603/// ```
604///
605/// # Error Handling
606///
607/// If any of the input partitions return an error, the error is propagated to
608/// all output partitions and inputs are not polled again.
609///
610/// # Output Ordering
611///
612/// If more than one stream is being repartitioned, the output will be some
613/// arbitrary interleaving (and thus unordered) unless
614/// [`Self::with_preserve_order`] specifies otherwise.
615///
616/// # Spilling Architecture
617///
618/// RepartitionExec uses [`SpillPool`](crate::spill::spill_pool) channels to handle
619/// memory pressure during repartitioning. Each (input partition, output partition)
620/// pair gets its own SpillPool channel for FIFO ordering.
621///
622/// ```text
623/// Input Partitions (N)          Output Partitions (M)
624/// ────────────────────          ─────────────────────
625///
626///    Input 0 ──┐                      ┌──▶ Output 0
627///              │  ┌──────────────┐    │
628///              ├─▶│ SpillPool    │────┤
629///              │  │ [In0→Out0]   │    │
630///    Input 1 ──┤  └──────────────┘    ├──▶ Output 1
631///              │                       │
632///              │  ┌──────────────┐    │
633///              ├─▶│ SpillPool    │────┤
634///              │  │ [In1→Out0]   │    │
635///    Input 2 ──┤  └──────────────┘    ├──▶ Output 2
636///              │                      │
637///              │       ... (N×M SpillPools total)
638///              │                      │
639///              │  ┌──────────────┐    │
640///              └─▶│ SpillPool    │────┘
641///                 │ [InN→OutM]   │
642///                 └──────────────┘
643///
644/// Each SpillPool maintains FIFO order for its (input, output) pair.
645/// See `RepartitionBatch` for details on the memory/spill decision logic.
646/// ```
647///
648/// # Footnote
649///
650/// The "Exchange Operator" was first described in the 1989 paper
651/// [Encapsulation of parallelism in the Volcano query processing
652/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
653/// which uses the term "Exchange" for the concept of repartitioning
654/// data across threads.
655#[derive(Debug, Clone)]
656pub struct RepartitionExec {
657    /// Input execution plan
658    input: Arc<dyn ExecutionPlan>,
659    /// Inner state that is initialized when the parent calls .execute() on this node
660    /// and consumed as soon as the parent starts consuming this node.
661    state: Arc<Mutex<RepartitionExecState>>,
662    /// Execution metrics
663    metrics: ExecutionPlanMetricsSet,
664    /// Boolean flag to decide whether to preserve ordering. If true means
665    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
666    preserve_order: bool,
667    /// Cache holding plan properties like equivalences, output partitioning etc.
668    cache: PlanProperties,
669}
670
671#[derive(Debug, Clone)]
672struct RepartitionMetrics {
673    /// Time in nanos to execute child operator and fetch batches
674    fetch_time: metrics::Time,
675    /// Repartitioning elapsed time in nanos
676    repartition_time: metrics::Time,
677    /// Time in nanos for sending resulting batches to channels.
678    ///
679    /// One metric per output partition.
680    send_time: Vec<metrics::Time>,
681}
682
683impl RepartitionMetrics {
684    pub fn new(
685        input_partition: usize,
686        num_output_partitions: usize,
687        metrics: &ExecutionPlanMetricsSet,
688    ) -> Self {
689        // Time in nanos to execute child operator and fetch batches
690        let fetch_time =
691            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
692
693        // Time in nanos to perform repartitioning
694        let repartition_time =
695            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
696
697        // Time in nanos for sending resulting batches to channels
698        let send_time = (0..num_output_partitions)
699            .map(|output_partition| {
700                let label =
701                    metrics::Label::new("outputPartition", output_partition.to_string());
702                MetricBuilder::new(metrics)
703                    .with_label(label)
704                    .subset_time("send_time", input_partition)
705            })
706            .collect();
707
708        Self {
709            fetch_time,
710            repartition_time,
711            send_time,
712        }
713    }
714}
715
716impl RepartitionExec {
717    /// Input execution plan
718    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
719        &self.input
720    }
721
722    /// Partitioning scheme to use
723    pub fn partitioning(&self) -> &Partitioning {
724        &self.cache.partitioning
725    }
726
727    /// Get preserve_order flag of the RepartitionExec
728    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
729    pub fn preserve_order(&self) -> bool {
730        self.preserve_order
731    }
732
733    /// Get name used to display this Exec
734    pub fn name(&self) -> &str {
735        "RepartitionExec"
736    }
737}
738
739impl DisplayAs for RepartitionExec {
740    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
741        match t {
742            DisplayFormatType::Default | DisplayFormatType::Verbose => {
743                write!(
744                    f,
745                    "{}: partitioning={}, input_partitions={}",
746                    self.name(),
747                    self.partitioning(),
748                    self.input.output_partitioning().partition_count()
749                )?;
750
751                if self.preserve_order {
752                    write!(f, ", preserve_order=true")?;
753                }
754
755                if let Some(sort_exprs) = self.sort_exprs() {
756                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
757                }
758                Ok(())
759            }
760            DisplayFormatType::TreeRender => {
761                writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
762
763                let input_partition_count =
764                    self.input.output_partitioning().partition_count();
765                let output_partition_count = self.partitioning().partition_count();
766                let input_to_output_partition_str =
767                    format!("{input_partition_count} -> {output_partition_count}");
768                writeln!(
769                    f,
770                    "partition_count(in->out)={input_to_output_partition_str}"
771                )?;
772
773                if self.preserve_order {
774                    writeln!(f, "preserve_order={}", self.preserve_order)?;
775                }
776                Ok(())
777            }
778        }
779    }
780}
781
782impl ExecutionPlan for RepartitionExec {
783    fn name(&self) -> &'static str {
784        "RepartitionExec"
785    }
786
787    /// Return a reference to Any that can be used for downcasting
788    fn as_any(&self) -> &dyn Any {
789        self
790    }
791
792    fn properties(&self) -> &PlanProperties {
793        &self.cache
794    }
795
796    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
797        vec![&self.input]
798    }
799
800    fn with_new_children(
801        self: Arc<Self>,
802        mut children: Vec<Arc<dyn ExecutionPlan>>,
803    ) -> Result<Arc<dyn ExecutionPlan>> {
804        let mut repartition = RepartitionExec::try_new(
805            children.swap_remove(0),
806            self.partitioning().clone(),
807        )?;
808        if self.preserve_order {
809            repartition = repartition.with_preserve_order();
810        }
811        Ok(Arc::new(repartition))
812    }
813
814    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
815        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
816    }
817
818    fn maintains_input_order(&self) -> Vec<bool> {
819        Self::maintains_input_order_helper(self.input(), self.preserve_order)
820    }
821
822    fn execute(
823        &self,
824        partition: usize,
825        context: Arc<TaskContext>,
826    ) -> Result<SendableRecordBatchStream> {
827        trace!(
828            "Start {}::execute for partition: {}",
829            self.name(),
830            partition
831        );
832
833        let spill_metrics = SpillMetrics::new(&self.metrics, partition);
834
835        let input = Arc::clone(&self.input);
836        let partitioning = self.partitioning().clone();
837        let metrics = self.metrics.clone();
838        let preserve_order = self.sort_exprs().is_some();
839        let name = self.name().to_owned();
840        let schema = self.schema();
841        let schema_captured = Arc::clone(&schema);
842
843        let spill_manager = SpillManager::new(
844            Arc::clone(&context.runtime_env()),
845            spill_metrics,
846            input.schema(),
847        );
848
849        // Get existing ordering to use for merging
850        let sort_exprs = self.sort_exprs().cloned();
851
852        let state = Arc::clone(&self.state);
853        if let Some(mut state) = state.try_lock() {
854            state.ensure_input_streams_initialized(
855                Arc::clone(&input),
856                metrics.clone(),
857                partitioning.partition_count(),
858                Arc::clone(&context),
859            )?;
860        }
861
862        let num_input_partitions = input.output_partitioning().partition_count();
863
864        let stream = futures::stream::once(async move {
865            // lock scope
866            let (rx, reservation, spill_readers, abort_helper) = {
867                // lock mutexes
868                let mut state = state.lock();
869                let state = state.consume_input_streams(
870                    Arc::clone(&input),
871                    metrics.clone(),
872                    partitioning,
873                    preserve_order,
874                    name.clone(),
875                    Arc::clone(&context),
876                    spill_manager.clone(),
877                )?;
878
879                // now return stream for the specified *output* partition which will
880                // read from the channel
881                let PartitionChannels {
882                    rx,
883                    reservation,
884                    spill_readers,
885                    ..
886                } = state
887                    .channels
888                    .remove(&partition)
889                    .expect("partition not used yet");
890
891                (
892                    rx,
893                    reservation,
894                    spill_readers,
895                    Arc::clone(&state.abort_helper),
896                )
897            };
898
899            trace!(
900                "Before returning stream in {name}::execute for partition: {partition}"
901            );
902
903            if preserve_order {
904                // Store streams from all the input partitions:
905                // Each input partition gets its own spill reader to maintain proper FIFO ordering
906                let input_streams = rx
907                    .into_iter()
908                    .zip(spill_readers)
909                    .map(|(receiver, spill_stream)| {
910                        // In preserve_order mode, each receiver corresponds to exactly one input partition
911                        Box::pin(PerPartitionStream::new(
912                            Arc::clone(&schema_captured),
913                            receiver,
914                            Arc::clone(&abort_helper),
915                            Arc::clone(&reservation),
916                            spill_stream,
917                            1, // Each receiver handles one input partition
918                        )) as SendableRecordBatchStream
919                    })
920                    .collect::<Vec<_>>();
921                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
922
923                // Merge streams (while preserving ordering) coming from
924                // input partitions to this partition:
925                let fetch = None;
926                let merge_reservation =
927                    MemoryConsumer::new(format!("{name}[Merge {partition}]"))
928                        .register(context.memory_pool());
929                StreamingMergeBuilder::new()
930                    .with_streams(input_streams)
931                    .with_schema(schema_captured)
932                    .with_expressions(&sort_exprs.unwrap())
933                    .with_metrics(BaselineMetrics::new(&metrics, partition))
934                    .with_batch_size(context.session_config().batch_size())
935                    .with_fetch(fetch)
936                    .with_reservation(merge_reservation)
937                    .with_spill_manager(spill_manager)
938                    .build()
939            } else {
940                // Non-preserve-order case: single input stream, so use the first spill reader
941                let spill_stream = spill_readers
942                    .into_iter()
943                    .next()
944                    .expect("at least one spill reader should exist");
945
946                Ok(Box::pin(PerPartitionStream::new(
947                    schema_captured,
948                    rx.into_iter()
949                        .next()
950                        .expect("at least one receiver should exist"),
951                    abort_helper,
952                    reservation,
953                    spill_stream,
954                    num_input_partitions,
955                )) as SendableRecordBatchStream)
956            }
957        })
958        .try_flatten();
959        let stream = RecordBatchStreamAdapter::new(schema, stream);
960        Ok(Box::pin(stream))
961    }
962
963    fn metrics(&self) -> Option<MetricsSet> {
964        Some(self.metrics.clone_inner())
965    }
966
967    fn statistics(&self) -> Result<Statistics> {
968        self.input.partition_statistics(None)
969    }
970
971    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
972        if let Some(partition) = partition {
973            let partition_count = self.partitioning().partition_count();
974            if partition_count == 0 {
975                return Ok(Statistics::new_unknown(&self.schema()));
976            }
977
978            if partition >= partition_count {
979                return internal_err!(
980                    "RepartitionExec invalid partition {} (expected less than {})",
981                    partition,
982                    self.partitioning().partition_count()
983                );
984            }
985
986            let mut stats = self.input.partition_statistics(None)?;
987
988            // Distribute statistics across partitions
989            stats.num_rows = stats
990                .num_rows
991                .get_value()
992                .map(|rows| Precision::Inexact(rows / partition_count))
993                .unwrap_or(Precision::Absent);
994            stats.total_byte_size = stats
995                .total_byte_size
996                .get_value()
997                .map(|bytes| Precision::Inexact(bytes / partition_count))
998                .unwrap_or(Precision::Absent);
999
1000            // Make all column stats unknown
1001            stats.column_statistics = stats
1002                .column_statistics
1003                .iter()
1004                .map(|_| ColumnStatistics::new_unknown())
1005                .collect();
1006
1007            Ok(stats)
1008        } else {
1009            self.input.partition_statistics(None)
1010        }
1011    }
1012
1013    fn cardinality_effect(&self) -> CardinalityEffect {
1014        CardinalityEffect::Equal
1015    }
1016
1017    fn try_swapping_with_projection(
1018        &self,
1019        projection: &ProjectionExec,
1020    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1021        // If the projection does not narrow the schema, we should not try to push it down.
1022        if projection.expr().len() >= projection.input().schema().fields().len() {
1023            return Ok(None);
1024        }
1025
1026        // If pushdown is not beneficial or applicable, break it.
1027        if projection.benefits_from_input_partitioning()[0]
1028            || !all_columns(projection.expr())
1029        {
1030            return Ok(None);
1031        }
1032
1033        let new_projection = make_with_child(projection, self.input())?;
1034
1035        let new_partitioning = match self.partitioning() {
1036            Partitioning::Hash(partitions, size) => {
1037                let mut new_partitions = vec![];
1038                for partition in partitions {
1039                    let Some(new_partition) =
1040                        update_expr(partition, projection.expr(), false)?
1041                    else {
1042                        return Ok(None);
1043                    };
1044                    new_partitions.push(new_partition);
1045                }
1046                Partitioning::Hash(new_partitions, *size)
1047            }
1048            others => others.clone(),
1049        };
1050
1051        Ok(Some(Arc::new(RepartitionExec::try_new(
1052            new_projection,
1053            new_partitioning,
1054        )?)))
1055    }
1056
1057    fn gather_filters_for_pushdown(
1058        &self,
1059        _phase: FilterPushdownPhase,
1060        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
1061        _config: &ConfigOptions,
1062    ) -> Result<FilterDescription> {
1063        FilterDescription::from_children(parent_filters, &self.children())
1064    }
1065
1066    fn handle_child_pushdown_result(
1067        &self,
1068        _phase: FilterPushdownPhase,
1069        child_pushdown_result: ChildPushdownResult,
1070        _config: &ConfigOptions,
1071    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
1072        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
1073    }
1074
1075    fn repartitioned(
1076        &self,
1077        target_partitions: usize,
1078        _config: &ConfigOptions,
1079    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1080        use Partitioning::*;
1081        let mut new_properties = self.cache.clone();
1082        new_properties.partitioning = match new_properties.partitioning {
1083            RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
1084            Hash(hash, _) => Hash(hash, target_partitions),
1085            UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
1086        };
1087        Ok(Some(Arc::new(Self {
1088            input: Arc::clone(&self.input),
1089            state: Arc::clone(&self.state),
1090            metrics: self.metrics.clone(),
1091            preserve_order: self.preserve_order,
1092            cache: new_properties,
1093        })))
1094    }
1095}
1096
1097impl RepartitionExec {
1098    /// Create a new RepartitionExec, that produces output `partitioning`, and
1099    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
1100    /// for more details)
1101    pub fn try_new(
1102        input: Arc<dyn ExecutionPlan>,
1103        partitioning: Partitioning,
1104    ) -> Result<Self> {
1105        let preserve_order = false;
1106        let cache =
1107            Self::compute_properties(&input, partitioning.clone(), preserve_order);
1108        Ok(RepartitionExec {
1109            input,
1110            state: Default::default(),
1111            metrics: ExecutionPlanMetricsSet::new(),
1112            preserve_order,
1113            cache,
1114        })
1115    }
1116
1117    fn maintains_input_order_helper(
1118        input: &Arc<dyn ExecutionPlan>,
1119        preserve_order: bool,
1120    ) -> Vec<bool> {
1121        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
1122        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
1123    }
1124
1125    fn eq_properties_helper(
1126        input: &Arc<dyn ExecutionPlan>,
1127        preserve_order: bool,
1128    ) -> EquivalenceProperties {
1129        // Equivalence Properties
1130        let mut eq_properties = input.equivalence_properties().clone();
1131        // If the ordering is lost, reset the ordering equivalence class:
1132        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
1133            eq_properties.clear_orderings();
1134        }
1135        // When there are more than one input partitions, they will be fused at the output.
1136        // Therefore, remove per partition constants.
1137        if input.output_partitioning().partition_count() > 1 {
1138            eq_properties.clear_per_partition_constants();
1139        }
1140        eq_properties
1141    }
1142
1143    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1144    fn compute_properties(
1145        input: &Arc<dyn ExecutionPlan>,
1146        partitioning: Partitioning,
1147        preserve_order: bool,
1148    ) -> PlanProperties {
1149        PlanProperties::new(
1150            Self::eq_properties_helper(input, preserve_order),
1151            partitioning,
1152            input.pipeline_behavior(),
1153            input.boundedness(),
1154        )
1155        .with_scheduling_type(SchedulingType::Cooperative)
1156        .with_evaluation_type(EvaluationType::Eager)
1157    }
1158
1159    /// Specify if this repartitioning operation should preserve the order of
1160    /// rows from its input when producing output. Preserving order is more
1161    /// expensive at runtime, so should only be set if the output of this
1162    /// operator can take advantage of it.
1163    ///
1164    /// If the input is not ordered, or has only one partition, this is a no op,
1165    /// and the node remains a `RepartitionExec`.
1166    pub fn with_preserve_order(mut self) -> Self {
1167        self.preserve_order =
1168                // If the input isn't ordered, there is no ordering to preserve
1169                self.input.output_ordering().is_some() &&
1170                // if there is only one input partition, merging is not required
1171                // to maintain order
1172                self.input.output_partitioning().partition_count() > 1;
1173        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1174        self.cache = self.cache.with_eq_properties(eq_properties);
1175        self
1176    }
1177
1178    /// Return the sort expressions that are used to merge
1179    fn sort_exprs(&self) -> Option<&LexOrdering> {
1180        if self.preserve_order {
1181            self.input.output_ordering()
1182        } else {
1183            None
1184        }
1185    }
1186
1187    /// Pulls data from the specified input plan, feeding it to the
1188    /// output partitions based on the desired partitioning
1189    ///
1190    /// `output_channels` holds the output sending channels for each output partition
1191    async fn pull_from_input(
1192        mut stream: SendableRecordBatchStream,
1193        mut output_channels: HashMap<usize, OutputChannel>,
1194        partitioning: Partitioning,
1195        metrics: RepartitionMetrics,
1196    ) -> Result<()> {
1197        let mut partitioner =
1198            BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
1199
1200        // While there are still outputs to send to, keep pulling inputs
1201        let mut batches_until_yield = partitioner.num_partitions();
1202        while !output_channels.is_empty() {
1203            // fetch the next batch
1204            let timer = metrics.fetch_time.timer();
1205            let result = stream.next().await;
1206            timer.done();
1207
1208            // Input is done
1209            let batch = match result {
1210                Some(result) => result?,
1211                None => break,
1212            };
1213
1214            // Handle empty batch
1215            if batch.num_rows() == 0 {
1216                continue;
1217            }
1218
1219            for res in partitioner.partition_iter(batch)? {
1220                let (partition, batch) = res?;
1221                let size = batch.get_array_memory_size();
1222
1223                let timer = metrics.send_time[partition].timer();
1224                // if there is still a receiver, send to it
1225                if let Some(channel) = output_channels.get_mut(&partition) {
1226                    let (batch_to_send, is_memory_batch) =
1227                        match channel.reservation.lock().try_grow(size) {
1228                            Ok(_) => {
1229                                // Memory available - send in-memory batch
1230                                (RepartitionBatch::Memory(batch), true)
1231                            }
1232                            Err(_) => {
1233                                // We're memory limited - spill to SpillPool
1234                                // SpillPool handles file handle reuse and rotation
1235                                channel.spill_writer.push_batch(&batch)?;
1236                                // Send marker indicating batch was spilled
1237                                (RepartitionBatch::Spilled, false)
1238                            }
1239                        };
1240
1241                    if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
1242                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
1243                        // Only shrink memory if it was a memory batch
1244                        if is_memory_batch {
1245                            channel.reservation.lock().shrink(size);
1246                        }
1247                        output_channels.remove(&partition);
1248                    }
1249                }
1250                timer.done();
1251            }
1252
1253            // If the input stream is endless, we may spin forever and
1254            // never yield back to tokio.  See
1255            // https://github.com/apache/datafusion/issues/5278.
1256            //
1257            // However, yielding on every batch causes a bottleneck
1258            // when running with multiple cores. See
1259            // https://github.com/apache/datafusion/issues/6290
1260            //
1261            // Thus, heuristically yield after producing num_partition
1262            // batches
1263            //
1264            // In round robin this is ideal as each input will get a
1265            // new batch. In hash partitioning it may yield too often
1266            // on uneven distributions even if some partition can not
1267            // make progress, but parallelism is going to be limited
1268            // in that case anyways
1269            if batches_until_yield == 0 {
1270                tokio::task::yield_now().await;
1271                batches_until_yield = partitioner.num_partitions();
1272            } else {
1273                batches_until_yield -= 1;
1274            }
1275        }
1276
1277        // Spill writers will auto-finalize when dropped
1278        // No need for explicit flush
1279        Ok(())
1280    }
1281
1282    /// Waits for `input_task` which is consuming one of the inputs to
1283    /// complete. Upon each successful completion, sends a `None` to
1284    /// each of the output tx channels to signal one of the inputs is
1285    /// complete. Upon error, propagates the errors to all output tx
1286    /// channels.
1287    async fn wait_for_task(
1288        input_task: SpawnedTask<Result<()>>,
1289        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1290    ) {
1291        // wait for completion, and propagate error
1292        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
1293
1294        match input_task.join().await {
1295            // Error in joining task
1296            Err(e) => {
1297                let e = Arc::new(e);
1298
1299                for (_, tx) in txs {
1300                    let err = Err(DataFusionError::Context(
1301                        "Join Error".to_string(),
1302                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1303                    ));
1304                    tx.send(Some(err)).await.ok();
1305                }
1306            }
1307            // Error from running input task
1308            Ok(Err(e)) => {
1309                // send the same Arc'd error to all output partitions
1310                let e = Arc::new(e);
1311
1312                for (_, tx) in txs {
1313                    // wrap it because need to send error to all output partitions
1314                    let err = Err(DataFusionError::from(&e));
1315                    tx.send(Some(err)).await.ok();
1316                }
1317            }
1318            // Input task completed successfully
1319            Ok(Ok(())) => {
1320                // notify each output partition that this input partition has no more data
1321                for (_partition, tx) in txs {
1322                    tx.send(None).await.ok();
1323                }
1324            }
1325        }
1326    }
1327}
1328
1329/// State for tracking whether we're reading from memory channel or spill stream.
1330///
1331/// This state machine ensures proper ordering when batches are mixed between memory
1332/// and spilled storage. When a [`RepartitionBatch::Spilled`] marker is received,
1333/// the stream must block on the spill stream until the corresponding batch arrives.
1334///
1335/// # State Machine
1336///
1337/// ```text
1338///                        ┌─────────────────┐
1339///                   ┌───▶│  ReadingMemory  │◀───┐
1340///                   │    └────────┬────────┘    │
1341///                   │             │             │
1342///                   │     Poll channel          │
1343///                   │             │             │
1344///                   │  ┌──────────┼─────────────┐
1345///                   │  │          │             │
1346///                   │  ▼          ▼             │
1347///                   │ Memory   Spilled          │
1348///       Got batch   │ batch    marker           │
1349///       from spill  │  │          │             │
1350///                   │  │          ▼             │
1351///                   │  │  ┌──────────────────┐  │
1352///                   │  │  │ ReadingSpilled   │  │
1353///                   │  │  └────────┬─────────┘  │
1354///                   │  │           │            │
1355///                   │  │   Poll spill_stream    │
1356///                   │  │           │            │
1357///                   │  │           ▼            │
1358///                   │  │      Get batch         │
1359///                   │  │           │            │
1360///                   └──┴───────────┴────────────┘
1361///                                  │
1362///                                  ▼
1363///                           Return batch
1364///                     (Order preserved within
1365///                      (input, output) pair)
1366/// ```
1367///
1368/// The transition to `ReadingSpilled` blocks further channel polling to maintain
1369/// FIFO ordering - we cannot read the next item from the channel until the spill
1370/// stream provides the current batch.
1371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1372enum StreamState {
1373    /// Reading from the memory channel (normal operation)
1374    ReadingMemory,
1375    /// Waiting for a spilled batch from the spill stream.
1376    /// Must not poll channel until spilled batch is received to preserve ordering.
1377    ReadingSpilled,
1378}
1379
1380/// This struct converts a receiver to a stream.
1381/// Receiver receives data on an SPSC channel.
1382struct PerPartitionStream {
1383    /// Schema wrapped by Arc
1384    schema: SchemaRef,
1385
1386    /// channel containing the repartitioned batches
1387    receiver: DistributionReceiver<MaybeBatch>,
1388
1389    /// Handle to ensure background tasks are killed when no longer needed.
1390    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1391
1392    /// Memory reservation.
1393    reservation: SharedMemoryReservation,
1394
1395    /// Infinite stream for reading from the spill pool
1396    spill_stream: SendableRecordBatchStream,
1397
1398    /// Internal state indicating if we are reading from memory or spill stream
1399    state: StreamState,
1400
1401    /// Number of input partitions that have not yet finished.
1402    /// In non-preserve-order mode, multiple input partitions send to the same channel,
1403    /// each sending None when complete. We must wait for all of them.
1404    remaining_partitions: usize,
1405}
1406
1407impl PerPartitionStream {
1408    fn new(
1409        schema: SchemaRef,
1410        receiver: DistributionReceiver<MaybeBatch>,
1411        drop_helper: Arc<Vec<SpawnedTask<()>>>,
1412        reservation: SharedMemoryReservation,
1413        spill_stream: SendableRecordBatchStream,
1414        num_input_partitions: usize,
1415    ) -> Self {
1416        Self {
1417            schema,
1418            receiver,
1419            _drop_helper: drop_helper,
1420            reservation,
1421            spill_stream,
1422            state: StreamState::ReadingMemory,
1423            remaining_partitions: num_input_partitions,
1424        }
1425    }
1426}
1427
1428impl Stream for PerPartitionStream {
1429    type Item = Result<RecordBatch>;
1430
1431    fn poll_next(
1432        mut self: Pin<&mut Self>,
1433        cx: &mut Context<'_>,
1434    ) -> Poll<Option<Self::Item>> {
1435        use futures::StreamExt;
1436
1437        loop {
1438            match self.state {
1439                StreamState::ReadingMemory => {
1440                    // Poll the memory channel for next message
1441                    let value = match self.receiver.recv().poll_unpin(cx) {
1442                        Poll::Ready(v) => v,
1443                        Poll::Pending => {
1444                            // Nothing from channel, wait
1445                            return Poll::Pending;
1446                        }
1447                    };
1448
1449                    match value {
1450                        Some(Some(v)) => match v {
1451                            Ok(RepartitionBatch::Memory(batch)) => {
1452                                // Release memory and return batch
1453                                self.reservation
1454                                    .lock()
1455                                    .shrink(batch.get_array_memory_size());
1456                                return Poll::Ready(Some(Ok(batch)));
1457                            }
1458                            Ok(RepartitionBatch::Spilled) => {
1459                                // Batch was spilled, transition to reading from spill stream
1460                                // We must block on spill stream until we get the batch
1461                                // to preserve ordering
1462                                self.state = StreamState::ReadingSpilled;
1463                                continue;
1464                            }
1465                            Err(e) => {
1466                                return Poll::Ready(Some(Err(e)));
1467                            }
1468                        },
1469                        Some(None) => {
1470                            // One input partition finished
1471                            self.remaining_partitions -= 1;
1472                            if self.remaining_partitions == 0 {
1473                                // All input partitions finished
1474                                return Poll::Ready(None);
1475                            }
1476                            // Continue to poll for more data from other partitions
1477                            continue;
1478                        }
1479                        None => {
1480                            // Channel closed unexpectedly
1481                            return Poll::Ready(None);
1482                        }
1483                    }
1484                }
1485                StreamState::ReadingSpilled => {
1486                    // Poll spill stream for the spilled batch
1487                    match self.spill_stream.poll_next_unpin(cx) {
1488                        Poll::Ready(Some(Ok(batch))) => {
1489                            self.state = StreamState::ReadingMemory;
1490                            return Poll::Ready(Some(Ok(batch)));
1491                        }
1492                        Poll::Ready(Some(Err(e))) => {
1493                            return Poll::Ready(Some(Err(e)));
1494                        }
1495                        Poll::Ready(None) => {
1496                            // Spill stream ended, keep draining the memory channel
1497                            self.state = StreamState::ReadingMemory;
1498                        }
1499                        Poll::Pending => {
1500                            // Spilled batch not ready yet, must wait
1501                            // This preserves ordering by blocking until spill data arrives
1502                            return Poll::Pending;
1503                        }
1504                    }
1505                }
1506            }
1507        }
1508    }
1509}
1510
1511impl RecordBatchStream for PerPartitionStream {
1512    /// Get the schema
1513    fn schema(&self) -> SchemaRef {
1514        Arc::clone(&self.schema)
1515    }
1516}
1517
1518#[cfg(test)]
1519mod tests {
1520    use std::collections::HashSet;
1521
1522    use super::*;
1523    use crate::test::TestMemoryExec;
1524    use crate::{
1525        test::{
1526            assert_is_pending,
1527            exec::{
1528                assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1529                ErrorExec, MockExec,
1530            },
1531        },
1532        {collect, expressions::col},
1533    };
1534
1535    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1536    use arrow::datatypes::{DataType, Field, Schema};
1537    use datafusion_common::cast::as_string_array;
1538    use datafusion_common::exec_err;
1539    use datafusion_common::test_util::batches_to_sort_string;
1540    use datafusion_common_runtime::JoinSet;
1541    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1542    use insta::assert_snapshot;
1543    use itertools::Itertools;
1544
1545    #[tokio::test]
1546    async fn one_to_many_round_robin() -> Result<()> {
1547        // define input partitions
1548        let schema = test_schema();
1549        let partition = create_vec_batches(50);
1550        let partitions = vec![partition];
1551
1552        // repartition from 1 input to 4 output
1553        let output_partitions =
1554            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1555
1556        assert_eq!(4, output_partitions.len());
1557        assert_eq!(13, output_partitions[0].len());
1558        assert_eq!(13, output_partitions[1].len());
1559        assert_eq!(12, output_partitions[2].len());
1560        assert_eq!(12, output_partitions[3].len());
1561
1562        Ok(())
1563    }
1564
1565    #[tokio::test]
1566    async fn many_to_one_round_robin() -> Result<()> {
1567        // define input partitions
1568        let schema = test_schema();
1569        let partition = create_vec_batches(50);
1570        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1571
1572        // repartition from 3 input to 1 output
1573        let output_partitions =
1574            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1575
1576        assert_eq!(1, output_partitions.len());
1577        assert_eq!(150, output_partitions[0].len());
1578
1579        Ok(())
1580    }
1581
1582    #[tokio::test]
1583    async fn many_to_many_round_robin() -> Result<()> {
1584        // define input partitions
1585        let schema = test_schema();
1586        let partition = create_vec_batches(50);
1587        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1588
1589        // repartition from 3 input to 5 output
1590        let output_partitions =
1591            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1592
1593        assert_eq!(5, output_partitions.len());
1594        assert_eq!(30, output_partitions[0].len());
1595        assert_eq!(30, output_partitions[1].len());
1596        assert_eq!(30, output_partitions[2].len());
1597        assert_eq!(30, output_partitions[3].len());
1598        assert_eq!(30, output_partitions[4].len());
1599
1600        Ok(())
1601    }
1602
1603    #[tokio::test]
1604    async fn many_to_many_hash_partition() -> Result<()> {
1605        // define input partitions
1606        let schema = test_schema();
1607        let partition = create_vec_batches(50);
1608        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1609
1610        let output_partitions = repartition(
1611            &schema,
1612            partitions,
1613            Partitioning::Hash(vec![col("c0", &schema)?], 8),
1614        )
1615        .await?;
1616
1617        let total_rows: usize = output_partitions
1618            .iter()
1619            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1620            .sum();
1621
1622        assert_eq!(8, output_partitions.len());
1623        assert_eq!(total_rows, 8 * 50 * 3);
1624
1625        Ok(())
1626    }
1627
1628    fn test_schema() -> Arc<Schema> {
1629        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1630    }
1631
1632    async fn repartition(
1633        schema: &SchemaRef,
1634        input_partitions: Vec<Vec<RecordBatch>>,
1635        partitioning: Partitioning,
1636    ) -> Result<Vec<Vec<RecordBatch>>> {
1637        let task_ctx = Arc::new(TaskContext::default());
1638        // create physical plan
1639        let exec =
1640            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1641        let exec = RepartitionExec::try_new(exec, partitioning)?;
1642
1643        // execute and collect results
1644        let mut output_partitions = vec![];
1645        for i in 0..exec.partitioning().partition_count() {
1646            // execute this *output* partition and collect all batches
1647            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1648            let mut batches = vec![];
1649            while let Some(result) = stream.next().await {
1650                batches.push(result?);
1651            }
1652            output_partitions.push(batches);
1653        }
1654        Ok(output_partitions)
1655    }
1656
1657    #[tokio::test]
1658    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1659        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1660            SpawnedTask::spawn(async move {
1661                // define input partitions
1662                let schema = test_schema();
1663                let partition = create_vec_batches(50);
1664                let partitions =
1665                    vec![partition.clone(), partition.clone(), partition.clone()];
1666
1667                // repartition from 3 input to 5 output
1668                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1669            });
1670
1671        let output_partitions = handle.join().await.unwrap().unwrap();
1672
1673        assert_eq!(5, output_partitions.len());
1674        assert_eq!(30, output_partitions[0].len());
1675        assert_eq!(30, output_partitions[1].len());
1676        assert_eq!(30, output_partitions[2].len());
1677        assert_eq!(30, output_partitions[3].len());
1678        assert_eq!(30, output_partitions[4].len());
1679
1680        Ok(())
1681    }
1682
1683    #[tokio::test]
1684    async fn unsupported_partitioning() {
1685        let task_ctx = Arc::new(TaskContext::default());
1686        // have to send at least one batch through to provoke error
1687        let batch = RecordBatch::try_from_iter(vec![(
1688            "my_awesome_field",
1689            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1690        )])
1691        .unwrap();
1692
1693        let schema = batch.schema();
1694        let input = MockExec::new(vec![Ok(batch)], schema);
1695        // This generates an error (partitioning type not supported)
1696        // but only after the plan is executed. The error should be
1697        // returned and no results produced
1698        let partitioning = Partitioning::UnknownPartitioning(1);
1699        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1700        let output_stream = exec.execute(0, task_ctx).unwrap();
1701
1702        // Expect that an error is returned
1703        let result_string = crate::common::collect(output_stream)
1704            .await
1705            .unwrap_err()
1706            .to_string();
1707        assert!(
1708            result_string
1709                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1710            "actual: {result_string}"
1711        );
1712    }
1713
1714    #[tokio::test]
1715    async fn error_for_input_exec() {
1716        // This generates an error on a call to execute. The error
1717        // should be returned and no results produced.
1718
1719        let task_ctx = Arc::new(TaskContext::default());
1720        let input = ErrorExec::new();
1721        let partitioning = Partitioning::RoundRobinBatch(1);
1722        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1723
1724        // Expect that an error is returned
1725        let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1726
1727        assert!(
1728            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1729            "actual: {result_string}"
1730        );
1731    }
1732
1733    #[tokio::test]
1734    async fn repartition_with_error_in_stream() {
1735        let task_ctx = Arc::new(TaskContext::default());
1736        let batch = RecordBatch::try_from_iter(vec![(
1737            "my_awesome_field",
1738            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1739        )])
1740        .unwrap();
1741
1742        // input stream returns one good batch and then one error. The
1743        // error should be returned.
1744        let err = exec_err!("bad data error");
1745
1746        let schema = batch.schema();
1747        let input = MockExec::new(vec![Ok(batch), err], schema);
1748        let partitioning = Partitioning::RoundRobinBatch(1);
1749        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1750
1751        // Note: this should pass (the stream can be created) but the
1752        // error when the input is executed should get passed back
1753        let output_stream = exec.execute(0, task_ctx).unwrap();
1754
1755        // Expect that an error is returned
1756        let result_string = crate::common::collect(output_stream)
1757            .await
1758            .unwrap_err()
1759            .to_string();
1760        assert!(
1761            result_string.contains("bad data error"),
1762            "actual: {result_string}"
1763        );
1764    }
1765
1766    #[tokio::test]
1767    async fn repartition_with_delayed_stream() {
1768        let task_ctx = Arc::new(TaskContext::default());
1769        let batch1 = RecordBatch::try_from_iter(vec![(
1770            "my_awesome_field",
1771            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1772        )])
1773        .unwrap();
1774
1775        let batch2 = RecordBatch::try_from_iter(vec![(
1776            "my_awesome_field",
1777            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1778        )])
1779        .unwrap();
1780
1781        // The mock exec doesn't return immediately (instead it
1782        // requires the input to wait at least once)
1783        let schema = batch1.schema();
1784        let expected_batches = vec![batch1.clone(), batch2.clone()];
1785        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1786        let partitioning = Partitioning::RoundRobinBatch(1);
1787
1788        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1789
1790        assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
1791        +------------------+
1792        | my_awesome_field |
1793        +------------------+
1794        | bar              |
1795        | baz              |
1796        | foo              |
1797        | frob             |
1798        +------------------+
1799        ");
1800
1801        let output_stream = exec.execute(0, task_ctx).unwrap();
1802        let batches = crate::common::collect(output_stream).await.unwrap();
1803
1804        assert_snapshot!(batches_to_sort_string(&batches), @r"
1805        +------------------+
1806        | my_awesome_field |
1807        +------------------+
1808        | bar              |
1809        | baz              |
1810        | foo              |
1811        | frob             |
1812        +------------------+
1813        ");
1814    }
1815
1816    #[tokio::test]
1817    async fn robin_repartition_with_dropping_output_stream() {
1818        let task_ctx = Arc::new(TaskContext::default());
1819        let partitioning = Partitioning::RoundRobinBatch(2);
1820        // The barrier exec waits to be pinged
1821        // requires the input to wait at least once)
1822        let input = Arc::new(make_barrier_exec());
1823
1824        // partition into two output streams
1825        let exec = RepartitionExec::try_new(
1826            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1827            partitioning,
1828        )
1829        .unwrap();
1830
1831        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1832        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1833
1834        // now, purposely drop output stream 0
1835        // *before* any outputs are produced
1836        drop(output_stream0);
1837
1838        // Now, start sending input
1839        let mut background_task = JoinSet::new();
1840        background_task.spawn(async move {
1841            input.wait().await;
1842        });
1843
1844        // output stream 1 should *not* error and have one of the input batches
1845        let batches = crate::common::collect(output_stream1).await.unwrap();
1846
1847        assert_snapshot!(batches_to_sort_string(&batches), @r#"
1848            +------------------+
1849            | my_awesome_field |
1850            +------------------+
1851            | baz              |
1852            | frob             |
1853            | gaz              |
1854            | grob             |
1855            +------------------+
1856            "#);
1857    }
1858
1859    #[tokio::test]
1860    // As the hash results might be different on different platforms or
1861    // with different compilers, we will compare the same execution with
1862    // and without dropping the output stream.
1863    async fn hash_repartition_with_dropping_output_stream() {
1864        let task_ctx = Arc::new(TaskContext::default());
1865        let partitioning = Partitioning::Hash(
1866            vec![Arc::new(crate::expressions::Column::new(
1867                "my_awesome_field",
1868                0,
1869            ))],
1870            2,
1871        );
1872
1873        // We first collect the results without dropping the output stream.
1874        let input = Arc::new(make_barrier_exec());
1875        let exec = RepartitionExec::try_new(
1876            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1877            partitioning.clone(),
1878        )
1879        .unwrap();
1880        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1881        let mut background_task = JoinSet::new();
1882        background_task.spawn(async move {
1883            input.wait().await;
1884        });
1885        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1886
1887        // run some checks on the result
1888        let items_vec = str_batches_to_vec(&batches_without_drop);
1889        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1890        assert_eq!(items_vec.len(), items_set.len());
1891        let source_str_set: HashSet<&str> =
1892            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1893                .iter()
1894                .copied()
1895                .collect();
1896        assert_eq!(items_set.difference(&source_str_set).count(), 0);
1897
1898        // Now do the same but dropping the stream before waiting for the barrier
1899        let input = Arc::new(make_barrier_exec());
1900        let exec = RepartitionExec::try_new(
1901            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1902            partitioning,
1903        )
1904        .unwrap();
1905        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1906        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1907        // now, purposely drop output stream 0
1908        // *before* any outputs are produced
1909        drop(output_stream0);
1910        let mut background_task = JoinSet::new();
1911        background_task.spawn(async move {
1912            input.wait().await;
1913        });
1914        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1915
1916        fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1917            batch
1918                .into_iter()
1919                .sorted_by_key(|b| format!("{b:?}"))
1920                .collect()
1921        }
1922
1923        assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1924    }
1925
1926    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1927        batches
1928            .iter()
1929            .flat_map(|batch| {
1930                assert_eq!(batch.columns().len(), 1);
1931                let string_array = as_string_array(batch.column(0))
1932                    .expect("Unexpected type for repartitioned batch");
1933
1934                string_array
1935                    .iter()
1936                    .map(|v| v.expect("Unexpected null"))
1937                    .collect::<Vec<_>>()
1938            })
1939            .collect::<Vec<_>>()
1940    }
1941
1942    /// Create a BarrierExec that returns two partitions of two batches each
1943    fn make_barrier_exec() -> BarrierExec {
1944        let batch1 = RecordBatch::try_from_iter(vec![(
1945            "my_awesome_field",
1946            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1947        )])
1948        .unwrap();
1949
1950        let batch2 = RecordBatch::try_from_iter(vec![(
1951            "my_awesome_field",
1952            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1953        )])
1954        .unwrap();
1955
1956        let batch3 = RecordBatch::try_from_iter(vec![(
1957            "my_awesome_field",
1958            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1959        )])
1960        .unwrap();
1961
1962        let batch4 = RecordBatch::try_from_iter(vec![(
1963            "my_awesome_field",
1964            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1965        )])
1966        .unwrap();
1967
1968        // The barrier exec waits to be pinged
1969        // requires the input to wait at least once)
1970        let schema = batch1.schema();
1971        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1972    }
1973
1974    #[tokio::test]
1975    async fn test_drop_cancel() -> Result<()> {
1976        let task_ctx = Arc::new(TaskContext::default());
1977        let schema =
1978            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1979
1980        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1981        let refs = blocking_exec.refs();
1982        let repartition_exec = Arc::new(RepartitionExec::try_new(
1983            blocking_exec,
1984            Partitioning::UnknownPartitioning(1),
1985        )?);
1986
1987        let fut = collect(repartition_exec, task_ctx);
1988        let mut fut = fut.boxed();
1989
1990        assert_is_pending(&mut fut);
1991        drop(fut);
1992        assert_strong_count_converges_to_zero(refs).await;
1993
1994        Ok(())
1995    }
1996
1997    #[tokio::test]
1998    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1999        let task_ctx = Arc::new(TaskContext::default());
2000        let batch = RecordBatch::try_from_iter(vec![(
2001            "a",
2002            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
2003        )])
2004        .unwrap();
2005        let partitioning = Partitioning::Hash(
2006            vec![Arc::new(crate::expressions::Column::new("a", 0))],
2007            2,
2008        );
2009        let schema = batch.schema();
2010        let input = MockExec::new(vec![Ok(batch)], schema);
2011        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
2012        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
2013        let batch0 = crate::common::collect(output_stream0).await.unwrap();
2014        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
2015        let batch1 = crate::common::collect(output_stream1).await.unwrap();
2016        assert!(batch0.is_empty() || batch1.is_empty());
2017        Ok(())
2018    }
2019
2020    #[tokio::test]
2021    async fn repartition_with_spilling() -> Result<()> {
2022        // Test that repartition successfully spills to disk when memory is constrained
2023        let schema = test_schema();
2024        let partition = create_vec_batches(50);
2025        let input_partitions = vec![partition];
2026        let partitioning = Partitioning::RoundRobinBatch(4);
2027
2028        // Set up context with very tight memory limit to force spilling
2029        let runtime = RuntimeEnvBuilder::default()
2030            .with_memory_limit(1, 1.0)
2031            .build_arc()?;
2032
2033        let task_ctx = TaskContext::default().with_runtime(runtime);
2034        let task_ctx = Arc::new(task_ctx);
2035
2036        // create physical plan
2037        let exec =
2038            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2039        let exec = RepartitionExec::try_new(exec, partitioning)?;
2040
2041        // Collect all partitions - should succeed by spilling to disk
2042        let mut total_rows = 0;
2043        for i in 0..exec.partitioning().partition_count() {
2044            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2045            while let Some(result) = stream.next().await {
2046                let batch = result?;
2047                total_rows += batch.num_rows();
2048            }
2049        }
2050
2051        // Verify we got all the data (50 batches * 8 rows each)
2052        assert_eq!(total_rows, 50 * 8);
2053
2054        // Verify spilling metrics to confirm spilling actually happened
2055        let metrics = exec.metrics().unwrap();
2056        assert!(
2057            metrics.spill_count().unwrap() > 0,
2058            "Expected spill_count > 0, but got {:?}",
2059            metrics.spill_count()
2060        );
2061        println!("Spilled {} times", metrics.spill_count().unwrap());
2062        assert!(
2063            metrics.spilled_bytes().unwrap() > 0,
2064            "Expected spilled_bytes > 0, but got {:?}",
2065            metrics.spilled_bytes()
2066        );
2067        println!(
2068            "Spilled {} bytes in {} spills",
2069            metrics.spilled_bytes().unwrap(),
2070            metrics.spill_count().unwrap()
2071        );
2072        assert!(
2073            metrics.spilled_rows().unwrap() > 0,
2074            "Expected spilled_rows > 0, but got {:?}",
2075            metrics.spilled_rows()
2076        );
2077        println!("Spilled {} rows", metrics.spilled_rows().unwrap());
2078
2079        Ok(())
2080    }
2081
2082    #[tokio::test]
2083    async fn repartition_with_partial_spilling() -> Result<()> {
2084        // Test that repartition can handle partial spilling (some batches in memory, some spilled)
2085        let schema = test_schema();
2086        let partition = create_vec_batches(50);
2087        let input_partitions = vec![partition];
2088        let partitioning = Partitioning::RoundRobinBatch(4);
2089
2090        // Set up context with moderate memory limit to force partial spilling
2091        // 2KB should allow some batches in memory but force others to spill
2092        let runtime = RuntimeEnvBuilder::default()
2093            .with_memory_limit(2 * 1024, 1.0)
2094            .build_arc()?;
2095
2096        let task_ctx = TaskContext::default().with_runtime(runtime);
2097        let task_ctx = Arc::new(task_ctx);
2098
2099        // create physical plan
2100        let exec =
2101            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2102        let exec = RepartitionExec::try_new(exec, partitioning)?;
2103
2104        // Collect all partitions - should succeed with partial spilling
2105        let mut total_rows = 0;
2106        for i in 0..exec.partitioning().partition_count() {
2107            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2108            while let Some(result) = stream.next().await {
2109                let batch = result?;
2110                total_rows += batch.num_rows();
2111            }
2112        }
2113
2114        // Verify we got all the data (50 batches * 8 rows each)
2115        assert_eq!(total_rows, 50 * 8);
2116
2117        // Verify partial spilling metrics
2118        let metrics = exec.metrics().unwrap();
2119        let spill_count = metrics.spill_count().unwrap();
2120        let spilled_rows = metrics.spilled_rows().unwrap();
2121        let spilled_bytes = metrics.spilled_bytes().unwrap();
2122
2123        assert!(
2124            spill_count > 0,
2125            "Expected some spilling to occur, but got spill_count={spill_count}"
2126        );
2127        assert!(
2128            spilled_rows > 0 && spilled_rows < total_rows,
2129            "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2130        );
2131        assert!(
2132            spilled_bytes > 0,
2133            "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2134        );
2135
2136        println!(
2137            "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2138            spilled_rows,
2139            total_rows,
2140            (spilled_rows as f64 / total_rows as f64) * 100.0,
2141            spill_count,
2142            spilled_bytes
2143        );
2144
2145        Ok(())
2146    }
2147
2148    #[tokio::test]
2149    async fn repartition_without_spilling() -> Result<()> {
2150        // Test that repartition does not spill when there's ample memory
2151        let schema = test_schema();
2152        let partition = create_vec_batches(50);
2153        let input_partitions = vec![partition];
2154        let partitioning = Partitioning::RoundRobinBatch(4);
2155
2156        // Set up context with generous memory limit - no spilling should occur
2157        let runtime = RuntimeEnvBuilder::default()
2158            .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB
2159            .build_arc()?;
2160
2161        let task_ctx = TaskContext::default().with_runtime(runtime);
2162        let task_ctx = Arc::new(task_ctx);
2163
2164        // create physical plan
2165        let exec =
2166            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2167        let exec = RepartitionExec::try_new(exec, partitioning)?;
2168
2169        // Collect all partitions - should succeed without spilling
2170        let mut total_rows = 0;
2171        for i in 0..exec.partitioning().partition_count() {
2172            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2173            while let Some(result) = stream.next().await {
2174                let batch = result?;
2175                total_rows += batch.num_rows();
2176            }
2177        }
2178
2179        // Verify we got all the data (50 batches * 8 rows each)
2180        assert_eq!(total_rows, 50 * 8);
2181
2182        // Verify no spilling occurred
2183        let metrics = exec.metrics().unwrap();
2184        assert_eq!(
2185            metrics.spill_count(),
2186            Some(0),
2187            "Expected no spilling, but got spill_count={:?}",
2188            metrics.spill_count()
2189        );
2190        assert_eq!(
2191            metrics.spilled_bytes(),
2192            Some(0),
2193            "Expected no bytes spilled, but got spilled_bytes={:?}",
2194            metrics.spilled_bytes()
2195        );
2196        assert_eq!(
2197            metrics.spilled_rows(),
2198            Some(0),
2199            "Expected no rows spilled, but got spilled_rows={:?}",
2200            metrics.spilled_rows()
2201        );
2202
2203        println!("No spilling occurred - all data processed in memory");
2204
2205        Ok(())
2206    }
2207
2208    #[tokio::test]
2209    async fn oom() -> Result<()> {
2210        use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2211
2212        // Test that repartition fails with OOM when disk manager is disabled
2213        let schema = test_schema();
2214        let partition = create_vec_batches(50);
2215        let input_partitions = vec![partition];
2216        let partitioning = Partitioning::RoundRobinBatch(4);
2217
2218        // Setup context with memory limit but NO disk manager (explicitly disabled)
2219        let runtime = RuntimeEnvBuilder::default()
2220            .with_memory_limit(1, 1.0)
2221            .with_disk_manager_builder(
2222                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2223            )
2224            .build_arc()?;
2225
2226        let task_ctx = TaskContext::default().with_runtime(runtime);
2227        let task_ctx = Arc::new(task_ctx);
2228
2229        // create physical plan
2230        let exec =
2231            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2232        let exec = RepartitionExec::try_new(exec, partitioning)?;
2233
2234        // Attempt to execute - should fail with ResourcesExhausted error
2235        for i in 0..exec.partitioning().partition_count() {
2236            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2237            let err = stream.next().await.unwrap().unwrap_err();
2238            let err = err.find_root();
2239            assert!(
2240                matches!(err, DataFusionError::ResourcesExhausted(_)),
2241                "Wrong error type: {err}",
2242            );
2243        }
2244
2245        Ok(())
2246    }
2247
2248    /// Create vector batches
2249    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2250        let batch = create_batch();
2251        (0..n).map(|_| batch.clone()).collect()
2252    }
2253
2254    /// Create batch
2255    fn create_batch() -> RecordBatch {
2256        let schema = test_schema();
2257        RecordBatch::try_new(
2258            schema,
2259            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2260        )
2261        .unwrap()
2262    }
2263
2264    /// Create batches with sequential values for ordering tests
2265    fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
2266        let schema = test_schema();
2267        (0..num_batches)
2268            .map(|i| {
2269                let start = (i * 8) as u32;
2270                RecordBatch::try_new(
2271                    Arc::clone(&schema),
2272                    vec![Arc::new(UInt32Array::from(
2273                        (start..start + 8).collect::<Vec<_>>(),
2274                    ))],
2275                )
2276                .unwrap()
2277            })
2278            .collect()
2279    }
2280
2281    #[tokio::test]
2282    async fn test_repartition_ordering_with_spilling() -> Result<()> {
2283        // Test that repartition preserves ordering when spilling occurs
2284        // This tests the state machine fix where we must block on spill_stream
2285        // when a Spilled marker is received, rather than continuing to poll the channel
2286
2287        let schema = test_schema();
2288        // Create batches with sequential values: batch 0 has [0,1,2,3,4,5,6,7],
2289        // batch 1 has [8,9,10,11,12,13,14,15], etc.
2290        let partition = create_ordered_batches(20);
2291        let input_partitions = vec![partition];
2292
2293        // Use RoundRobinBatch to ensure predictable ordering
2294        let partitioning = Partitioning::RoundRobinBatch(2);
2295
2296        // Set up context with very tight memory limit to force spilling
2297        let runtime = RuntimeEnvBuilder::default()
2298            .with_memory_limit(1, 1.0)
2299            .build_arc()?;
2300
2301        let task_ctx = TaskContext::default().with_runtime(runtime);
2302        let task_ctx = Arc::new(task_ctx);
2303
2304        // create physical plan
2305        let exec =
2306            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2307        let exec = RepartitionExec::try_new(exec, partitioning)?;
2308
2309        // Collect all output partitions
2310        let mut all_batches = Vec::new();
2311        for i in 0..exec.partitioning().partition_count() {
2312            let mut partition_batches = Vec::new();
2313            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2314            while let Some(result) = stream.next().await {
2315                let batch = result?;
2316                partition_batches.push(batch);
2317            }
2318            all_batches.push(partition_batches);
2319        }
2320
2321        // Verify spilling occurred
2322        let metrics = exec.metrics().unwrap();
2323        assert!(
2324            metrics.spill_count().unwrap() > 0,
2325            "Expected spilling to occur, but spill_count = 0"
2326        );
2327
2328        // Verify ordering is preserved within each partition
2329        // With RoundRobinBatch, even batches go to partition 0, odd batches to partition 1
2330        for (partition_idx, batches) in all_batches.iter().enumerate() {
2331            let mut last_value = None;
2332            for batch in batches {
2333                let array = batch
2334                    .column(0)
2335                    .as_any()
2336                    .downcast_ref::<UInt32Array>()
2337                    .unwrap();
2338
2339                for i in 0..array.len() {
2340                    let value = array.value(i);
2341                    if let Some(last) = last_value {
2342                        assert!(
2343                            value > last,
2344                            "Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
2345                        );
2346                    }
2347                    last_value = Some(value);
2348                }
2349            }
2350        }
2351
2352        Ok(())
2353    }
2354}
2355
2356#[cfg(test)]
2357mod test {
2358    use arrow::array::record_batch;
2359    use arrow::compute::SortOptions;
2360    use arrow::datatypes::{DataType, Field, Schema};
2361    use datafusion_common::assert_batches_eq;
2362
2363    use super::*;
2364    use crate::test::TestMemoryExec;
2365    use crate::union::UnionExec;
2366
2367    use datafusion_physical_expr::expressions::col;
2368    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2369
2370    /// Asserts that the plan is as expected
2371    ///
2372    /// `$EXPECTED_PLAN_LINES`: input plan
2373    /// `$PLAN`: the plan to optimized
2374    macro_rules! assert_plan {
2375        ($PLAN: expr,  @ $EXPECTED: expr) => {
2376            let formatted = crate::displayable($PLAN).indent(true).to_string();
2377
2378            insta::assert_snapshot!(
2379                formatted,
2380                @$EXPECTED
2381            );
2382        };
2383    }
2384
2385    #[tokio::test]
2386    async fn test_preserve_order() -> Result<()> {
2387        let schema = test_schema();
2388        let sort_exprs = sort_exprs(&schema);
2389        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2390        let source2 = sorted_memory_exec(&schema, sort_exprs);
2391        // output has multiple partitions, and is sorted
2392        let union = UnionExec::try_new(vec![source1, source2])?;
2393        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2394            .with_preserve_order();
2395
2396        // Repartition should preserve order
2397        assert_plan!(&exec, @r"
2398        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2399          UnionExec
2400            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2401            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2402        ");
2403        Ok(())
2404    }
2405
2406    #[tokio::test]
2407    async fn test_preserve_order_one_partition() -> Result<()> {
2408        let schema = test_schema();
2409        let sort_exprs = sort_exprs(&schema);
2410        let source = sorted_memory_exec(&schema, sort_exprs);
2411        // output is sorted, but has only a single partition, so no need to sort
2412        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2413            .with_preserve_order();
2414
2415        // Repartition should not preserve order
2416        assert_plan!(&exec, @r"
2417        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1
2418          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2419        ");
2420
2421        Ok(())
2422    }
2423
2424    #[tokio::test]
2425    async fn test_preserve_order_input_not_sorted() -> Result<()> {
2426        let schema = test_schema();
2427        let source1 = memory_exec(&schema);
2428        let source2 = memory_exec(&schema);
2429        // output has multiple partitions, but is not sorted
2430        let union = UnionExec::try_new(vec![source1, source2])?;
2431        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2432            .with_preserve_order();
2433
2434        // Repartition should not preserve order, as there is no order to preserve
2435        assert_plan!(&exec, @r"
2436        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2437          UnionExec
2438            DataSourceExec: partitions=1, partition_sizes=[0]
2439            DataSourceExec: partitions=1, partition_sizes=[0]
2440        ");
2441        Ok(())
2442    }
2443
2444    #[tokio::test]
2445    async fn test_preserve_order_with_spilling() -> Result<()> {
2446        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2447        use datafusion_execution::TaskContext;
2448
2449        // Create sorted input data across multiple partitions
2450        // Partition1: [1,3], [5,7], [9,11]
2451        // Partition2: [2,4], [6,8], [10,12]
2452        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2453        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2454        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2455        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2456        let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
2457        let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
2458        let schema = batch1.schema();
2459        let sort_exprs = LexOrdering::new([PhysicalSortExpr {
2460            expr: col("c0", &schema).unwrap(),
2461            options: SortOptions::default().asc(),
2462        }])
2463        .unwrap();
2464        let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
2465        let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
2466        let input_partitions = vec![partition1, partition2];
2467
2468        // Set up context with tight memory limit to force spilling
2469        // Sorting needs some non-spillable memory, so 64 bytes should force spilling while still allowing the query to complete
2470        let runtime = RuntimeEnvBuilder::default()
2471            .with_memory_limit(64, 1.0)
2472            .build_arc()?;
2473
2474        let task_ctx = TaskContext::default().with_runtime(runtime);
2475        let task_ctx = Arc::new(task_ctx);
2476
2477        // Create physical plan with order preservation
2478        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
2479            .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
2480        let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec)));
2481        // Repartition into 3 partitions with order preservation
2482        // We expect 1 batch per output partition after repartitioning
2483        let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
2484            .with_preserve_order();
2485
2486        let mut batches = vec![];
2487
2488        // Collect all partitions - should succeed by spilling to disk
2489        for i in 0..exec.partitioning().partition_count() {
2490            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2491            while let Some(result) = stream.next().await {
2492                let batch = result?;
2493                batches.push(batch);
2494            }
2495        }
2496
2497        #[rustfmt::skip]
2498        let expected = [
2499            [
2500                "+----+",
2501                "| c0 |",
2502                "+----+",
2503                "| 1  |",
2504                "| 2  |",
2505                "| 3  |",
2506                "| 4  |",
2507                "+----+",
2508            ],
2509            [
2510                "+----+",
2511                "| c0 |",
2512                "+----+",
2513                "| 5  |",
2514                "| 6  |",
2515                "| 7  |",
2516                "| 8  |",
2517                "+----+",
2518            ],
2519            [
2520                "+----+",
2521                "| c0 |",
2522                "+----+",
2523                "| 9  |",
2524                "| 10 |",
2525                "| 11 |",
2526                "| 12 |",
2527                "+----+",
2528            ],
2529        ];
2530
2531        for (batch, expected) in batches.iter().zip(expected.iter()) {
2532            assert_batches_eq!(expected, std::slice::from_ref(batch));
2533        }
2534
2535        // We should have spilled ~ all of the data.
2536        // - We spill data during the repartitioning phase
2537        // - We may also spill during the final merge sort
2538        let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
2539        let metrics = exec.metrics().unwrap();
2540        assert!(
2541            metrics.spill_count().unwrap() > input_partitions.len(),
2542            "Expected spill_count > {} for order-preserving repartition, but got {:?}",
2543            input_partitions.len(),
2544            metrics.spill_count()
2545        );
2546        assert!(
2547            metrics.spilled_bytes().unwrap()
2548                > all_batches
2549                    .iter()
2550                    .map(|b| b.get_array_memory_size())
2551                    .sum::<usize>(),
2552            "Expected spilled_bytes > {} for order-preserving repartition, got {}",
2553            all_batches
2554                .iter()
2555                .map(|b| b.get_array_memory_size())
2556                .sum::<usize>(),
2557            metrics.spilled_bytes().unwrap()
2558        );
2559        assert!(
2560            metrics.spilled_rows().unwrap()
2561                >= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2562            "Expected spilled_rows > {} for order-preserving repartition, got {}",
2563            all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
2564            metrics.spilled_rows().unwrap()
2565        );
2566
2567        Ok(())
2568    }
2569
2570    #[tokio::test]
2571    async fn test_hash_partitioning_with_spilling() -> Result<()> {
2572        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2573        use datafusion_execution::TaskContext;
2574
2575        // Create input data similar to the round-robin test
2576        let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
2577        let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
2578        let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
2579        let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
2580        let schema = batch1.schema();
2581
2582        let partition1 = vec![batch1.clone(), batch3.clone()];
2583        let partition2 = vec![batch2.clone(), batch4.clone()];
2584        let input_partitions = vec![partition1, partition2];
2585
2586        // Set up context with memory limit to test hash partitioning with spilling infrastructure
2587        let runtime = RuntimeEnvBuilder::default()
2588            .with_memory_limit(1, 1.0)
2589            .build_arc()?;
2590
2591        let task_ctx = TaskContext::default().with_runtime(runtime);
2592        let task_ctx = Arc::new(task_ctx);
2593
2594        // Create physical plan with hash partitioning
2595        let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
2596        let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec)));
2597        // Hash partition into 2 partitions by column c0
2598        let hash_expr = col("c0", &schema)?;
2599        let exec =
2600            RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
2601
2602        // Collect all partitions concurrently using JoinSet - this prevents deadlock
2603        // where the distribution channel gate closes when all output channels are full
2604        let mut join_set = tokio::task::JoinSet::new();
2605        for i in 0..exec.partitioning().partition_count() {
2606            let stream = exec.execute(i, Arc::clone(&task_ctx))?;
2607            join_set.spawn(async move {
2608                let mut count = 0;
2609                futures::pin_mut!(stream);
2610                while let Some(result) = stream.next().await {
2611                    let batch = result?;
2612                    count += batch.num_rows();
2613                }
2614                Ok::<usize, DataFusionError>(count)
2615            });
2616        }
2617
2618        // Wait for all partitions and sum the rows
2619        let mut total_rows = 0;
2620        while let Some(result) = join_set.join_next().await {
2621            total_rows += result.unwrap()?;
2622        }
2623
2624        // Verify we got all rows back
2625        let all_batches = [batch1, batch2, batch3, batch4];
2626        let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
2627        assert_eq!(total_rows, expected_rows);
2628
2629        // Verify metrics are available
2630        let metrics = exec.metrics().unwrap();
2631        // Just verify the metrics can be retrieved (spilling may or may not occur)
2632        let spill_count = metrics.spill_count().unwrap_or(0);
2633        assert!(spill_count > 0);
2634        let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
2635        assert!(spilled_bytes > 0);
2636        let spilled_rows = metrics.spilled_rows().unwrap_or(0);
2637        assert!(spilled_rows > 0);
2638
2639        Ok(())
2640    }
2641
2642    #[tokio::test]
2643    async fn test_repartition() -> Result<()> {
2644        let schema = test_schema();
2645        let sort_exprs = sort_exprs(&schema);
2646        let source = sorted_memory_exec(&schema, sort_exprs);
2647        // output is sorted, but has only a single partition, so no need to sort
2648        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2649            .repartitioned(20, &Default::default())?
2650            .unwrap();
2651
2652        // Repartition should not preserve order
2653        assert_plan!(exec.as_ref(), @r"
2654        RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1
2655          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2656        ");
2657        Ok(())
2658    }
2659
2660    fn test_schema() -> Arc<Schema> {
2661        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2662    }
2663
2664    fn sort_exprs(schema: &Schema) -> LexOrdering {
2665        [PhysicalSortExpr {
2666            expr: col("c0", schema).unwrap(),
2667            options: SortOptions::default(),
2668        }]
2669        .into()
2670    }
2671
2672    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2673        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2674    }
2675
2676    fn sorted_memory_exec(
2677        schema: &SchemaRef,
2678        sort_exprs: LexOrdering,
2679    ) -> Arc<dyn ExecutionPlan> {
2680        Arc::new(TestMemoryExec::update_cache(Arc::new(
2681            TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2682                .unwrap()
2683                .try_with_sort_information(vec![sort_exprs])
2684                .unwrap(),
2685        )))
2686    }
2687}