datafusion_physical_plan/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30
31use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
32use datafusion_common::{exec_err, Result};
33use datafusion_common_runtime::JoinSet;
34use datafusion_execution::TaskContext;
35
36use futures::ready;
37use futures::stream::BoxStream;
38use futures::{Future, Stream, StreamExt};
39use log::debug;
40use pin_project_lite::pin_project;
41use tokio::runtime::Handle;
42use tokio::sync::mpsc::{Receiver, Sender};
43
44/// Creates a stream from a collection of producing tasks, routing panics to the stream.
45///
46/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with the differences being:
47///
48/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
49///
50/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
51///
52/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
53///
54/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
55pub(crate) struct ReceiverStreamBuilder<O> {
56    tx: Sender<Result<O>>,
57    rx: Receiver<Result<O>>,
58    join_set: JoinSet<Result<()>>,
59}
60
61impl<O: Send + 'static> ReceiverStreamBuilder<O> {
62    /// Create new channels with the specified buffer size
63    pub fn new(capacity: usize) -> Self {
64        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
65
66        Self {
67            tx,
68            rx,
69            join_set: JoinSet::new(),
70        }
71    }
72
73    /// Get a handle for sending data to the output
74    pub fn tx(&self) -> Sender<Result<O>> {
75        self.tx.clone()
76    }
77
78    /// Spawn task that will be aborted if this builder (or the stream
79    /// built from it) are dropped
80    pub fn spawn<F>(&mut self, task: F)
81    where
82        F: Future<Output = Result<()>>,
83        F: Send + 'static,
84    {
85        self.join_set.spawn(task);
86    }
87
88    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime
89    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
90    where
91        F: Future<Output = Result<()>>,
92        F: Send + 'static,
93    {
94        self.join_set.spawn_on(task, handle);
95    }
96
97    /// Spawn a blocking task that will be aborted if this builder (or the stream
98    /// built from it) are dropped.
99    ///
100    /// This is often used to spawn tasks that write to the sender
101    /// retrieved from `Self::tx`.
102    pub fn spawn_blocking<F>(&mut self, f: F)
103    where
104        F: FnOnce() -> Result<()>,
105        F: Send + 'static,
106    {
107        self.join_set.spawn_blocking(f);
108    }
109
110    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
111    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
112    where
113        F: FnOnce() -> Result<()>,
114        F: Send + 'static,
115    {
116        self.join_set.spawn_blocking_on(f, handle);
117    }
118
119    /// Create a stream of all data written to `tx`
120    pub fn build(self) -> BoxStream<'static, Result<O>> {
121        let Self {
122            tx,
123            rx,
124            mut join_set,
125        } = self;
126
127        // Doesn't need tx
128        drop(tx);
129
130        // future that checks the result of the join set, and propagates panic if seen
131        let check = async move {
132            while let Some(result) = join_set.join_next().await {
133                match result {
134                    Ok(task_result) => {
135                        match task_result {
136                            // Nothing to report
137                            Ok(_) => continue,
138                            // This means a blocking task error
139                            Err(error) => return Some(Err(error)),
140                        }
141                    }
142                    // This means a tokio task error, likely a panic
143                    Err(e) => {
144                        if e.is_panic() {
145                            // resume on the main thread
146                            std::panic::resume_unwind(e.into_panic());
147                        } else {
148                            // This should only occur if the task is
149                            // cancelled, which would only occur if
150                            // the JoinSet were aborted, which in turn
151                            // would imply that the receiver has been
152                            // dropped and this code is not running
153                            return Some(exec_err!("Non Panic Task error: {e}"));
154                        }
155                    }
156                }
157            }
158            None
159        };
160
161        let check_stream = futures::stream::once(check)
162            // unwrap Option / only return the error
163            .filter_map(|item| async move { item });
164
165        // Convert the receiver into a stream
166        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
167            let next_item = rx.recv().await;
168            next_item.map(|next_item| (next_item, rx))
169        });
170
171        // Merge the streams together so whichever is ready first
172        // produces the batch
173        futures::stream::select(rx_stream, check_stream).boxed()
174    }
175}
176
177/// Builder for `RecordBatchReceiverStream` that propagates errors
178/// and panic's correctly.
179///
180/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
181/// that produce [`RecordBatch`]es and send them to a single
182/// `Receiver` which can improve parallelism.
183///
184/// This also handles propagating panic`s and canceling the tasks.
185///
186/// # Example
187///
188/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
189/// the `tx` end of the builder, after building the stream, we can receive
190/// those batches with calling `.next()`
191///
192/// ```
193/// # use std::sync::Arc;
194/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
195/// # use datafusion_common::arrow::array::RecordBatch;
196/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
197/// # use futures::stream::StreamExt;
198/// # use tokio::runtime::Builder;
199/// # let rt = Builder::new_current_thread().build().unwrap();
200/// #
201/// # rt.block_on(async {
202/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
203/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
204///
205/// // task 1
206/// let tx_1 = builder.tx();
207/// let schema_1 = Arc::clone(&schema);
208/// builder.spawn(async move {
209///     // Your task needs to send batches to the tx
210///     tx_1.send(Ok(RecordBatch::new_empty(schema_1))).await.unwrap();
211///
212///     Ok(())
213/// });
214///
215/// // task 2
216/// let tx_2 = builder.tx();
217/// let schema_2 = Arc::clone(&schema);
218/// builder.spawn(async move {
219///     // Your task needs to send batches to the tx
220///     tx_2.send(Ok(RecordBatch::new_empty(schema_2))).await.unwrap();
221///
222///     Ok(())
223/// });
224///
225/// let mut stream = builder.build();
226/// while let Some(res_batch) = stream.next().await {
227///     // `res_batch` can either from task 1 or 2
228///
229///     // do something with `res_batch`
230/// }
231/// # });
232/// ```
233pub struct RecordBatchReceiverStreamBuilder {
234    schema: SchemaRef,
235    inner: ReceiverStreamBuilder<RecordBatch>,
236}
237
238impl RecordBatchReceiverStreamBuilder {
239    /// Create new channels with the specified buffer size
240    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
241        Self {
242            schema,
243            inner: ReceiverStreamBuilder::new(capacity),
244        }
245    }
246
247    /// Get a handle for sending [`RecordBatch`] to the output
248    ///
249    /// If the stream is dropped / canceled, the sender will be closed and
250    /// calling `tx().send()` will return an error. Producers should stop
251    /// producing in this case and return control.
252    pub fn tx(&self) -> Sender<Result<RecordBatch>> {
253        self.inner.tx()
254    }
255
256    /// Spawn task that will be aborted if this builder (or the stream
257    /// built from it) are dropped
258    ///
259    /// This is often used to spawn tasks that write to the sender
260    /// retrieved from [`Self::tx`], for examples, see the document
261    /// of this type.
262    pub fn spawn<F>(&mut self, task: F)
263    where
264        F: Future<Output = Result<()>>,
265        F: Send + 'static,
266    {
267        self.inner.spawn(task)
268    }
269
270    /// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
271    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
272    where
273        F: Future<Output = Result<()>>,
274        F: Send + 'static,
275    {
276        self.inner.spawn_on(task, handle)
277    }
278
279    /// Spawn a blocking task tied to the builder and stream.
280    ///
281    /// # Drop / Cancel Behavior
282    ///
283    /// If this builder (or the stream built from it) is dropped **before** the
284    /// task starts, the task is also dropped and will never start execute.
285    ///
286    /// **Note:** Once the blocking task has started, it **will not** be
287    /// forcibly stopped on drop as Rust does not allow forcing a running thread
288    /// to terminate. The task will continue running until it completes or
289    /// encounters an error.
290    ///
291    /// Users should ensure that their blocking function periodically checks for
292    /// errors calling `tx.blocking_send`. An error signals that the stream has
293    /// been dropped / cancelled and the blocking task should exit.
294    ///
295    /// This is often used to spawn tasks that write to the sender
296    /// retrieved from [`Self::tx`], for examples, see the document
297    /// of this type.
298    pub fn spawn_blocking<F>(&mut self, f: F)
299    where
300        F: FnOnce() -> Result<()>,
301        F: Send + 'static,
302    {
303        self.inner.spawn_blocking(f)
304    }
305
306    /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
307    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
308    where
309        F: FnOnce() -> Result<()>,
310        F: Send + 'static,
311    {
312        self.inner.spawn_blocking_on(f, handle)
313    }
314
315    /// Runs the `partition` of the `input` ExecutionPlan on the
316    /// tokio thread pool and writes its outputs to this stream
317    ///
318    /// If the input partition produces an error, the error will be
319    /// sent to the output stream and no further results are sent.
320    pub(crate) fn run_input(
321        &mut self,
322        input: Arc<dyn ExecutionPlan>,
323        partition: usize,
324        context: Arc<TaskContext>,
325    ) {
326        let output = self.tx();
327
328        self.inner.spawn(async move {
329            let mut stream = match input.execute(partition, context) {
330                Err(e) => {
331                    // If send fails, the plan being torn down, there
332                    // is no place to send the error and no reason to continue.
333                    output.send(Err(e)).await.ok();
334                    debug!(
335                        "Stopping execution: error executing input: {}",
336                        displayable(input.as_ref()).one_line()
337                    );
338                    return Ok(());
339                }
340                Ok(stream) => stream,
341            };
342
343            // Transfer batches from inner stream to the output tx
344            // immediately.
345            while let Some(item) = stream.next().await {
346                let is_err = item.is_err();
347
348                // If send fails, plan being torn down, there is no
349                // place to send the error and no reason to continue.
350                if output.send(item).await.is_err() {
351                    debug!(
352                        "Stopping execution: output is gone, plan cancelling: {}",
353                        displayable(input.as_ref()).one_line()
354                    );
355                    return Ok(());
356                }
357
358                // Stop after the first error is encountered (Don't
359                // drive all streams to completion)
360                if is_err {
361                    debug!(
362                        "Stopping execution: plan returned error: {}",
363                        displayable(input.as_ref()).one_line()
364                    );
365                    return Ok(());
366                }
367            }
368
369            Ok(())
370        });
371    }
372
373    /// Create a stream of all [`RecordBatch`] written to `tx`
374    pub fn build(self) -> SendableRecordBatchStream {
375        Box::pin(RecordBatchStreamAdapter::new(
376            self.schema,
377            self.inner.build(),
378        ))
379    }
380}
381
382#[doc(hidden)]
383pub struct RecordBatchReceiverStream {}
384
385impl RecordBatchReceiverStream {
386    /// Create a builder with an internal buffer of capacity batches.
387    pub fn builder(
388        schema: SchemaRef,
389        capacity: usize,
390    ) -> RecordBatchReceiverStreamBuilder {
391        RecordBatchReceiverStreamBuilder::new(schema, capacity)
392    }
393}
394
395pin_project! {
396    /// Combines a [`Stream`] with a [`SchemaRef`] implementing
397    /// [`SendableRecordBatchStream`] for the combination
398    ///
399    /// See [`Self::new`] for an example
400    pub struct RecordBatchStreamAdapter<S> {
401        schema: SchemaRef,
402
403        #[pin]
404        stream: S,
405    }
406}
407
408impl<S> RecordBatchStreamAdapter<S> {
409    /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
410    ///
411    /// Note to create a [`SendableRecordBatchStream`] you pin the result
412    ///
413    /// # Example
414    /// ```
415    /// # use arrow::array::record_batch;
416    /// # use datafusion_execution::SendableRecordBatchStream;
417    /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
418    /// // Create stream of Result<RecordBatch>
419    /// let batch = record_batch!(
420    ///   ("a", Int32, [1, 2, 3]),
421    ///   ("b", Float64, [Some(4.0), None, Some(5.0)])
422    /// ).expect("created batch");
423    /// let schema = batch.schema();
424    /// let stream = futures::stream::iter(vec![Ok(batch)]);
425    /// // Convert the stream to a SendableRecordBatchStream
426    /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
427    /// // Now you can use the adapter as a SendableRecordBatchStream
428    /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
429    /// // ...
430    /// ```
431    pub fn new(schema: SchemaRef, stream: S) -> Self {
432        Self { schema, stream }
433    }
434}
435
436impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
437    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438        f.debug_struct("RecordBatchStreamAdapter")
439            .field("schema", &self.schema)
440            .finish()
441    }
442}
443
444impl<S> Stream for RecordBatchStreamAdapter<S>
445where
446    S: Stream<Item = Result<RecordBatch>>,
447{
448    type Item = Result<RecordBatch>;
449
450    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
451        self.project().stream.poll_next(cx)
452    }
453
454    fn size_hint(&self) -> (usize, Option<usize>) {
455        self.stream.size_hint()
456    }
457}
458
459impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
460where
461    S: Stream<Item = Result<RecordBatch>>,
462{
463    fn schema(&self) -> SchemaRef {
464        Arc::clone(&self.schema)
465    }
466}
467
468/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
469/// that will produce no results
470pub struct EmptyRecordBatchStream {
471    /// Schema wrapped by Arc
472    schema: SchemaRef,
473}
474
475impl EmptyRecordBatchStream {
476    /// Create an empty RecordBatchStream
477    pub fn new(schema: SchemaRef) -> Self {
478        Self { schema }
479    }
480}
481
482impl RecordBatchStream for EmptyRecordBatchStream {
483    fn schema(&self) -> SchemaRef {
484        Arc::clone(&self.schema)
485    }
486}
487
488impl Stream for EmptyRecordBatchStream {
489    type Item = Result<RecordBatch>;
490
491    fn poll_next(
492        self: Pin<&mut Self>,
493        _cx: &mut Context<'_>,
494    ) -> Poll<Option<Self::Item>> {
495        Poll::Ready(None)
496    }
497}
498
499/// Stream wrapper that records `BaselineMetrics` for a particular
500/// `[SendableRecordBatchStream]` (likely a partition)
501pub(crate) struct ObservedStream {
502    inner: SendableRecordBatchStream,
503    baseline_metrics: BaselineMetrics,
504    fetch: Option<usize>,
505    produced: usize,
506}
507
508impl ObservedStream {
509    pub fn new(
510        inner: SendableRecordBatchStream,
511        baseline_metrics: BaselineMetrics,
512        fetch: Option<usize>,
513    ) -> Self {
514        Self {
515            inner,
516            baseline_metrics,
517            fetch,
518            produced: 0,
519        }
520    }
521
522    fn limit_reached(
523        &mut self,
524        poll: Poll<Option<Result<RecordBatch>>>,
525    ) -> Poll<Option<Result<RecordBatch>>> {
526        let Some(fetch) = self.fetch else { return poll };
527
528        if self.produced >= fetch {
529            return Poll::Ready(None);
530        }
531
532        if let Poll::Ready(Some(Ok(batch))) = &poll {
533            if self.produced + batch.num_rows() > fetch {
534                let batch = batch.slice(0, fetch.saturating_sub(self.produced));
535                self.produced += batch.num_rows();
536                return Poll::Ready(Some(Ok(batch)));
537            };
538            self.produced += batch.num_rows()
539        }
540        poll
541    }
542}
543
544impl RecordBatchStream for ObservedStream {
545    fn schema(&self) -> SchemaRef {
546        self.inner.schema()
547    }
548}
549
550impl Stream for ObservedStream {
551    type Item = Result<RecordBatch>;
552
553    fn poll_next(
554        mut self: Pin<&mut Self>,
555        cx: &mut Context<'_>,
556    ) -> Poll<Option<Self::Item>> {
557        let mut poll = self.inner.poll_next_unpin(cx);
558        if self.fetch.is_some() {
559            poll = self.limit_reached(poll);
560        }
561        self.baseline_metrics.record_poll(poll)
562    }
563}
564
565pin_project! {
566    /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches.
567    ///
568    /// This ensures upstream operators receive batches no larger than
569    /// `batch_size`, which can improve parallelism when data sources
570    /// generate very large batches.
571    ///
572    /// # Fields
573    ///
574    /// - `current_batch`: The batch currently being split, if any
575    /// - `offset`: Index of the next row to split from `current_batch`.
576    ///   This tracks our position within the current batch being split.
577    ///
578    /// # Invariants
579    ///
580    /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some`
581    /// - When `current_batch` is `None`, `offset` is always 0
582    /// - `batch_size` is always > 0
583pub struct BatchSplitStream {
584        #[pin]
585        input: SendableRecordBatchStream,
586        schema: SchemaRef,
587        batch_size: usize,
588        metrics: SplitMetrics,
589        current_batch: Option<RecordBatch>,
590        offset: usize,
591    }
592}
593
594impl BatchSplitStream {
595    /// Create a new [`BatchSplitStream`]
596    pub fn new(
597        input: SendableRecordBatchStream,
598        batch_size: usize,
599        metrics: SplitMetrics,
600    ) -> Self {
601        let schema = input.schema();
602        Self {
603            input,
604            schema,
605            batch_size,
606            metrics,
607            current_batch: None,
608            offset: 0,
609        }
610    }
611
612    /// Attempt to produce the next sliced batch from the current batch.
613    ///
614    /// Returns `Some(batch)` if a slice was produced, `None` if the current batch
615    /// is exhausted and we need to poll upstream for more data.
616    fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
617        let batch = self.current_batch.take()?;
618
619        // Assert slice boundary safety - offset should never exceed batch size
620        debug_assert!(
621            self.offset <= batch.num_rows(),
622            "Offset {} exceeds batch size {}",
623            self.offset,
624            batch.num_rows()
625        );
626
627        let remaining = batch.num_rows() - self.offset;
628        let to_take = remaining.min(self.batch_size);
629        let out = batch.slice(self.offset, to_take);
630
631        self.metrics.batches_split.add(1);
632        self.offset += to_take;
633        if self.offset < batch.num_rows() {
634            // More data remains in this batch, store it back
635            self.current_batch = Some(batch);
636        } else {
637            // Batch is exhausted, reset offset
638            // Note: current_batch is already None since we took it at the start
639            self.offset = 0;
640        }
641        Some(Ok(out))
642    }
643
644    /// Poll the upstream input for the next batch.
645    ///
646    /// Returns the appropriate `Poll` result based on upstream state.
647    /// Small batches are passed through directly, large batches are stored
648    /// for slicing and return the first slice immediately.
649    fn poll_upstream(
650        &mut self,
651        cx: &mut Context<'_>,
652    ) -> Poll<Option<Result<RecordBatch>>> {
653        match ready!(self.input.as_mut().poll_next(cx)) {
654            Some(Ok(batch)) => {
655                if batch.num_rows() <= self.batch_size {
656                    // Small batch, pass through directly
657                    Poll::Ready(Some(Ok(batch)))
658                } else {
659                    // Large batch, store for slicing and return first slice
660                    self.current_batch = Some(batch);
661                    // Immediately produce the first slice
662                    match self.next_sliced_batch() {
663                        Some(result) => Poll::Ready(Some(result)),
664                        None => Poll::Ready(None), // Should not happen
665                    }
666                }
667            }
668            Some(Err(e)) => Poll::Ready(Some(Err(e))),
669            None => Poll::Ready(None),
670        }
671    }
672}
673
674impl Stream for BatchSplitStream {
675    type Item = Result<RecordBatch>;
676
677    fn poll_next(
678        mut self: Pin<&mut Self>,
679        cx: &mut Context<'_>,
680    ) -> Poll<Option<Self::Item>> {
681        // First, try to produce a slice from the current batch
682        if let Some(result) = self.next_sliced_batch() {
683            return Poll::Ready(Some(result));
684        }
685
686        // No current batch or current batch exhausted, poll upstream
687        self.poll_upstream(cx)
688    }
689}
690
691impl RecordBatchStream for BatchSplitStream {
692    fn schema(&self) -> SchemaRef {
693        Arc::clone(&self.schema)
694    }
695}
696
697#[cfg(test)]
698mod test {
699    use super::*;
700    use crate::test::exec::{
701        assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
702    };
703
704    use arrow::datatypes::{DataType, Field, Schema};
705    use datafusion_common::exec_err;
706
707    fn schema() -> SchemaRef {
708        Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
709    }
710
711    #[tokio::test]
712    #[should_panic(expected = "PanickingStream did panic")]
713    async fn record_batch_receiver_stream_propagates_panics() {
714        let schema = schema();
715
716        let num_partitions = 10;
717        let input = PanicExec::new(Arc::clone(&schema), num_partitions);
718        consume(input, 10).await
719    }
720
721    #[tokio::test]
722    #[should_panic(expected = "PanickingStream did panic: 1")]
723    async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
724        let schema = schema();
725
726        // Make 2 partitions, second partition panics before the first
727        let num_partitions = 2;
728        let input = PanicExec::new(Arc::clone(&schema), num_partitions)
729            .with_partition_panic(0, 10)
730            .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
731
732        // Ensure that the panic results in an early shutdown (that
733        // everything stops after the first panic).
734
735        // Since the stream reads every other batch: (0,1,0,1,0,panic)
736        // so should not exceed 5 batches prior to the panic
737        let max_batches = 5;
738        consume(input, max_batches).await
739    }
740
741    #[tokio::test]
742    async fn record_batch_receiver_stream_drop_cancel() {
743        let task_ctx = Arc::new(TaskContext::default());
744        let schema = schema();
745
746        // Make an input that never proceeds
747        let input = BlockingExec::new(Arc::clone(&schema), 1);
748        let refs = input.refs();
749
750        // Configure a RecordBatchReceiverStream to consume the input
751        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
752        builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
753        let stream = builder.build();
754
755        // Input should still be present
756        assert!(std::sync::Weak::strong_count(&refs) > 0);
757
758        // Drop the stream, ensure the refs go to zero
759        drop(stream);
760        assert_strong_count_converges_to_zero(refs).await;
761    }
762
763    #[tokio::test]
764    /// Ensure that if an error is received in one stream, the
765    /// `RecordBatchReceiverStream` stops early and does not drive
766    /// other streams to completion.
767    async fn record_batch_receiver_stream_error_does_not_drive_completion() {
768        let task_ctx = Arc::new(TaskContext::default());
769        let schema = schema();
770
771        // make an input that will error twice
772        let error_stream = MockExec::new(
773            vec![exec_err!("Test1"), exec_err!("Test2")],
774            Arc::clone(&schema),
775        )
776        .with_use_task(false);
777
778        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
779        builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
780        let mut stream = builder.build();
781
782        // Get the first result, which should be an error
783        let first_batch = stream.next().await.unwrap();
784        let first_err = first_batch.unwrap_err();
785        assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
786
787        // There should be no more batches produced (should not get the second error)
788        assert!(stream.next().await.is_none());
789    }
790
791    #[tokio::test]
792    async fn batch_split_stream_basic_functionality() {
793        use arrow::array::{Int32Array, RecordBatch};
794        use futures::stream::{self, StreamExt};
795
796        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
797
798        // Create a large batch that should be split
799        let large_batch = RecordBatch::try_new(
800            Arc::clone(&schema),
801            vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
802        )
803        .unwrap();
804
805        // Create a stream with the large batch
806        let input_stream = stream::iter(vec![Ok(large_batch)]);
807        let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
808        let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
809
810        // Create a BatchSplitStream with batch_size = 500
811        let metrics = ExecutionPlanMetricsSet::new();
812        let split_metrics = SplitMetrics::new(&metrics, 0);
813        let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
814
815        let mut total_rows = 0;
816        let mut batch_count = 0;
817
818        while let Some(result) = split_stream.next().await {
819            let batch = result.unwrap();
820            assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
821            total_rows += batch.num_rows();
822            batch_count += 1;
823        }
824
825        assert_eq!(total_rows, 2000, "All rows should be preserved");
826        assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
827    }
828
829    /// Consumes all the input's partitions into a
830    /// RecordBatchReceiverStream and runs it to completion
831    ///
832    /// panic's if more than max_batches is seen,
833    async fn consume(input: PanicExec, max_batches: usize) {
834        let task_ctx = Arc::new(TaskContext::default());
835
836        let input = Arc::new(input);
837        let num_partitions = input.properties().output_partitioning().partition_count();
838
839        // Configure a RecordBatchReceiverStream to consume all the input partitions
840        let mut builder =
841            RecordBatchReceiverStream::builder(input.schema(), num_partitions);
842        for partition in 0..num_partitions {
843            builder.run_input(
844                Arc::clone(&input) as Arc<dyn ExecutionPlan>,
845                partition,
846                Arc::clone(&task_ctx),
847            );
848        }
849        let mut stream = builder.build();
850
851        // Drain the stream until it is complete, panic'ing on error
852        let mut num_batches = 0;
853        while let Some(next) = stream.next().await {
854            next.unwrap();
855            num_batches += 1;
856            assert!(
857                num_batches < max_batches,
858                "Got the limit of {num_batches} batches before seeing panic"
859            );
860        }
861    }
862
863    #[test]
864    fn record_batch_receiver_stream_builder_spawn_on_runtime() {
865        let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
866            .enable_all()
867            .build()
868            .unwrap();
869
870        let mut builder =
871            RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
872
873        let tx1 = builder.tx();
874        builder.spawn_on(
875            async move {
876                tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
877                    .await
878                    .unwrap();
879
880                Ok(())
881            },
882            tokio_runtime.handle(),
883        );
884
885        let tx2 = builder.tx();
886        builder.spawn_blocking_on(
887            move || {
888                tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
889                    .unwrap();
890
891                Ok(())
892            },
893            tokio_runtime.handle(),
894        );
895
896        let mut stream = builder.build();
897
898        let mut number_of_batches = 0;
899
900        loop {
901            let poll = stream.poll_next_unpin(&mut Context::from_waker(
902                futures::task::noop_waker_ref(),
903            ));
904
905            match poll {
906                Poll::Ready(None) => {
907                    break;
908                }
909                Poll::Ready(Some(Ok(batch))) => {
910                    number_of_batches += 1;
911                    assert_eq!(batch.num_rows(), 0);
912                }
913                Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
914                Poll::Pending => {
915                    continue;
916                }
917            }
918        }
919
920        assert_eq!(
921            number_of_batches, 2,
922            "Should have received exactly one empty batch"
923        );
924    }
925}