datafusion_physical_plan/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30
31use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
32use datafusion_common::{exec_err, Result};
33use datafusion_common_runtime::JoinSet;
34use datafusion_execution::TaskContext;
35
36use futures::ready;
37use futures::stream::BoxStream;
38use futures::{Future, Stream, StreamExt};
39use log::debug;
40use pin_project_lite::pin_project;
41use tokio::runtime::Handle;
42use tokio::sync::mpsc::{Receiver, Sender};
43
44/// Creates a stream from a collection of producing tasks, routing panics to the stream.
45///
46/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with the differences being:
47///
48/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
49///
50/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
51///
52/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
53///
54/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
55pub(crate) struct ReceiverStreamBuilder<O> {
56    tx: Sender<Result<O>>,
57    rx: Receiver<Result<O>>,
58    join_set: JoinSet<Result<()>>,
59}
60
61impl<O: Send + 'static> ReceiverStreamBuilder<O> {
62    /// Create new channels with the specified buffer size
63    pub fn new(capacity: usize) -> Self {
64        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
65
66        Self {
67            tx,
68            rx,
69            join_set: JoinSet::new(),
70        }
71    }
72
73    /// Get a handle for sending data to the output
74    pub fn tx(&self) -> Sender<Result<O>> {
75        self.tx.clone()
76    }
77
78    /// Spawn task that will be aborted if this builder (or the stream
79    /// built from it) are dropped
80    pub fn spawn<F>(&mut self, task: F)
81    where
82        F: Future<Output = Result<()>>,
83        F: Send + 'static,
84    {
85        self.join_set.spawn(task);
86    }
87
88    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime
89    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
90    where
91        F: Future<Output = Result<()>>,
92        F: Send + 'static,
93    {
94        self.join_set.spawn_on(task, handle);
95    }
96
97    /// Spawn a blocking task that will be aborted if this builder (or the stream
98    /// built from it) are dropped.
99    ///
100    /// This is often used to spawn tasks that write to the sender
101    /// retrieved from `Self::tx`.
102    pub fn spawn_blocking<F>(&mut self, f: F)
103    where
104        F: FnOnce() -> Result<()>,
105        F: Send + 'static,
106    {
107        self.join_set.spawn_blocking(f);
108    }
109
110    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
111    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
112    where
113        F: FnOnce() -> Result<()>,
114        F: Send + 'static,
115    {
116        self.join_set.spawn_blocking_on(f, handle);
117    }
118
119    /// Create a stream of all data written to `tx`
120    pub fn build(self) -> BoxStream<'static, Result<O>> {
121        let Self {
122            tx,
123            rx,
124            mut join_set,
125        } = self;
126
127        // Doesn't need tx
128        drop(tx);
129
130        // future that checks the result of the join set, and propagates panic if seen
131        let check = async move {
132            while let Some(result) = join_set.join_next().await {
133                match result {
134                    Ok(task_result) => {
135                        match task_result {
136                            // Nothing to report
137                            Ok(_) => continue,
138                            // This means a blocking task error
139                            Err(error) => return Some(Err(error)),
140                        }
141                    }
142                    // This means a tokio task error, likely a panic
143                    Err(e) => {
144                        if e.is_panic() {
145                            // resume on the main thread
146                            std::panic::resume_unwind(e.into_panic());
147                        } else {
148                            // This should only occur if the task is
149                            // cancelled, which would only occur if
150                            // the JoinSet were aborted, which in turn
151                            // would imply that the receiver has been
152                            // dropped and this code is not running
153                            return Some(exec_err!("Non Panic Task error: {e}"));
154                        }
155                    }
156                }
157            }
158            None
159        };
160
161        let check_stream = futures::stream::once(check)
162            // unwrap Option / only return the error
163            .filter_map(|item| async move { item });
164
165        // Convert the receiver into a stream
166        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
167            let next_item = rx.recv().await;
168            next_item.map(|next_item| (next_item, rx))
169        });
170
171        // Merge the streams together so whichever is ready first
172        // produces the batch
173        futures::stream::select(rx_stream, check_stream).boxed()
174    }
175}
176
177/// Builder for `RecordBatchReceiverStream` that propagates errors
178/// and panic's correctly.
179///
180/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
181/// that produce [`RecordBatch`]es and send them to a single
182/// `Receiver` which can improve parallelism.
183///
184/// This also handles propagating panic`s and canceling the tasks.
185///
186/// # Example
187///
188/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
189/// the `tx` end of the builder, after building the stream, we can receive
190/// those batches with calling `.next()`
191///
192/// ```
193/// # use std::sync::Arc;
194/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
195/// # use datafusion_common::arrow::array::RecordBatch;
196/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
197/// # use futures::stream::StreamExt;
198/// # use tokio::runtime::Builder;
199/// # let rt = Builder::new_current_thread().build().unwrap();
200/// #
201/// # rt.block_on(async {
202/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
203/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
204///
205/// // task 1
206/// let tx_1 = builder.tx();
207/// let schema_1 = Arc::clone(&schema);
208/// builder.spawn(async move {
209///     // Your task needs to send batches to the tx
210///     tx_1.send(Ok(RecordBatch::new_empty(schema_1)))
211///         .await
212///         .unwrap();
213///
214///     Ok(())
215/// });
216///
217/// // task 2
218/// let tx_2 = builder.tx();
219/// let schema_2 = Arc::clone(&schema);
220/// builder.spawn(async move {
221///     // Your task needs to send batches to the tx
222///     tx_2.send(Ok(RecordBatch::new_empty(schema_2)))
223///         .await
224///         .unwrap();
225///
226///     Ok(())
227/// });
228///
229/// let mut stream = builder.build();
230/// while let Some(res_batch) = stream.next().await {
231///     // `res_batch` can either from task 1 or 2
232///
233///     // do something with `res_batch`
234/// }
235/// # });
236/// ```
237pub struct RecordBatchReceiverStreamBuilder {
238    schema: SchemaRef,
239    inner: ReceiverStreamBuilder<RecordBatch>,
240}
241
242impl RecordBatchReceiverStreamBuilder {
243    /// Create new channels with the specified buffer size
244    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
245        Self {
246            schema,
247            inner: ReceiverStreamBuilder::new(capacity),
248        }
249    }
250
251    /// Get a handle for sending [`RecordBatch`] to the output
252    ///
253    /// If the stream is dropped / canceled, the sender will be closed and
254    /// calling `tx().send()` will return an error. Producers should stop
255    /// producing in this case and return control.
256    pub fn tx(&self) -> Sender<Result<RecordBatch>> {
257        self.inner.tx()
258    }
259
260    /// Spawn task that will be aborted if this builder (or the stream
261    /// built from it) are dropped
262    ///
263    /// This is often used to spawn tasks that write to the sender
264    /// retrieved from [`Self::tx`], for examples, see the document
265    /// of this type.
266    pub fn spawn<F>(&mut self, task: F)
267    where
268        F: Future<Output = Result<()>>,
269        F: Send + 'static,
270    {
271        self.inner.spawn(task)
272    }
273
274    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
275    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
276    where
277        F: Future<Output = Result<()>>,
278        F: Send + 'static,
279    {
280        self.inner.spawn_on(task, handle)
281    }
282
283    /// Spawn a blocking task tied to the builder and stream.
284    ///
285    /// # Drop / Cancel Behavior
286    ///
287    /// If this builder (or the stream built from it) is dropped **before** the
288    /// task starts, the task is also dropped and will never start execute.
289    ///
290    /// **Note:** Once the blocking task has started, it **will not** be
291    /// forcibly stopped on drop as Rust does not allow forcing a running thread
292    /// to terminate. The task will continue running until it completes or
293    /// encounters an error.
294    ///
295    /// Users should ensure that their blocking function periodically checks for
296    /// errors calling `tx.blocking_send`. An error signals that the stream has
297    /// been dropped / cancelled and the blocking task should exit.
298    ///
299    /// This is often used to spawn tasks that write to the sender
300    /// retrieved from [`Self::tx`], for examples, see the document
301    /// of this type.
302    pub fn spawn_blocking<F>(&mut self, f: F)
303    where
304        F: FnOnce() -> Result<()>,
305        F: Send + 'static,
306    {
307        self.inner.spawn_blocking(f)
308    }
309
310    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
311    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
312    where
313        F: FnOnce() -> Result<()>,
314        F: Send + 'static,
315    {
316        self.inner.spawn_blocking_on(f, handle)
317    }
318
319    /// Runs the `partition` of the `input` ExecutionPlan on the
320    /// tokio thread pool and writes its outputs to this stream
321    ///
322    /// If the input partition produces an error, the error will be
323    /// sent to the output stream and no further results are sent.
324    pub(crate) fn run_input(
325        &mut self,
326        input: Arc<dyn ExecutionPlan>,
327        partition: usize,
328        context: Arc<TaskContext>,
329    ) {
330        let output = self.tx();
331
332        self.inner.spawn(async move {
333            let mut stream = match input.execute(partition, context) {
334                Err(e) => {
335                    // If send fails, the plan being torn down, there
336                    // is no place to send the error and no reason to continue.
337                    output.send(Err(e)).await.ok();
338                    debug!(
339                        "Stopping execution: error executing input: {}",
340                        displayable(input.as_ref()).one_line()
341                    );
342                    return Ok(());
343                }
344                Ok(stream) => stream,
345            };
346
347            // Transfer batches from inner stream to the output tx
348            // immediately.
349            while let Some(item) = stream.next().await {
350                let is_err = item.is_err();
351
352                // If send fails, plan being torn down, there is no
353                // place to send the error and no reason to continue.
354                if output.send(item).await.is_err() {
355                    debug!(
356                        "Stopping execution: output is gone, plan cancelling: {}",
357                        displayable(input.as_ref()).one_line()
358                    );
359                    return Ok(());
360                }
361
362                // Stop after the first error is encountered (Don't
363                // drive all streams to completion)
364                if is_err {
365                    debug!(
366                        "Stopping execution: plan returned error: {}",
367                        displayable(input.as_ref()).one_line()
368                    );
369                    return Ok(());
370                }
371            }
372
373            Ok(())
374        });
375    }
376
377    /// Create a stream of all [`RecordBatch`] written to `tx`
378    pub fn build(self) -> SendableRecordBatchStream {
379        Box::pin(RecordBatchStreamAdapter::new(
380            self.schema,
381            self.inner.build(),
382        ))
383    }
384}
385
386#[doc(hidden)]
387pub struct RecordBatchReceiverStream {}
388
389impl RecordBatchReceiverStream {
390    /// Create a builder with an internal buffer of capacity batches.
391    pub fn builder(
392        schema: SchemaRef,
393        capacity: usize,
394    ) -> RecordBatchReceiverStreamBuilder {
395        RecordBatchReceiverStreamBuilder::new(schema, capacity)
396    }
397}
398
399pin_project! {
400    /// Combines a [`Stream`] with a [`SchemaRef`] implementing
401    /// [`SendableRecordBatchStream`] for the combination
402    ///
403    /// See [`Self::new`] for an example
404    pub struct RecordBatchStreamAdapter<S> {
405        schema: SchemaRef,
406
407        #[pin]
408        stream: S,
409    }
410}
411
412impl<S> RecordBatchStreamAdapter<S> {
413    /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
414    ///
415    /// Note to create a [`SendableRecordBatchStream`] you pin the result
416    ///
417    /// # Example
418    /// ```
419    /// # use arrow::array::record_batch;
420    /// # use datafusion_execution::SendableRecordBatchStream;
421    /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
422    /// // Create stream of Result<RecordBatch>
423    /// let batch = record_batch!(
424    ///     ("a", Int32, [1, 2, 3]),
425    ///     ("b", Float64, [Some(4.0), None, Some(5.0)])
426    /// )
427    /// .expect("created batch");
428    /// let schema = batch.schema();
429    /// let stream = futures::stream::iter(vec![Ok(batch)]);
430    /// // Convert the stream to a SendableRecordBatchStream
431    /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
432    /// // Now you can use the adapter as a SendableRecordBatchStream
433    /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
434    /// // ...
435    /// ```
436    pub fn new(schema: SchemaRef, stream: S) -> Self {
437        Self { schema, stream }
438    }
439}
440
441impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.debug_struct("RecordBatchStreamAdapter")
444            .field("schema", &self.schema)
445            .finish()
446    }
447}
448
449impl<S> Stream for RecordBatchStreamAdapter<S>
450where
451    S: Stream<Item = Result<RecordBatch>>,
452{
453    type Item = Result<RecordBatch>;
454
455    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
456        self.project().stream.poll_next(cx)
457    }
458
459    fn size_hint(&self) -> (usize, Option<usize>) {
460        self.stream.size_hint()
461    }
462}
463
464impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
465where
466    S: Stream<Item = Result<RecordBatch>>,
467{
468    fn schema(&self) -> SchemaRef {
469        Arc::clone(&self.schema)
470    }
471}
472
473/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
474/// that will produce no results
475pub struct EmptyRecordBatchStream {
476    /// Schema wrapped by Arc
477    schema: SchemaRef,
478}
479
480impl EmptyRecordBatchStream {
481    /// Create an empty RecordBatchStream
482    pub fn new(schema: SchemaRef) -> Self {
483        Self { schema }
484    }
485}
486
487impl RecordBatchStream for EmptyRecordBatchStream {
488    fn schema(&self) -> SchemaRef {
489        Arc::clone(&self.schema)
490    }
491}
492
493impl Stream for EmptyRecordBatchStream {
494    type Item = Result<RecordBatch>;
495
496    fn poll_next(
497        self: Pin<&mut Self>,
498        _cx: &mut Context<'_>,
499    ) -> Poll<Option<Self::Item>> {
500        Poll::Ready(None)
501    }
502}
503
504/// Stream wrapper that records `BaselineMetrics` for a particular
505/// `[SendableRecordBatchStream]` (likely a partition)
506pub(crate) struct ObservedStream {
507    inner: SendableRecordBatchStream,
508    baseline_metrics: BaselineMetrics,
509    fetch: Option<usize>,
510    produced: usize,
511}
512
513impl ObservedStream {
514    pub fn new(
515        inner: SendableRecordBatchStream,
516        baseline_metrics: BaselineMetrics,
517        fetch: Option<usize>,
518    ) -> Self {
519        Self {
520            inner,
521            baseline_metrics,
522            fetch,
523            produced: 0,
524        }
525    }
526
527    fn limit_reached(
528        &mut self,
529        poll: Poll<Option<Result<RecordBatch>>>,
530    ) -> Poll<Option<Result<RecordBatch>>> {
531        let Some(fetch) = self.fetch else { return poll };
532
533        if self.produced >= fetch {
534            return Poll::Ready(None);
535        }
536
537        if let Poll::Ready(Some(Ok(batch))) = &poll {
538            if self.produced + batch.num_rows() > fetch {
539                let batch = batch.slice(0, fetch.saturating_sub(self.produced));
540                self.produced += batch.num_rows();
541                return Poll::Ready(Some(Ok(batch)));
542            };
543            self.produced += batch.num_rows()
544        }
545        poll
546    }
547}
548
549impl RecordBatchStream for ObservedStream {
550    fn schema(&self) -> SchemaRef {
551        self.inner.schema()
552    }
553}
554
555impl Stream for ObservedStream {
556    type Item = Result<RecordBatch>;
557
558    fn poll_next(
559        mut self: Pin<&mut Self>,
560        cx: &mut Context<'_>,
561    ) -> Poll<Option<Self::Item>> {
562        let mut poll = self.inner.poll_next_unpin(cx);
563        if self.fetch.is_some() {
564            poll = self.limit_reached(poll);
565        }
566        self.baseline_metrics.record_poll(poll)
567    }
568}
569
570pin_project! {
571    /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches.
572    ///
573    /// This ensures upstream operators receive batches no larger than
574    /// `batch_size`, which can improve parallelism when data sources
575    /// generate very large batches.
576    ///
577    /// # Fields
578    ///
579    /// - `current_batch`: The batch currently being split, if any
580    /// - `offset`: Index of the next row to split from `current_batch`.
581    ///   This tracks our position within the current batch being split.
582    ///
583    /// # Invariants
584    ///
585    /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some`
586    /// - When `current_batch` is `None`, `offset` is always 0
587    /// - `batch_size` is always > 0
588pub struct BatchSplitStream {
589        #[pin]
590        input: SendableRecordBatchStream,
591        schema: SchemaRef,
592        batch_size: usize,
593        metrics: SplitMetrics,
594        current_batch: Option<RecordBatch>,
595        offset: usize,
596    }
597}
598
599impl BatchSplitStream {
600    /// Create a new [`BatchSplitStream`]
601    pub fn new(
602        input: SendableRecordBatchStream,
603        batch_size: usize,
604        metrics: SplitMetrics,
605    ) -> Self {
606        let schema = input.schema();
607        Self {
608            input,
609            schema,
610            batch_size,
611            metrics,
612            current_batch: None,
613            offset: 0,
614        }
615    }
616
617    /// Attempt to produce the next sliced batch from the current batch.
618    ///
619    /// Returns `Some(batch)` if a slice was produced, `None` if the current batch
620    /// is exhausted and we need to poll upstream for more data.
621    fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
622        let batch = self.current_batch.take()?;
623
624        // Assert slice boundary safety - offset should never exceed batch size
625        debug_assert!(
626            self.offset <= batch.num_rows(),
627            "Offset {} exceeds batch size {}",
628            self.offset,
629            batch.num_rows()
630        );
631
632        let remaining = batch.num_rows() - self.offset;
633        let to_take = remaining.min(self.batch_size);
634        let out = batch.slice(self.offset, to_take);
635
636        self.metrics.batches_split.add(1);
637        self.offset += to_take;
638        if self.offset < batch.num_rows() {
639            // More data remains in this batch, store it back
640            self.current_batch = Some(batch);
641        } else {
642            // Batch is exhausted, reset offset
643            // Note: current_batch is already None since we took it at the start
644            self.offset = 0;
645        }
646        Some(Ok(out))
647    }
648
649    /// Poll the upstream input for the next batch.
650    ///
651    /// Returns the appropriate `Poll` result based on upstream state.
652    /// Small batches are passed through directly, large batches are stored
653    /// for slicing and return the first slice immediately.
654    fn poll_upstream(
655        &mut self,
656        cx: &mut Context<'_>,
657    ) -> Poll<Option<Result<RecordBatch>>> {
658        match ready!(self.input.as_mut().poll_next(cx)) {
659            Some(Ok(batch)) => {
660                if batch.num_rows() <= self.batch_size {
661                    // Small batch, pass through directly
662                    Poll::Ready(Some(Ok(batch)))
663                } else {
664                    // Large batch, store for slicing and return first slice
665                    self.current_batch = Some(batch);
666                    // Immediately produce the first slice
667                    match self.next_sliced_batch() {
668                        Some(result) => Poll::Ready(Some(result)),
669                        None => Poll::Ready(None), // Should not happen
670                    }
671                }
672            }
673            Some(Err(e)) => Poll::Ready(Some(Err(e))),
674            None => Poll::Ready(None),
675        }
676    }
677}
678
679impl Stream for BatchSplitStream {
680    type Item = Result<RecordBatch>;
681
682    fn poll_next(
683        mut self: Pin<&mut Self>,
684        cx: &mut Context<'_>,
685    ) -> Poll<Option<Self::Item>> {
686        // First, try to produce a slice from the current batch
687        if let Some(result) = self.next_sliced_batch() {
688            return Poll::Ready(Some(result));
689        }
690
691        // No current batch or current batch exhausted, poll upstream
692        self.poll_upstream(cx)
693    }
694}
695
696impl RecordBatchStream for BatchSplitStream {
697    fn schema(&self) -> SchemaRef {
698        Arc::clone(&self.schema)
699    }
700}
701
702#[cfg(test)]
703mod test {
704    use super::*;
705    use crate::test::exec::{
706        assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
707    };
708
709    use arrow::datatypes::{DataType, Field, Schema};
710    use datafusion_common::exec_err;
711
712    fn schema() -> SchemaRef {
713        Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
714    }
715
716    #[tokio::test]
717    #[should_panic(expected = "PanickingStream did panic")]
718    async fn record_batch_receiver_stream_propagates_panics() {
719        let schema = schema();
720
721        let num_partitions = 10;
722        let input = PanicExec::new(Arc::clone(&schema), num_partitions);
723        consume(input, 10).await
724    }
725
726    #[tokio::test]
727    #[should_panic(expected = "PanickingStream did panic: 1")]
728    async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
729        let schema = schema();
730
731        // Make 2 partitions, second partition panics before the first
732        let num_partitions = 2;
733        let input = PanicExec::new(Arc::clone(&schema), num_partitions)
734            .with_partition_panic(0, 10)
735            .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
736
737        // Ensure that the panic results in an early shutdown (that
738        // everything stops after the first panic).
739
740        // Since the stream reads every other batch: (0,1,0,1,0,panic)
741        // so should not exceed 5 batches prior to the panic
742        let max_batches = 5;
743        consume(input, max_batches).await
744    }
745
746    #[tokio::test]
747    async fn record_batch_receiver_stream_drop_cancel() {
748        let task_ctx = Arc::new(TaskContext::default());
749        let schema = schema();
750
751        // Make an input that never proceeds
752        let input = BlockingExec::new(Arc::clone(&schema), 1);
753        let refs = input.refs();
754
755        // Configure a RecordBatchReceiverStream to consume the input
756        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
757        builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
758        let stream = builder.build();
759
760        // Input should still be present
761        assert!(std::sync::Weak::strong_count(&refs) > 0);
762
763        // Drop the stream, ensure the refs go to zero
764        drop(stream);
765        assert_strong_count_converges_to_zero(refs).await;
766    }
767
768    #[tokio::test]
769    /// Ensure that if an error is received in one stream, the
770    /// `RecordBatchReceiverStream` stops early and does not drive
771    /// other streams to completion.
772    async fn record_batch_receiver_stream_error_does_not_drive_completion() {
773        let task_ctx = Arc::new(TaskContext::default());
774        let schema = schema();
775
776        // make an input that will error twice
777        let error_stream = MockExec::new(
778            vec![exec_err!("Test1"), exec_err!("Test2")],
779            Arc::clone(&schema),
780        )
781        .with_use_task(false);
782
783        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
784        builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
785        let mut stream = builder.build();
786
787        // Get the first result, which should be an error
788        let first_batch = stream.next().await.unwrap();
789        let first_err = first_batch.unwrap_err();
790        assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
791
792        // There should be no more batches produced (should not get the second error)
793        assert!(stream.next().await.is_none());
794    }
795
796    #[tokio::test]
797    async fn batch_split_stream_basic_functionality() {
798        use arrow::array::{Int32Array, RecordBatch};
799        use futures::stream::{self, StreamExt};
800
801        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
802
803        // Create a large batch that should be split
804        let large_batch = RecordBatch::try_new(
805            Arc::clone(&schema),
806            vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
807        )
808        .unwrap();
809
810        // Create a stream with the large batch
811        let input_stream = stream::iter(vec![Ok(large_batch)]);
812        let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
813        let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
814
815        // Create a BatchSplitStream with batch_size = 500
816        let metrics = ExecutionPlanMetricsSet::new();
817        let split_metrics = SplitMetrics::new(&metrics, 0);
818        let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
819
820        let mut total_rows = 0;
821        let mut batch_count = 0;
822
823        while let Some(result) = split_stream.next().await {
824            let batch = result.unwrap();
825            assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
826            total_rows += batch.num_rows();
827            batch_count += 1;
828        }
829
830        assert_eq!(total_rows, 2000, "All rows should be preserved");
831        assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
832    }
833
834    /// Consumes all the input's partitions into a
835    /// RecordBatchReceiverStream and runs it to completion
836    ///
837    /// panic's if more than max_batches is seen,
838    async fn consume(input: PanicExec, max_batches: usize) {
839        let task_ctx = Arc::new(TaskContext::default());
840
841        let input = Arc::new(input);
842        let num_partitions = input.properties().output_partitioning().partition_count();
843
844        // Configure a RecordBatchReceiverStream to consume all the input partitions
845        let mut builder =
846            RecordBatchReceiverStream::builder(input.schema(), num_partitions);
847        for partition in 0..num_partitions {
848            builder.run_input(
849                Arc::clone(&input) as Arc<dyn ExecutionPlan>,
850                partition,
851                Arc::clone(&task_ctx),
852            );
853        }
854        let mut stream = builder.build();
855
856        // Drain the stream until it is complete, panic'ing on error
857        let mut num_batches = 0;
858        while let Some(next) = stream.next().await {
859            next.unwrap();
860            num_batches += 1;
861            assert!(
862                num_batches < max_batches,
863                "Got the limit of {num_batches} batches before seeing panic"
864            );
865        }
866    }
867
868    #[test]
869    fn record_batch_receiver_stream_builder_spawn_on_runtime() {
870        let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
871            .enable_all()
872            .build()
873            .unwrap();
874
875        let mut builder =
876            RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
877
878        let tx1 = builder.tx();
879        builder.spawn_on(
880            async move {
881                tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
882                    .await
883                    .unwrap();
884
885                Ok(())
886            },
887            tokio_runtime.handle(),
888        );
889
890        let tx2 = builder.tx();
891        builder.spawn_blocking_on(
892            move || {
893                tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
894                    .unwrap();
895
896                Ok(())
897            },
898            tokio_runtime.handle(),
899        );
900
901        let mut stream = builder.build();
902
903        let mut number_of_batches = 0;
904
905        loop {
906            let poll = stream.poll_next_unpin(&mut Context::from_waker(
907                futures::task::noop_waker_ref(),
908            ));
909
910            match poll {
911                Poll::Ready(None) => {
912                    break;
913                }
914                Poll::Ready(Some(Ok(batch))) => {
915                    number_of_batches += 1;
916                    assert_eq!(batch.num_rows(), 0);
917                }
918                Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
919                Poll::Pending => {
920                    continue;
921                }
922            }
923        }
924
925        assert_eq!(
926            number_of_batches, 2,
927            "Should have received exactly one empty batch"
928        );
929    }
930}