datafusion_physical_plan/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30use crate::spill::get_record_batch_memory_size;
31
32use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
33use datafusion_common::{Result, exec_err};
34use datafusion_common_runtime::JoinSet;
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37
38use futures::ready;
39use futures::stream::BoxStream;
40use futures::{Future, Stream, StreamExt};
41use log::debug;
42use pin_project_lite::pin_project;
43use tokio::runtime::Handle;
44use tokio::sync::mpsc::{Receiver, Sender};
45
46/// Creates a stream from a collection of producing tasks, routing panics to the stream.
47///
48/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with the differences being:
49///
50/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
51///
52/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
53///
54/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
55///
56/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
57pub(crate) struct ReceiverStreamBuilder<O> {
58    tx: Sender<Result<O>>,
59    rx: Receiver<Result<O>>,
60    join_set: JoinSet<Result<()>>,
61}
62
63impl<O: Send + 'static> ReceiverStreamBuilder<O> {
64    /// Create new channels with the specified buffer size
65    pub fn new(capacity: usize) -> Self {
66        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
67
68        Self {
69            tx,
70            rx,
71            join_set: JoinSet::new(),
72        }
73    }
74
75    /// Get a handle for sending data to the output
76    pub fn tx(&self) -> Sender<Result<O>> {
77        self.tx.clone()
78    }
79
80    /// Spawn task that will be aborted if this builder (or the stream
81    /// built from it) are dropped
82    pub fn spawn<F>(&mut self, task: F)
83    where
84        F: Future<Output = Result<()>>,
85        F: Send + 'static,
86    {
87        self.join_set.spawn(task);
88    }
89
90    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime
91    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
92    where
93        F: Future<Output = Result<()>>,
94        F: Send + 'static,
95    {
96        self.join_set.spawn_on(task, handle);
97    }
98
99    /// Spawn a blocking task that will be aborted if this builder (or the stream
100    /// built from it) are dropped.
101    ///
102    /// This is often used to spawn tasks that write to the sender
103    /// retrieved from `Self::tx`.
104    pub fn spawn_blocking<F>(&mut self, f: F)
105    where
106        F: FnOnce() -> Result<()>,
107        F: Send + 'static,
108    {
109        self.join_set.spawn_blocking(f);
110    }
111
112    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
113    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
114    where
115        F: FnOnce() -> Result<()>,
116        F: Send + 'static,
117    {
118        self.join_set.spawn_blocking_on(f, handle);
119    }
120
121    /// Create a stream of all data written to `tx`
122    pub fn build(self) -> BoxStream<'static, Result<O>> {
123        let Self {
124            tx,
125            rx,
126            mut join_set,
127        } = self;
128
129        // Doesn't need tx
130        drop(tx);
131
132        // future that checks the result of the join set, and propagates panic if seen
133        let check = async move {
134            while let Some(result) = join_set.join_next().await {
135                match result {
136                    Ok(task_result) => {
137                        match task_result {
138                            // Nothing to report
139                            Ok(_) => continue,
140                            // This means a blocking task error
141                            Err(error) => return Some(Err(error)),
142                        }
143                    }
144                    // This means a tokio task error, likely a panic
145                    Err(e) => {
146                        if e.is_panic() {
147                            // resume on the main thread
148                            std::panic::resume_unwind(e.into_panic());
149                        } else {
150                            // This should only occur if the task is
151                            // cancelled, which would only occur if
152                            // the JoinSet were aborted, which in turn
153                            // would imply that the receiver has been
154                            // dropped and this code is not running
155                            return Some(exec_err!("Non Panic Task error: {e}"));
156                        }
157                    }
158                }
159            }
160            None
161        };
162
163        let check_stream = futures::stream::once(check)
164            // unwrap Option / only return the error
165            .filter_map(|item| async move { item });
166
167        // Convert the receiver into a stream
168        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
169            let next_item = rx.recv().await;
170            next_item.map(|next_item| (next_item, rx))
171        });
172
173        // Merge the streams together so whichever is ready first
174        // produces the batch
175        futures::stream::select(rx_stream, check_stream).boxed()
176    }
177}
178
179/// Builder for `RecordBatchReceiverStream` that propagates errors
180/// and panic's correctly.
181///
182/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
183/// that produce [`RecordBatch`]es and send them to a single
184/// `Receiver` which can improve parallelism.
185///
186/// This also handles propagating panic`s and canceling the tasks.
187///
188/// # Example
189///
190/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
191/// the `tx` end of the builder, after building the stream, we can receive
192/// those batches with calling `.next()`
193///
194/// ```
195/// # use std::sync::Arc;
196/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
197/// # use datafusion_common::arrow::array::RecordBatch;
198/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
199/// # use futures::stream::StreamExt;
200/// # use tokio::runtime::Builder;
201/// # let rt = Builder::new_current_thread().build().unwrap();
202/// #
203/// # rt.block_on(async {
204/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
205/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
206///
207/// // task 1
208/// let tx_1 = builder.tx();
209/// let schema_1 = Arc::clone(&schema);
210/// builder.spawn(async move {
211///     // Your task needs to send batches to the tx
212///     tx_1.send(Ok(RecordBatch::new_empty(schema_1)))
213///         .await
214///         .unwrap();
215///
216///     Ok(())
217/// });
218///
219/// // task 2
220/// let tx_2 = builder.tx();
221/// let schema_2 = Arc::clone(&schema);
222/// builder.spawn(async move {
223///     // Your task needs to send batches to the tx
224///     tx_2.send(Ok(RecordBatch::new_empty(schema_2)))
225///         .await
226///         .unwrap();
227///
228///     Ok(())
229/// });
230///
231/// let mut stream = builder.build();
232/// while let Some(res_batch) = stream.next().await {
233///     // `res_batch` can either from task 1 or 2
234///
235///     // do something with `res_batch`
236/// }
237/// # });
238/// ```
239pub struct RecordBatchReceiverStreamBuilder {
240    schema: SchemaRef,
241    inner: ReceiverStreamBuilder<RecordBatch>,
242}
243
244impl RecordBatchReceiverStreamBuilder {
245    /// Create new channels with the specified buffer size
246    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
247        Self {
248            schema,
249            inner: ReceiverStreamBuilder::new(capacity),
250        }
251    }
252
253    /// Get a handle for sending [`RecordBatch`] to the output
254    ///
255    /// If the stream is dropped / canceled, the sender will be closed and
256    /// calling `tx().send()` will return an error. Producers should stop
257    /// producing in this case and return control.
258    pub fn tx(&self) -> Sender<Result<RecordBatch>> {
259        self.inner.tx()
260    }
261
262    /// Spawn task that will be aborted if this builder (or the stream
263    /// built from it) are dropped
264    ///
265    /// This is often used to spawn tasks that write to the sender
266    /// retrieved from [`Self::tx`], for examples, see the document
267    /// of this type.
268    pub fn spawn<F>(&mut self, task: F)
269    where
270        F: Future<Output = Result<()>>,
271        F: Send + 'static,
272    {
273        self.inner.spawn(task)
274    }
275
276    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
277    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
278    where
279        F: Future<Output = Result<()>>,
280        F: Send + 'static,
281    {
282        self.inner.spawn_on(task, handle)
283    }
284
285    /// Spawn a blocking task tied to the builder and stream.
286    ///
287    /// # Drop / Cancel Behavior
288    ///
289    /// If this builder (or the stream built from it) is dropped **before** the
290    /// task starts, the task is also dropped and will never start execute.
291    ///
292    /// **Note:** Once the blocking task has started, it **will not** be
293    /// forcibly stopped on drop as Rust does not allow forcing a running thread
294    /// to terminate. The task will continue running until it completes or
295    /// encounters an error.
296    ///
297    /// Users should ensure that their blocking function periodically checks for
298    /// errors calling `tx.blocking_send`. An error signals that the stream has
299    /// been dropped / cancelled and the blocking task should exit.
300    ///
301    /// This is often used to spawn tasks that write to the sender
302    /// retrieved from [`Self::tx`], for examples, see the document
303    /// of this type.
304    pub fn spawn_blocking<F>(&mut self, f: F)
305    where
306        F: FnOnce() -> Result<()>,
307        F: Send + 'static,
308    {
309        self.inner.spawn_blocking(f)
310    }
311
312    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
313    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
314    where
315        F: FnOnce() -> Result<()>,
316        F: Send + 'static,
317    {
318        self.inner.spawn_blocking_on(f, handle)
319    }
320
321    /// Runs the `partition` of the `input` ExecutionPlan on the
322    /// tokio thread pool and writes its outputs to this stream
323    ///
324    /// If the input partition produces an error, the error will be
325    /// sent to the output stream and no further results are sent.
326    pub(crate) fn run_input(
327        &mut self,
328        input: Arc<dyn ExecutionPlan>,
329        partition: usize,
330        context: Arc<TaskContext>,
331    ) {
332        let output = self.tx();
333
334        self.inner.spawn(async move {
335            let mut stream = match input.execute(partition, context) {
336                Err(e) => {
337                    // If send fails, the plan being torn down, there
338                    // is no place to send the error and no reason to continue.
339                    output.send(Err(e)).await.ok();
340                    debug!(
341                        "Stopping execution: error executing input: {}",
342                        displayable(input.as_ref()).one_line()
343                    );
344                    return Ok(());
345                }
346                Ok(stream) => stream,
347            };
348
349            // Transfer batches from inner stream to the output tx
350            // immediately.
351            while let Some(item) = stream.next().await {
352                let is_err = item.is_err();
353
354                // If send fails, plan being torn down, there is no
355                // place to send the error and no reason to continue.
356                if output.send(item).await.is_err() {
357                    debug!(
358                        "Stopping execution: output is gone, plan cancelling: {}",
359                        displayable(input.as_ref()).one_line()
360                    );
361                    return Ok(());
362                }
363
364                // Stop after the first error is encountered (Don't
365                // drive all streams to completion)
366                if is_err {
367                    debug!(
368                        "Stopping execution: plan returned error: {}",
369                        displayable(input.as_ref()).one_line()
370                    );
371                    return Ok(());
372                }
373            }
374
375            Ok(())
376        });
377    }
378
379    /// Create a stream of all [`RecordBatch`] written to `tx`
380    pub fn build(self) -> SendableRecordBatchStream {
381        Box::pin(RecordBatchStreamAdapter::new(
382            self.schema,
383            self.inner.build(),
384        ))
385    }
386}
387
388#[doc(hidden)]
389pub struct RecordBatchReceiverStream {}
390
391impl RecordBatchReceiverStream {
392    /// Create a builder with an internal buffer of capacity batches.
393    pub fn builder(
394        schema: SchemaRef,
395        capacity: usize,
396    ) -> RecordBatchReceiverStreamBuilder {
397        RecordBatchReceiverStreamBuilder::new(schema, capacity)
398    }
399}
400
401pin_project! {
402    /// Combines a [`Stream`] with a [`SchemaRef`] implementing
403    /// [`SendableRecordBatchStream`] for the combination
404    ///
405    /// See [`Self::new`] for an example
406    pub struct RecordBatchStreamAdapter<S> {
407        schema: SchemaRef,
408
409        #[pin]
410        stream: S,
411    }
412}
413
414impl<S> RecordBatchStreamAdapter<S> {
415    /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
416    ///
417    /// Note to create a [`SendableRecordBatchStream`] you pin the result
418    ///
419    /// # Example
420    /// ```
421    /// # use arrow::array::record_batch;
422    /// # use datafusion_execution::SendableRecordBatchStream;
423    /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
424    /// // Create stream of Result<RecordBatch>
425    /// let batch = record_batch!(
426    ///     ("a", Int32, [1, 2, 3]),
427    ///     ("b", Float64, [Some(4.0), None, Some(5.0)])
428    /// )
429    /// .expect("created batch");
430    /// let schema = batch.schema();
431    /// let stream = futures::stream::iter(vec![Ok(batch)]);
432    /// // Convert the stream to a SendableRecordBatchStream
433    /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
434    /// // Now you can use the adapter as a SendableRecordBatchStream
435    /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
436    /// // ...
437    /// ```
438    pub fn new(schema: SchemaRef, stream: S) -> Self {
439        Self { schema, stream }
440    }
441}
442
443impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
444    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445        f.debug_struct("RecordBatchStreamAdapter")
446            .field("schema", &self.schema)
447            .finish()
448    }
449}
450
451impl<S> Stream for RecordBatchStreamAdapter<S>
452where
453    S: Stream<Item = Result<RecordBatch>>,
454{
455    type Item = Result<RecordBatch>;
456
457    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
458        self.project().stream.poll_next(cx)
459    }
460
461    fn size_hint(&self) -> (usize, Option<usize>) {
462        self.stream.size_hint()
463    }
464}
465
466impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
467where
468    S: Stream<Item = Result<RecordBatch>>,
469{
470    fn schema(&self) -> SchemaRef {
471        Arc::clone(&self.schema)
472    }
473}
474
475/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
476/// that will produce no results
477pub struct EmptyRecordBatchStream {
478    /// Schema wrapped by Arc
479    schema: SchemaRef,
480}
481
482impl EmptyRecordBatchStream {
483    /// Create an empty RecordBatchStream
484    pub fn new(schema: SchemaRef) -> Self {
485        Self { schema }
486    }
487}
488
489impl RecordBatchStream for EmptyRecordBatchStream {
490    fn schema(&self) -> SchemaRef {
491        Arc::clone(&self.schema)
492    }
493}
494
495impl Stream for EmptyRecordBatchStream {
496    type Item = Result<RecordBatch>;
497
498    fn poll_next(
499        self: Pin<&mut Self>,
500        _cx: &mut Context<'_>,
501    ) -> Poll<Option<Self::Item>> {
502        Poll::Ready(None)
503    }
504}
505
506/// Stream wrapper that records `BaselineMetrics` for a particular
507/// `[SendableRecordBatchStream]` (likely a partition)
508pub(crate) struct ObservedStream {
509    inner: SendableRecordBatchStream,
510    baseline_metrics: BaselineMetrics,
511    fetch: Option<usize>,
512    produced: usize,
513}
514
515impl ObservedStream {
516    pub fn new(
517        inner: SendableRecordBatchStream,
518        baseline_metrics: BaselineMetrics,
519        fetch: Option<usize>,
520    ) -> Self {
521        Self {
522            inner,
523            baseline_metrics,
524            fetch,
525            produced: 0,
526        }
527    }
528
529    fn limit_reached(
530        &mut self,
531        poll: Poll<Option<Result<RecordBatch>>>,
532    ) -> Poll<Option<Result<RecordBatch>>> {
533        let Some(fetch) = self.fetch else { return poll };
534
535        if self.produced >= fetch {
536            return Poll::Ready(None);
537        }
538
539        if let Poll::Ready(Some(Ok(batch))) = &poll {
540            if self.produced + batch.num_rows() > fetch {
541                let batch = batch.slice(0, fetch.saturating_sub(self.produced));
542                self.produced += batch.num_rows();
543                return Poll::Ready(Some(Ok(batch)));
544            };
545            self.produced += batch.num_rows()
546        }
547        poll
548    }
549}
550
551impl RecordBatchStream for ObservedStream {
552    fn schema(&self) -> SchemaRef {
553        self.inner.schema()
554    }
555}
556
557impl Stream for ObservedStream {
558    type Item = Result<RecordBatch>;
559
560    fn poll_next(
561        mut self: Pin<&mut Self>,
562        cx: &mut Context<'_>,
563    ) -> Poll<Option<Self::Item>> {
564        let mut poll = self.inner.poll_next_unpin(cx);
565        if self.fetch.is_some() {
566            poll = self.limit_reached(poll);
567        }
568        self.baseline_metrics.record_poll(poll)
569    }
570}
571
572pin_project! {
573    /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches.
574    ///
575    /// This ensures upstream operators receive batches no larger than
576    /// `batch_size`, which can improve parallelism when data sources
577    /// generate very large batches.
578    ///
579    /// # Fields
580    ///
581    /// - `current_batch`: The batch currently being split, if any
582    /// - `offset`: Index of the next row to split from `current_batch`.
583    ///   This tracks our position within the current batch being split.
584    ///
585    /// # Invariants
586    ///
587    /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some`
588    /// - When `current_batch` is `None`, `offset` is always 0
589    /// - `batch_size` is always > 0
590pub struct BatchSplitStream {
591        #[pin]
592        input: SendableRecordBatchStream,
593        schema: SchemaRef,
594        batch_size: usize,
595        metrics: SplitMetrics,
596        current_batch: Option<RecordBatch>,
597        offset: usize,
598    }
599}
600
601impl BatchSplitStream {
602    /// Create a new [`BatchSplitStream`]
603    pub fn new(
604        input: SendableRecordBatchStream,
605        batch_size: usize,
606        metrics: SplitMetrics,
607    ) -> Self {
608        let schema = input.schema();
609        Self {
610            input,
611            schema,
612            batch_size,
613            metrics,
614            current_batch: None,
615            offset: 0,
616        }
617    }
618
619    /// Attempt to produce the next sliced batch from the current batch.
620    ///
621    /// Returns `Some(batch)` if a slice was produced, `None` if the current batch
622    /// is exhausted and we need to poll upstream for more data.
623    fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
624        let batch = self.current_batch.take()?;
625
626        // Assert slice boundary safety - offset should never exceed batch size
627        debug_assert!(
628            self.offset <= batch.num_rows(),
629            "Offset {} exceeds batch size {}",
630            self.offset,
631            batch.num_rows()
632        );
633
634        let remaining = batch.num_rows() - self.offset;
635        let to_take = remaining.min(self.batch_size);
636        let out = batch.slice(self.offset, to_take);
637
638        self.metrics.batches_split.add(1);
639        self.offset += to_take;
640        if self.offset < batch.num_rows() {
641            // More data remains in this batch, store it back
642            self.current_batch = Some(batch);
643        } else {
644            // Batch is exhausted, reset offset
645            // Note: current_batch is already None since we took it at the start
646            self.offset = 0;
647        }
648        Some(Ok(out))
649    }
650
651    /// Poll the upstream input for the next batch.
652    ///
653    /// Returns the appropriate `Poll` result based on upstream state.
654    /// Small batches are passed through directly, large batches are stored
655    /// for slicing and return the first slice immediately.
656    fn poll_upstream(
657        &mut self,
658        cx: &mut Context<'_>,
659    ) -> Poll<Option<Result<RecordBatch>>> {
660        match ready!(self.input.as_mut().poll_next(cx)) {
661            Some(Ok(batch)) => {
662                if batch.num_rows() <= self.batch_size {
663                    // Small batch, pass through directly
664                    Poll::Ready(Some(Ok(batch)))
665                } else {
666                    // Large batch, store for slicing and return first slice
667                    self.current_batch = Some(batch);
668                    // Immediately produce the first slice
669                    match self.next_sliced_batch() {
670                        Some(result) => Poll::Ready(Some(result)),
671                        None => Poll::Ready(None), // Should not happen
672                    }
673                }
674            }
675            Some(Err(e)) => Poll::Ready(Some(Err(e))),
676            None => Poll::Ready(None),
677        }
678    }
679}
680
681impl Stream for BatchSplitStream {
682    type Item = Result<RecordBatch>;
683
684    fn poll_next(
685        mut self: Pin<&mut Self>,
686        cx: &mut Context<'_>,
687    ) -> Poll<Option<Self::Item>> {
688        // First, try to produce a slice from the current batch
689        if let Some(result) = self.next_sliced_batch() {
690            return Poll::Ready(Some(result));
691        }
692
693        // No current batch or current batch exhausted, poll upstream
694        self.poll_upstream(cx)
695    }
696}
697
698impl RecordBatchStream for BatchSplitStream {
699    fn schema(&self) -> SchemaRef {
700        Arc::clone(&self.schema)
701    }
702}
703
704/// A stream that holds a memory reservation for its lifetime,
705/// shrinking the reservation as batches are consumed.
706/// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`]
707/// On error, the reservation is *NOT* freed, until the stream is dropped.
708pub(crate) struct ReservationStream {
709    schema: SchemaRef,
710    inner: SendableRecordBatchStream,
711    reservation: MemoryReservation,
712}
713
714impl ReservationStream {
715    pub(crate) fn new(
716        schema: SchemaRef,
717        inner: SendableRecordBatchStream,
718        reservation: MemoryReservation,
719    ) -> Self {
720        Self {
721            schema,
722            inner,
723            reservation,
724        }
725    }
726}
727
728impl Stream for ReservationStream {
729    type Item = Result<RecordBatch>;
730
731    fn poll_next(
732        mut self: Pin<&mut Self>,
733        cx: &mut Context<'_>,
734    ) -> Poll<Option<Self::Item>> {
735        let res = self.inner.poll_next_unpin(cx);
736
737        match res {
738            Poll::Ready(res) => {
739                match res {
740                    Some(Ok(batch)) => {
741                        self.reservation
742                            .shrink(get_record_batch_memory_size(&batch));
743                        Poll::Ready(Some(Ok(batch)))
744                    }
745                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
746                    None => {
747                        // Stream is done so free the reservation completely
748                        self.reservation.free();
749                        Poll::Ready(None)
750                    }
751                }
752            }
753            Poll::Pending => Poll::Pending,
754        }
755    }
756
757    fn size_hint(&self) -> (usize, Option<usize>) {
758        self.inner.size_hint()
759    }
760}
761
762impl RecordBatchStream for ReservationStream {
763    fn schema(&self) -> SchemaRef {
764        Arc::clone(&self.schema)
765    }
766}
767
768#[cfg(test)]
769mod test {
770    use super::*;
771    use crate::test::exec::{
772        BlockingExec, MockExec, PanicExec, assert_strong_count_converges_to_zero,
773    };
774
775    use arrow::datatypes::{DataType, Field, Schema};
776    use datafusion_common::exec_err;
777
778    fn schema() -> SchemaRef {
779        Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
780    }
781
782    #[tokio::test]
783    #[should_panic(expected = "PanickingStream did panic")]
784    async fn record_batch_receiver_stream_propagates_panics() {
785        let schema = schema();
786
787        let num_partitions = 10;
788        let input = PanicExec::new(Arc::clone(&schema), num_partitions);
789        consume(input, 10).await
790    }
791
792    #[tokio::test]
793    #[should_panic(expected = "PanickingStream did panic: 1")]
794    async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
795        let schema = schema();
796
797        // Make 2 partitions, second partition panics before the first
798        let num_partitions = 2;
799        let input = PanicExec::new(Arc::clone(&schema), num_partitions)
800            .with_partition_panic(0, 10)
801            .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
802
803        // Ensure that the panic results in an early shutdown (that
804        // everything stops after the first panic).
805
806        // Since the stream reads every other batch: (0,1,0,1,0,panic)
807        // so should not exceed 5 batches prior to the panic
808        let max_batches = 5;
809        consume(input, max_batches).await
810    }
811
812    #[tokio::test]
813    async fn record_batch_receiver_stream_drop_cancel() {
814        let task_ctx = Arc::new(TaskContext::default());
815        let schema = schema();
816
817        // Make an input that never proceeds
818        let input = BlockingExec::new(Arc::clone(&schema), 1);
819        let refs = input.refs();
820
821        // Configure a RecordBatchReceiverStream to consume the input
822        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
823        builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
824        let stream = builder.build();
825
826        // Input should still be present
827        assert!(std::sync::Weak::strong_count(&refs) > 0);
828
829        // Drop the stream, ensure the refs go to zero
830        drop(stream);
831        assert_strong_count_converges_to_zero(refs).await;
832    }
833
834    #[tokio::test]
835    /// Ensure that if an error is received in one stream, the
836    /// `RecordBatchReceiverStream` stops early and does not drive
837    /// other streams to completion.
838    async fn record_batch_receiver_stream_error_does_not_drive_completion() {
839        let task_ctx = Arc::new(TaskContext::default());
840        let schema = schema();
841
842        // make an input that will error twice
843        let error_stream = MockExec::new(
844            vec![exec_err!("Test1"), exec_err!("Test2")],
845            Arc::clone(&schema),
846        )
847        .with_use_task(false);
848
849        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
850        builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
851        let mut stream = builder.build();
852
853        // Get the first result, which should be an error
854        let first_batch = stream.next().await.unwrap();
855        let first_err = first_batch.unwrap_err();
856        assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
857
858        // There should be no more batches produced (should not get the second error)
859        assert!(stream.next().await.is_none());
860    }
861
862    #[tokio::test]
863    async fn batch_split_stream_basic_functionality() {
864        use arrow::array::{Int32Array, RecordBatch};
865        use futures::stream::{self, StreamExt};
866
867        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
868
869        // Create a large batch that should be split
870        let large_batch = RecordBatch::try_new(
871            Arc::clone(&schema),
872            vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
873        )
874        .unwrap();
875
876        // Create a stream with the large batch
877        let input_stream = stream::iter(vec![Ok(large_batch)]);
878        let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
879        let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
880
881        // Create a BatchSplitStream with batch_size = 500
882        let metrics = ExecutionPlanMetricsSet::new();
883        let split_metrics = SplitMetrics::new(&metrics, 0);
884        let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
885
886        let mut total_rows = 0;
887        let mut batch_count = 0;
888
889        while let Some(result) = split_stream.next().await {
890            let batch = result.unwrap();
891            assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
892            total_rows += batch.num_rows();
893            batch_count += 1;
894        }
895
896        assert_eq!(total_rows, 2000, "All rows should be preserved");
897        assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
898    }
899
900    /// Consumes all the input's partitions into a
901    /// RecordBatchReceiverStream and runs it to completion
902    ///
903    /// panic's if more than max_batches is seen,
904    async fn consume(input: PanicExec, max_batches: usize) {
905        let task_ctx = Arc::new(TaskContext::default());
906
907        let input = Arc::new(input);
908        let num_partitions = input.properties().output_partitioning().partition_count();
909
910        // Configure a RecordBatchReceiverStream to consume all the input partitions
911        let mut builder =
912            RecordBatchReceiverStream::builder(input.schema(), num_partitions);
913        for partition in 0..num_partitions {
914            builder.run_input(
915                Arc::clone(&input) as Arc<dyn ExecutionPlan>,
916                partition,
917                Arc::clone(&task_ctx),
918            );
919        }
920        let mut stream = builder.build();
921
922        // Drain the stream until it is complete, panic'ing on error
923        let mut num_batches = 0;
924        while let Some(next) = stream.next().await {
925            next.unwrap();
926            num_batches += 1;
927            assert!(
928                num_batches < max_batches,
929                "Got the limit of {num_batches} batches before seeing panic"
930            );
931        }
932    }
933
934    #[test]
935    fn record_batch_receiver_stream_builder_spawn_on_runtime() {
936        let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
937            .enable_all()
938            .build()
939            .unwrap();
940
941        let mut builder =
942            RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
943
944        let tx1 = builder.tx();
945        builder.spawn_on(
946            async move {
947                tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
948                    .await
949                    .unwrap();
950
951                Ok(())
952            },
953            tokio_runtime.handle(),
954        );
955
956        let tx2 = builder.tx();
957        builder.spawn_blocking_on(
958            move || {
959                tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
960                    .unwrap();
961
962                Ok(())
963            },
964            tokio_runtime.handle(),
965        );
966
967        let mut stream = builder.build();
968
969        let mut number_of_batches = 0;
970
971        loop {
972            let poll = stream.poll_next_unpin(&mut Context::from_waker(
973                futures::task::noop_waker_ref(),
974            ));
975
976            match poll {
977                Poll::Ready(None) => {
978                    break;
979                }
980                Poll::Ready(Some(Ok(batch))) => {
981                    number_of_batches += 1;
982                    assert_eq!(batch.num_rows(), 0);
983                }
984                Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
985                Poll::Pending => {
986                    continue;
987                }
988            }
989        }
990
991        assert_eq!(
992            number_of_batches, 2,
993            "Should have received exactly two empty batches"
994        );
995    }
996
997    #[tokio::test]
998    async fn test_reservation_stream_shrinks_on_poll() {
999        use arrow::array::Int32Array;
1000        use datafusion_execution::memory_pool::MemoryConsumer;
1001        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1002
1003        let runtime = RuntimeEnvBuilder::new()
1004            .with_memory_limit(10 * 1024 * 1024, 1.0)
1005            .build_arc()
1006            .unwrap();
1007
1008        let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
1009
1010        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1011
1012        // Create batches
1013        let batch1 = RecordBatch::try_new(
1014            Arc::clone(&schema),
1015            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
1016        )
1017        .unwrap();
1018        let batch2 = RecordBatch::try_new(
1019            Arc::clone(&schema),
1020            vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))],
1021        )
1022        .unwrap();
1023
1024        let batch1_size = get_record_batch_memory_size(&batch1);
1025        let batch2_size = get_record_batch_memory_size(&batch2);
1026
1027        // Reserve memory upfront
1028        reservation.try_grow(batch1_size + batch2_size).unwrap();
1029        let initial_reserved = runtime.memory_pool.reserved();
1030        assert_eq!(initial_reserved, batch1_size + batch2_size);
1031
1032        // Create stream with batches
1033        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1034        let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
1035            as SendableRecordBatchStream;
1036
1037        let mut res_stream =
1038            ReservationStream::new(Arc::clone(&schema), inner, reservation);
1039
1040        // Poll first batch
1041        let result1 = res_stream.next().await;
1042        assert!(result1.is_some());
1043
1044        // Memory should be reduced by batch1_size
1045        let after_first = runtime.memory_pool.reserved();
1046        assert_eq!(after_first, batch2_size);
1047
1048        // Poll second batch
1049        let result2 = res_stream.next().await;
1050        assert!(result2.is_some());
1051
1052        // Memory should be reduced by batch2_size
1053        let after_second = runtime.memory_pool.reserved();
1054        assert_eq!(after_second, 0);
1055
1056        // Poll None (end of stream)
1057        let result3 = res_stream.next().await;
1058        assert!(result3.is_none());
1059
1060        // Memory should still be 0
1061        assert_eq!(runtime.memory_pool.reserved(), 0);
1062    }
1063
1064    #[tokio::test]
1065    async fn test_reservation_stream_error_handling() {
1066        use datafusion_execution::memory_pool::MemoryConsumer;
1067        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1068
1069        let runtime = RuntimeEnvBuilder::new()
1070            .with_memory_limit(10 * 1024 * 1024, 1.0)
1071            .build_arc()
1072            .unwrap();
1073
1074        let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
1075
1076        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1077
1078        reservation.try_grow(1000).unwrap();
1079        let initial = runtime.memory_pool.reserved();
1080        assert_eq!(initial, 1000);
1081
1082        // Create a stream that errors
1083        let stream = futures::stream::iter(vec![exec_err!("Test error")]);
1084        let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
1085            as SendableRecordBatchStream;
1086
1087        let mut res_stream =
1088            ReservationStream::new(Arc::clone(&schema), inner, reservation);
1089
1090        // Get the error
1091        let result = res_stream.next().await;
1092        assert!(result.is_some());
1093        assert!(result.unwrap().is_err());
1094
1095        // Verify reservation is NOT automatically freed on error
1096        // The reservation is only freed when poll_next returns Poll::Ready(None)
1097        // After an error, the stream may continue to hold the reservation
1098        // until it's explicitly dropped or polled to None
1099        let after_error = runtime.memory_pool.reserved();
1100        assert_eq!(
1101            after_error, 1000,
1102            "Reservation should still be held after error"
1103        );
1104
1105        // Drop the stream to free the reservation
1106        drop(res_stream);
1107
1108        // Now memory should be freed
1109        assert_eq!(
1110            runtime.memory_pool.reserved(),
1111            0,
1112            "Memory should be freed when stream is dropped"
1113        );
1114    }
1115}