Skip to main content

datafusion_physical_plan/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30use crate::spill::get_record_batch_memory_size;
31
32use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
33use datafusion_common::{Result, exec_err};
34use datafusion_common_runtime::JoinSet;
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37
38use futures::ready;
39use futures::stream::BoxStream;
40use futures::{Future, Stream, StreamExt};
41use log::debug;
42use pin_project_lite::pin_project;
43use tokio::runtime::Handle;
44use tokio::sync::mpsc::{Receiver, Sender};
45
46/// Creates a stream from a collection of producing tasks, routing panics to the stream.
47///
48/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with the differences being:
49///
50/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
51///
52/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
53///
54/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
55///
56/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
57pub(crate) struct ReceiverStreamBuilder<O> {
58    tx: Sender<Result<O>>,
59    rx: Receiver<Result<O>>,
60    join_set: JoinSet<Result<()>>,
61}
62
63impl<O: Send + 'static> ReceiverStreamBuilder<O> {
64    /// Create new channels with the specified buffer size
65    pub fn new(capacity: usize) -> Self {
66        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
67
68        Self {
69            tx,
70            rx,
71            join_set: JoinSet::new(),
72        }
73    }
74
75    /// Get a handle for sending data to the output
76    pub fn tx(&self) -> Sender<Result<O>> {
77        self.tx.clone()
78    }
79
80    /// Spawn task that will be aborted if this builder (or the stream
81    /// built from it) are dropped
82    pub fn spawn<F>(&mut self, task: F)
83    where
84        F: Future<Output = Result<()>>,
85        F: Send + 'static,
86    {
87        self.join_set.spawn(task);
88    }
89
90    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime
91    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
92    where
93        F: Future<Output = Result<()>>,
94        F: Send + 'static,
95    {
96        self.join_set.spawn_on(task, handle);
97    }
98
99    /// Spawn a blocking task that will be aborted if this builder (or the stream
100    /// built from it) are dropped.
101    ///
102    /// This is often used to spawn tasks that write to the sender
103    /// retrieved from `Self::tx`.
104    pub fn spawn_blocking<F>(&mut self, f: F)
105    where
106        F: FnOnce() -> Result<()>,
107        F: Send + 'static,
108    {
109        self.join_set.spawn_blocking(f);
110    }
111
112    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
113    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
114    where
115        F: FnOnce() -> Result<()>,
116        F: Send + 'static,
117    {
118        self.join_set.spawn_blocking_on(f, handle);
119    }
120
121    /// Create a stream of all data written to `tx`
122    pub fn build(self) -> BoxStream<'static, Result<O>> {
123        let Self {
124            tx,
125            rx,
126            mut join_set,
127        } = self;
128
129        // Doesn't need tx
130        drop(tx);
131
132        // future that checks the result of the join set, and propagates panic if seen
133        let check = async move {
134            while let Some(result) = join_set.join_next().await {
135                match result {
136                    Ok(task_result) => {
137                        match task_result {
138                            // Nothing to report
139                            Ok(_) => continue,
140                            // This means a blocking task error
141                            Err(error) => return Some(Err(error)),
142                        }
143                    }
144                    // This means a tokio task error, likely a panic
145                    Err(e) => {
146                        if e.is_panic() {
147                            // resume on the main thread
148                            std::panic::resume_unwind(e.into_panic());
149                        } else {
150                            // This should only occur if the task is
151                            // cancelled, which would only occur if
152                            // the JoinSet were aborted, which in turn
153                            // would imply that the receiver has been
154                            // dropped and this code is not running
155                            return Some(exec_err!("Non Panic Task error: {e}"));
156                        }
157                    }
158                }
159            }
160            None
161        };
162
163        let check_stream = futures::stream::once(check)
164            // unwrap Option / only return the error
165            .filter_map(|item| async move { item });
166
167        // Convert the receiver into a stream
168        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
169            let next_item = rx.recv().await;
170            next_item.map(|next_item| (next_item, rx))
171        });
172
173        // Merge the streams together so whichever is ready first
174        // produces the batch
175        futures::stream::select(rx_stream, check_stream).boxed()
176    }
177}
178
179/// Builder for `RecordBatchReceiverStream` that propagates errors
180/// and panic's correctly.
181///
182/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
183/// that produce [`RecordBatch`]es and send them to a single
184/// `Receiver` which can improve parallelism.
185///
186/// This also handles propagating panic`s and canceling the tasks.
187///
188/// # Example
189///
190/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
191/// the `tx` end of the builder, after building the stream, we can receive
192/// those batches with calling `.next()`
193///
194/// ```
195/// # use std::sync::Arc;
196/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
197/// # use datafusion_common::arrow::array::RecordBatch;
198/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
199/// # use futures::stream::StreamExt;
200/// # use tokio::runtime::Builder;
201/// # let rt = Builder::new_current_thread().build().unwrap();
202/// #
203/// # rt.block_on(async {
204/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
205/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
206///
207/// // task 1
208/// let tx_1 = builder.tx();
209/// let schema_1 = Arc::clone(&schema);
210/// builder.spawn(async move {
211///     // Your task needs to send batches to the tx
212///     tx_1.send(Ok(RecordBatch::new_empty(schema_1)))
213///         .await
214///         .unwrap();
215///
216///     Ok(())
217/// });
218///
219/// // task 2
220/// let tx_2 = builder.tx();
221/// let schema_2 = Arc::clone(&schema);
222/// builder.spawn(async move {
223///     // Your task needs to send batches to the tx
224///     tx_2.send(Ok(RecordBatch::new_empty(schema_2)))
225///         .await
226///         .unwrap();
227///
228///     Ok(())
229/// });
230///
231/// let mut stream = builder.build();
232/// while let Some(res_batch) = stream.next().await {
233///     // `res_batch` can either from task 1 or 2
234///
235///     // do something with `res_batch`
236/// }
237/// # });
238/// ```
239pub struct RecordBatchReceiverStreamBuilder {
240    schema: SchemaRef,
241    inner: ReceiverStreamBuilder<RecordBatch>,
242}
243
244impl RecordBatchReceiverStreamBuilder {
245    /// Create new channels with the specified buffer size
246    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
247        Self {
248            schema,
249            inner: ReceiverStreamBuilder::new(capacity),
250        }
251    }
252
253    /// Get a handle for sending [`RecordBatch`] to the output
254    ///
255    /// If the stream is dropped / canceled, the sender will be closed and
256    /// calling `tx().send()` will return an error. Producers should stop
257    /// producing in this case and return control.
258    pub fn tx(&self) -> Sender<Result<RecordBatch>> {
259        self.inner.tx()
260    }
261
262    /// Spawn task that will be aborted if this builder (or the stream
263    /// built from it) are dropped
264    ///
265    /// This is often used to spawn tasks that write to the sender
266    /// retrieved from [`Self::tx`], for examples, see the document
267    /// of this type.
268    pub fn spawn<F>(&mut self, task: F)
269    where
270        F: Future<Output = Result<()>>,
271        F: Send + 'static,
272    {
273        self.inner.spawn(task)
274    }
275
276    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
277    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
278    where
279        F: Future<Output = Result<()>>,
280        F: Send + 'static,
281    {
282        self.inner.spawn_on(task, handle)
283    }
284
285    /// Spawn a blocking task tied to the builder and stream.
286    ///
287    /// # Drop / Cancel Behavior
288    ///
289    /// If this builder (or the stream built from it) is dropped **before** the
290    /// task starts, the task is also dropped and will never start execute.
291    ///
292    /// **Note:** Once the blocking task has started, it **will not** be
293    /// forcibly stopped on drop as Rust does not allow forcing a running thread
294    /// to terminate. The task will continue running until it completes or
295    /// encounters an error.
296    ///
297    /// Users should ensure that their blocking function periodically checks for
298    /// errors calling `tx.blocking_send`. An error signals that the stream has
299    /// been dropped / cancelled and the blocking task should exit.
300    ///
301    /// This is often used to spawn tasks that write to the sender
302    /// retrieved from [`Self::tx`], for examples, see the document
303    /// of this type.
304    pub fn spawn_blocking<F>(&mut self, f: F)
305    where
306        F: FnOnce() -> Result<()>,
307        F: Send + 'static,
308    {
309        self.inner.spawn_blocking(f)
310    }
311
312    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
313    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
314    where
315        F: FnOnce() -> Result<()>,
316        F: Send + 'static,
317    {
318        self.inner.spawn_blocking_on(f, handle)
319    }
320
321    /// Runs the `partition` of the `input` ExecutionPlan on the
322    /// tokio thread pool and writes its outputs to this stream
323    ///
324    /// If the input partition produces an error, the error will be
325    /// sent to the output stream and no further results are sent.
326    pub(crate) fn run_input(
327        &mut self,
328        input: Arc<dyn ExecutionPlan>,
329        partition: usize,
330        context: Arc<TaskContext>,
331    ) {
332        let output = self.tx();
333        let input_display = if log::log_enabled!(log::Level::Debug) {
334            displayable(input.as_ref()).one_line().to_string()
335        } else {
336            String::new()
337        };
338
339        self.inner.spawn(async move {
340            let mut stream = match input.execute(partition, context) {
341                Err(e) => {
342                    // If send fails, the plan being torn down, there
343                    // is no place to send the error and no reason to continue.
344                    output.send(Err(e)).await.ok();
345                    debug!(
346                        "Stopping execution: error executing input: {input_display}",
347                    );
348                    return Ok(());
349                }
350                Ok(stream) => stream,
351            };
352
353            // Drop the input early, as soon as we're done with it.
354            // Holding on to it can cause delays in cancelling the child plan when the query is
355            // cancelled.
356            drop(input);
357
358            // Transfer batches from inner stream to the output tx
359            // immediately.
360            while let Some(item) = stream.next().await {
361                let is_err = item.is_err();
362
363                // If send fails, plan being torn down, there is no
364                // place to send the error and no reason to continue.
365                if output.send(item).await.is_err() {
366                    debug!(
367                        "Stopping execution: output is gone, plan cancelling: {input_display}",
368                    );
369                    return Ok(());
370                }
371
372                // Stop after the first error is encountered (Don't
373                // drive all streams to completion)
374                if is_err {
375                    debug!("Stopping execution: plan returned error: {input_display}");
376                    return Ok(());
377                }
378            }
379
380            Ok(())
381        });
382    }
383
384    /// Create a stream of all [`RecordBatch`] written to `tx`
385    pub fn build(self) -> SendableRecordBatchStream {
386        Box::pin(RecordBatchStreamAdapter::new(
387            self.schema,
388            self.inner.build(),
389        ))
390    }
391}
392
393#[doc(hidden)]
394pub struct RecordBatchReceiverStream {}
395
396impl RecordBatchReceiverStream {
397    /// Create a builder with an internal buffer of capacity batches.
398    pub fn builder(
399        schema: SchemaRef,
400        capacity: usize,
401    ) -> RecordBatchReceiverStreamBuilder {
402        RecordBatchReceiverStreamBuilder::new(schema, capacity)
403    }
404}
405
406pin_project! {
407    /// Combines a [`Stream`] with a [`SchemaRef`] implementing
408    /// [`SendableRecordBatchStream`] for the combination
409    ///
410    /// See [`Self::new`] for an example
411    pub struct RecordBatchStreamAdapter<S> {
412        schema: SchemaRef,
413
414        // Wrapped in Option so we can drop the inner stream as soon as it
415        // returns `None`, releasing any upstream pipeline resources before the
416        // adapter itself is dropped.
417        #[pin]
418        stream: Option<S>,
419    }
420}
421
422impl<S> RecordBatchStreamAdapter<S> {
423    /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
424    ///
425    /// Note to create a [`SendableRecordBatchStream`] you pin the result
426    ///
427    /// # Example
428    /// ```
429    /// # use arrow::array::record_batch;
430    /// # use datafusion_execution::SendableRecordBatchStream;
431    /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
432    /// // Create stream of Result<RecordBatch>
433    /// let batch = record_batch!(
434    ///     ("a", Int32, [1, 2, 3]),
435    ///     ("b", Float64, [Some(4.0), None, Some(5.0)])
436    /// )
437    /// .expect("created batch");
438    /// let schema = batch.schema();
439    /// let stream = futures::stream::iter(vec![Ok(batch)]);
440    /// // Convert the stream to a SendableRecordBatchStream
441    /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
442    /// // Now you can use the adapter as a SendableRecordBatchStream
443    /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
444    /// // ...
445    /// ```
446    pub fn new(schema: SchemaRef, stream: S) -> Self {
447        Self {
448            schema,
449            stream: Some(stream),
450        }
451    }
452}
453
454impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456        f.debug_struct("RecordBatchStreamAdapter")
457            .field("schema", &self.schema)
458            .finish()
459    }
460}
461
462impl<S> Stream for RecordBatchStreamAdapter<S>
463where
464    S: Stream<Item = Result<RecordBatch>>,
465{
466    type Item = Result<RecordBatch>;
467
468    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
469        let mut this = self.project();
470        let Some(inner) = this.stream.as_mut().as_pin_mut() else {
471            return Poll::Ready(None);
472        };
473        let item = ready!(inner.poll_next(cx));
474        if item.is_none() {
475            // Drop the inner stream in place to release its resources.
476            // SAFETY: the inner stream is dropped without moving it out of
477            // its pinned location; assigning `None` only runs the inner
478            // value's destructor in place, which is permitted for pinned
479            // values.
480            unsafe {
481                *this.stream.as_mut().get_unchecked_mut() = None;
482            }
483        }
484        Poll::Ready(item)
485    }
486
487    fn size_hint(&self) -> (usize, Option<usize>) {
488        match self.stream.as_ref() {
489            Some(stream) => stream.size_hint(),
490            None => (0, Some(0)),
491        }
492    }
493}
494
495impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
496where
497    S: Stream<Item = Result<RecordBatch>>,
498{
499    fn schema(&self) -> SchemaRef {
500        Arc::clone(&self.schema)
501    }
502}
503
504/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
505/// that will produce no results
506pub struct EmptyRecordBatchStream {
507    /// Schema wrapped by Arc
508    schema: SchemaRef,
509}
510
511impl EmptyRecordBatchStream {
512    /// Create an empty RecordBatchStream
513    pub fn new(schema: SchemaRef) -> Self {
514        Self { schema }
515    }
516}
517
518impl RecordBatchStream for EmptyRecordBatchStream {
519    fn schema(&self) -> SchemaRef {
520        Arc::clone(&self.schema)
521    }
522}
523
524impl Stream for EmptyRecordBatchStream {
525    type Item = Result<RecordBatch>;
526
527    fn poll_next(
528        self: Pin<&mut Self>,
529        _cx: &mut Context<'_>,
530    ) -> Poll<Option<Self::Item>> {
531        Poll::Ready(None)
532    }
533}
534
535/// Stream wrapper that records `BaselineMetrics` for a particular
536/// `[SendableRecordBatchStream]` (likely a partition)
537pub(crate) struct ObservedStream {
538    inner: SendableRecordBatchStream,
539    baseline_metrics: BaselineMetrics,
540    fetch: Option<usize>,
541    produced: usize,
542}
543
544impl ObservedStream {
545    pub fn new(
546        inner: SendableRecordBatchStream,
547        baseline_metrics: BaselineMetrics,
548        fetch: Option<usize>,
549    ) -> Self {
550        Self {
551            inner,
552            baseline_metrics,
553            fetch,
554            produced: 0,
555        }
556    }
557
558    fn limit_reached(
559        &mut self,
560        poll: Poll<Option<Result<RecordBatch>>>,
561    ) -> Poll<Option<Result<RecordBatch>>> {
562        let Some(fetch) = self.fetch else { return poll };
563
564        if self.produced >= fetch {
565            self.release_inner();
566            return Poll::Ready(None);
567        }
568
569        if let Poll::Ready(Some(Ok(batch))) = &poll {
570            if self.produced + batch.num_rows() > fetch {
571                let batch = batch.slice(0, fetch.saturating_sub(self.produced));
572                self.produced += batch.num_rows();
573                if self.produced >= fetch {
574                    self.release_inner();
575                }
576                return Poll::Ready(Some(Ok(batch)));
577            };
578            self.produced += batch.num_rows()
579        }
580        poll
581    }
582
583    /// Replace the inner stream with an [`EmptyRecordBatchStream`], dropping
584    /// the original stream so its upstream pipeline can be torn down.
585    fn release_inner(&mut self) {
586        let schema = self.inner.schema();
587        self.inner = Box::pin(EmptyRecordBatchStream::new(schema));
588    }
589}
590
591impl RecordBatchStream for ObservedStream {
592    fn schema(&self) -> SchemaRef {
593        self.inner.schema()
594    }
595}
596
597impl Stream for ObservedStream {
598    type Item = Result<RecordBatch>;
599
600    fn poll_next(
601        mut self: Pin<&mut Self>,
602        cx: &mut Context<'_>,
603    ) -> Poll<Option<Self::Item>> {
604        let mut poll = self.inner.poll_next_unpin(cx);
605        if self.fetch.is_some() {
606            poll = self.limit_reached(poll);
607        }
608        self.baseline_metrics.record_poll(poll)
609    }
610}
611
612pin_project! {
613    /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches.
614    ///
615    /// This ensures upstream operators receive batches no larger than
616    /// `batch_size`, which can improve parallelism when data sources
617    /// generate very large batches.
618    ///
619    /// # Fields
620    ///
621    /// - `current_batch`: The batch currently being split, if any
622    /// - `offset`: Index of the next row to split from `current_batch`.
623    ///   This tracks our position within the current batch being split.
624    ///
625    /// # Invariants
626    ///
627    /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some`
628    /// - When `current_batch` is `None`, `offset` is always 0
629    /// - `batch_size` is always > 0
630pub struct BatchSplitStream {
631        #[pin]
632        input: SendableRecordBatchStream,
633        schema: SchemaRef,
634        batch_size: usize,
635        metrics: SplitMetrics,
636        current_batch: Option<RecordBatch>,
637        offset: usize,
638    }
639}
640
641impl BatchSplitStream {
642    /// Create a new [`BatchSplitStream`]
643    pub fn new(
644        input: SendableRecordBatchStream,
645        batch_size: usize,
646        metrics: SplitMetrics,
647    ) -> Self {
648        let schema = input.schema();
649        Self {
650            input,
651            schema,
652            batch_size,
653            metrics,
654            current_batch: None,
655            offset: 0,
656        }
657    }
658
659    /// Attempt to produce the next sliced batch from the current batch.
660    ///
661    /// Returns `Some(batch)` if a slice was produced, `None` if the current batch
662    /// is exhausted and we need to poll upstream for more data.
663    fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
664        let batch = self.current_batch.take()?;
665
666        // Assert slice boundary safety - offset should never exceed batch size
667        debug_assert!(
668            self.offset <= batch.num_rows(),
669            "Offset {} exceeds batch size {}",
670            self.offset,
671            batch.num_rows()
672        );
673
674        let remaining = batch.num_rows() - self.offset;
675        let to_take = remaining.min(self.batch_size);
676        let out = batch.slice(self.offset, to_take);
677
678        self.metrics.batches_split.add(1);
679        self.offset += to_take;
680        if self.offset < batch.num_rows() {
681            // More data remains in this batch, store it back
682            self.current_batch = Some(batch);
683        } else {
684            // Batch is exhausted, reset offset
685            // Note: current_batch is already None since we took it at the start
686            self.offset = 0;
687        }
688        Some(Ok(out))
689    }
690
691    /// Poll the upstream input for the next batch.
692    ///
693    /// Returns the appropriate `Poll` result based on upstream state.
694    /// Small batches are passed through directly, large batches are stored
695    /// for slicing and return the first slice immediately.
696    fn poll_upstream(
697        &mut self,
698        cx: &mut Context<'_>,
699    ) -> Poll<Option<Result<RecordBatch>>> {
700        match ready!(self.input.as_mut().poll_next(cx)) {
701            Some(Ok(batch)) => {
702                if batch.num_rows() <= self.batch_size {
703                    // Small batch, pass through directly
704                    Poll::Ready(Some(Ok(batch)))
705                } else {
706                    // Large batch, store for slicing and return first slice
707                    self.current_batch = Some(batch);
708                    // Immediately produce the first slice
709                    match self.next_sliced_batch() {
710                        Some(result) => Poll::Ready(Some(result)),
711                        None => Poll::Ready(None), // Should not happen
712                    }
713                }
714            }
715            Some(Err(e)) => Poll::Ready(Some(Err(e))),
716            None => {
717                // Release the input pipeline's resources.
718                let input_schema = self.input.schema();
719                self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
720                Poll::Ready(None)
721            }
722        }
723    }
724}
725
726impl Stream for BatchSplitStream {
727    type Item = Result<RecordBatch>;
728
729    fn poll_next(
730        mut self: Pin<&mut Self>,
731        cx: &mut Context<'_>,
732    ) -> Poll<Option<Self::Item>> {
733        // First, try to produce a slice from the current batch
734        if let Some(result) = self.next_sliced_batch() {
735            return Poll::Ready(Some(result));
736        }
737
738        // No current batch or current batch exhausted, poll upstream
739        self.poll_upstream(cx)
740    }
741}
742
743impl RecordBatchStream for BatchSplitStream {
744    fn schema(&self) -> SchemaRef {
745        Arc::clone(&self.schema)
746    }
747}
748
749/// A stream that holds a memory reservation for its lifetime,
750/// shrinking the reservation as batches are consumed.
751/// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`]
752/// On error, the reservation is *NOT* freed, until the stream is dropped.
753pub(crate) struct ReservationStream {
754    schema: SchemaRef,
755    inner: SendableRecordBatchStream,
756    reservation: MemoryReservation,
757}
758
759impl ReservationStream {
760    pub(crate) fn new(
761        schema: SchemaRef,
762        inner: SendableRecordBatchStream,
763        reservation: MemoryReservation,
764    ) -> Self {
765        Self {
766            schema,
767            inner,
768            reservation,
769        }
770    }
771}
772
773impl Stream for ReservationStream {
774    type Item = Result<RecordBatch>;
775
776    fn poll_next(
777        mut self: Pin<&mut Self>,
778        cx: &mut Context<'_>,
779    ) -> Poll<Option<Self::Item>> {
780        let res = self.inner.poll_next_unpin(cx);
781
782        match res {
783            Poll::Ready(res) => {
784                match res {
785                    Some(Ok(batch)) => {
786                        self.reservation
787                            .shrink(get_record_batch_memory_size(&batch));
788                        Poll::Ready(Some(Ok(batch)))
789                    }
790                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
791                    None => {
792                        // Stream is done so free the reservation completely
793                        self.reservation.free();
794                        // Release the input pipeline's resources.
795                        let inner_schema = self.inner.schema();
796                        self.inner = Box::pin(EmptyRecordBatchStream::new(inner_schema));
797                        Poll::Ready(None)
798                    }
799                }
800            }
801            Poll::Pending => Poll::Pending,
802        }
803    }
804
805    fn size_hint(&self) -> (usize, Option<usize>) {
806        self.inner.size_hint()
807    }
808}
809
810impl RecordBatchStream for ReservationStream {
811    fn schema(&self) -> SchemaRef {
812        Arc::clone(&self.schema)
813    }
814}
815
816#[cfg(test)]
817mod test {
818    use super::*;
819    use crate::test::exec::{
820        BlockingExec, MockExec, PanicExec, assert_strong_count_converges_to_zero,
821    };
822
823    use arrow::datatypes::{DataType, Field, Schema};
824    use datafusion_common::exec_err;
825
826    fn schema() -> SchemaRef {
827        Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
828    }
829
830    #[tokio::test]
831    #[should_panic(expected = "PanickingStream did panic")]
832    async fn record_batch_receiver_stream_propagates_panics() {
833        let schema = schema();
834
835        let num_partitions = 10;
836        let input = PanicExec::new(Arc::clone(&schema), num_partitions);
837        consume(input, 10).await
838    }
839
840    #[tokio::test]
841    #[should_panic(expected = "PanickingStream did panic: 1")]
842    async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
843        let schema = schema();
844
845        // Make 2 partitions, second partition panics before the first
846        let num_partitions = 2;
847        let input = PanicExec::new(Arc::clone(&schema), num_partitions)
848            .with_partition_panic(0, 10)
849            .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
850
851        // Ensure that the panic results in an early shutdown (that
852        // everything stops after the first panic).
853
854        // Since the stream reads every other batch: (0,1,0,1,0,panic)
855        // so should not exceed 5 batches prior to the panic
856        let max_batches = 5;
857        consume(input, max_batches).await
858    }
859
860    #[tokio::test]
861    async fn record_batch_receiver_stream_drop_cancel() {
862        let task_ctx = Arc::new(TaskContext::default());
863        let schema = schema();
864
865        // Make an input that never proceeds
866        let input = BlockingExec::new(Arc::clone(&schema), 1);
867        let refs = input.refs();
868
869        // Configure a RecordBatchReceiverStream to consume the input
870        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
871        builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
872        let stream = builder.build();
873
874        // Input should still be present
875        assert!(std::sync::Weak::strong_count(&refs) > 0);
876
877        // Drop the stream, ensure the refs go to zero
878        drop(stream);
879        assert_strong_count_converges_to_zero(refs).await;
880    }
881
882    #[tokio::test]
883    /// Ensure that if an error is received in one stream, the
884    /// `RecordBatchReceiverStream` stops early and does not drive
885    /// other streams to completion.
886    async fn record_batch_receiver_stream_error_does_not_drive_completion() {
887        let task_ctx = Arc::new(TaskContext::default());
888        let schema = schema();
889
890        // make an input that will error twice
891        let error_stream = MockExec::new(
892            vec![exec_err!("Test1"), exec_err!("Test2")],
893            Arc::clone(&schema),
894        )
895        .with_use_task(false);
896
897        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
898        builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
899        let mut stream = builder.build();
900
901        // Get the first result, which should be an error
902        let first_batch = stream.next().await.unwrap();
903        let first_err = first_batch.unwrap_err();
904        assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
905
906        // There should be no more batches produced (should not get the second error)
907        assert!(stream.next().await.is_none());
908    }
909
910    #[tokio::test]
911    async fn batch_split_stream_basic_functionality() {
912        use arrow::array::{Int32Array, RecordBatch};
913        use futures::stream::{self, StreamExt};
914
915        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
916
917        // Create a large batch that should be split
918        let large_batch = RecordBatch::try_new(
919            Arc::clone(&schema),
920            vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
921        )
922        .unwrap();
923
924        // Create a stream with the large batch
925        let input_stream = stream::iter(vec![Ok(large_batch)]);
926        let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
927        let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
928
929        // Create a BatchSplitStream with batch_size = 500
930        let metrics = ExecutionPlanMetricsSet::new();
931        let split_metrics = SplitMetrics::new(&metrics, 0);
932        let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
933
934        let mut total_rows = 0;
935        let mut batch_count = 0;
936
937        while let Some(result) = split_stream.next().await {
938            let batch = result.unwrap();
939            assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
940            total_rows += batch.num_rows();
941            batch_count += 1;
942        }
943
944        assert_eq!(total_rows, 2000, "All rows should be preserved");
945        assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
946    }
947
948    /// Consumes all the input's partitions into a
949    /// RecordBatchReceiverStream and runs it to completion
950    ///
951    /// panic's if more than max_batches is seen,
952    async fn consume(input: PanicExec, max_batches: usize) {
953        let task_ctx = Arc::new(TaskContext::default());
954
955        let input = Arc::new(input);
956        let num_partitions = input.properties().output_partitioning().partition_count();
957
958        // Configure a RecordBatchReceiverStream to consume all the input partitions
959        let mut builder =
960            RecordBatchReceiverStream::builder(input.schema(), num_partitions);
961        for partition in 0..num_partitions {
962            builder.run_input(
963                Arc::clone(&input) as Arc<dyn ExecutionPlan>,
964                partition,
965                Arc::clone(&task_ctx),
966            );
967        }
968        let mut stream = builder.build();
969
970        // Drain the stream until it is complete, panic'ing on error
971        let mut num_batches = 0;
972        while let Some(next) = stream.next().await {
973            next.unwrap();
974            num_batches += 1;
975            assert!(
976                num_batches < max_batches,
977                "Got the limit of {num_batches} batches before seeing panic"
978            );
979        }
980    }
981
982    #[test]
983    fn record_batch_receiver_stream_builder_spawn_on_runtime() {
984        let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
985            .enable_all()
986            .build()
987            .unwrap();
988
989        let mut builder =
990            RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
991
992        let tx1 = builder.tx();
993        builder.spawn_on(
994            async move {
995                tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
996                    .await
997                    .unwrap();
998
999                Ok(())
1000            },
1001            tokio_runtime.handle(),
1002        );
1003
1004        let tx2 = builder.tx();
1005        builder.spawn_blocking_on(
1006            move || {
1007                tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
1008                    .unwrap();
1009
1010                Ok(())
1011            },
1012            tokio_runtime.handle(),
1013        );
1014
1015        let mut stream = builder.build();
1016
1017        let mut number_of_batches = 0;
1018
1019        loop {
1020            let poll = stream.poll_next_unpin(&mut Context::from_waker(
1021                futures::task::noop_waker_ref(),
1022            ));
1023
1024            match poll {
1025                Poll::Ready(None) => {
1026                    break;
1027                }
1028                Poll::Ready(Some(Ok(batch))) => {
1029                    number_of_batches += 1;
1030                    assert_eq!(batch.num_rows(), 0);
1031                }
1032                Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
1033                Poll::Pending => {
1034                    continue;
1035                }
1036            }
1037        }
1038
1039        assert_eq!(
1040            number_of_batches, 2,
1041            "Should have received exactly two empty batches"
1042        );
1043    }
1044
1045    #[tokio::test]
1046    async fn test_reservation_stream_shrinks_on_poll() {
1047        use arrow::array::Int32Array;
1048        use datafusion_execution::memory_pool::MemoryConsumer;
1049        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1050
1051        let runtime = RuntimeEnvBuilder::new()
1052            .with_memory_limit(10 * 1024 * 1024, 1.0)
1053            .build_arc()
1054            .unwrap();
1055
1056        let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
1057
1058        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1059
1060        // Create batches
1061        let batch1 = RecordBatch::try_new(
1062            Arc::clone(&schema),
1063            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
1064        )
1065        .unwrap();
1066        let batch2 = RecordBatch::try_new(
1067            Arc::clone(&schema),
1068            vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))],
1069        )
1070        .unwrap();
1071
1072        let batch1_size = get_record_batch_memory_size(&batch1);
1073        let batch2_size = get_record_batch_memory_size(&batch2);
1074
1075        // Reserve memory upfront
1076        reservation.try_grow(batch1_size + batch2_size).unwrap();
1077        let initial_reserved = runtime.memory_pool.reserved();
1078        assert_eq!(initial_reserved, batch1_size + batch2_size);
1079
1080        // Create stream with batches
1081        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1082        let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
1083            as SendableRecordBatchStream;
1084
1085        let mut res_stream =
1086            ReservationStream::new(Arc::clone(&schema), inner, reservation);
1087
1088        // Poll first batch
1089        let result1 = res_stream.next().await;
1090        assert!(result1.is_some());
1091
1092        // Memory should be reduced by batch1_size
1093        let after_first = runtime.memory_pool.reserved();
1094        assert_eq!(after_first, batch2_size);
1095
1096        // Poll second batch
1097        let result2 = res_stream.next().await;
1098        assert!(result2.is_some());
1099
1100        // Memory should be reduced by batch2_size
1101        let after_second = runtime.memory_pool.reserved();
1102        assert_eq!(after_second, 0);
1103
1104        // Poll None (end of stream)
1105        let result3 = res_stream.next().await;
1106        assert!(result3.is_none());
1107
1108        // Memory should still be 0
1109        assert_eq!(runtime.memory_pool.reserved(), 0);
1110    }
1111
1112    #[tokio::test]
1113    async fn test_reservation_stream_error_handling() {
1114        use datafusion_execution::memory_pool::MemoryConsumer;
1115        use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1116
1117        let runtime = RuntimeEnvBuilder::new()
1118            .with_memory_limit(10 * 1024 * 1024, 1.0)
1119            .build_arc()
1120            .unwrap();
1121
1122        let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
1123
1124        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1125
1126        reservation.try_grow(1000).unwrap();
1127        let initial = runtime.memory_pool.reserved();
1128        assert_eq!(initial, 1000);
1129
1130        // Create a stream that errors
1131        let stream = futures::stream::iter(vec![exec_err!("Test error")]);
1132        let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
1133            as SendableRecordBatchStream;
1134
1135        let mut res_stream =
1136            ReservationStream::new(Arc::clone(&schema), inner, reservation);
1137
1138        // Get the error
1139        let result = res_stream.next().await;
1140        assert!(result.is_some());
1141        assert!(result.unwrap().is_err());
1142
1143        // Verify reservation is NOT automatically freed on error
1144        // The reservation is only freed when poll_next returns Poll::Ready(None)
1145        // After an error, the stream may continue to hold the reservation
1146        // until it's explicitly dropped or polled to None
1147        let after_error = runtime.memory_pool.reserved();
1148        assert_eq!(
1149            after_error, 1000,
1150            "Reservation should still be held after error"
1151        );
1152
1153        // Drop the stream to free the reservation
1154        drop(res_stream);
1155
1156        // Now memory should be freed
1157        assert_eq!(
1158            runtime.memory_pool.reserved(),
1159            0,
1160            "Memory should be freed when stream is dropped"
1161        );
1162    }
1163}