datafusion_physical_plan/stream.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30
31use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
32use datafusion_common::{exec_err, Result};
33use datafusion_common_runtime::JoinSet;
34use datafusion_execution::TaskContext;
35
36use futures::ready;
37use futures::stream::BoxStream;
38use futures::{Future, Stream, StreamExt};
39use log::debug;
40use pin_project_lite::pin_project;
41use tokio::runtime::Handle;
42use tokio::sync::mpsc::{Receiver, Sender};
43
44/// Creates a stream from a collection of producing tasks, routing panics to the stream.
45///
46/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being:
47///
48/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`).
49///
50/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
51///
52/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
53///
54/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
55pub(crate) struct ReceiverStreamBuilder<O> {
56 tx: Sender<Result<O>>,
57 rx: Receiver<Result<O>>,
58 join_set: JoinSet<Result<()>>,
59}
60
61impl<O: Send + 'static> ReceiverStreamBuilder<O> {
62 /// Create new channels with the specified buffer size
63 pub fn new(capacity: usize) -> Self {
64 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
65
66 Self {
67 tx,
68 rx,
69 join_set: JoinSet::new(),
70 }
71 }
72
73 /// Get a handle for sending data to the output
74 pub fn tx(&self) -> Sender<Result<O>> {
75 self.tx.clone()
76 }
77
78 /// Spawn task that will be aborted if this builder (or the stream
79 /// built from it) are dropped
80 pub fn spawn<F>(&mut self, task: F)
81 where
82 F: Future<Output = Result<()>>,
83 F: Send + 'static,
84 {
85 self.join_set.spawn(task);
86 }
87
88 /// Same as [`Self::spawn`] but it spawns the task on the provided runtime
89 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
90 where
91 F: Future<Output = Result<()>>,
92 F: Send + 'static,
93 {
94 self.join_set.spawn_on(task, handle);
95 }
96
97 /// Spawn a blocking task that will be aborted if this builder (or the stream
98 /// built from it) are dropped.
99 ///
100 /// This is often used to spawn tasks that write to the sender
101 /// retrieved from `Self::tx`.
102 pub fn spawn_blocking<F>(&mut self, f: F)
103 where
104 F: FnOnce() -> Result<()>,
105 F: Send + 'static,
106 {
107 self.join_set.spawn_blocking(f);
108 }
109
110 /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
111 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
112 where
113 F: FnOnce() -> Result<()>,
114 F: Send + 'static,
115 {
116 self.join_set.spawn_blocking_on(f, handle);
117 }
118
119 /// Create a stream of all data written to `tx`
120 pub fn build(self) -> BoxStream<'static, Result<O>> {
121 let Self {
122 tx,
123 rx,
124 mut join_set,
125 } = self;
126
127 // Doesn't need tx
128 drop(tx);
129
130 // future that checks the result of the join set, and propagates panic if seen
131 let check = async move {
132 while let Some(result) = join_set.join_next().await {
133 match result {
134 Ok(task_result) => {
135 match task_result {
136 // Nothing to report
137 Ok(_) => continue,
138 // This means a blocking task error
139 Err(error) => return Some(Err(error)),
140 }
141 }
142 // This means a tokio task error, likely a panic
143 Err(e) => {
144 if e.is_panic() {
145 // resume on the main thread
146 std::panic::resume_unwind(e.into_panic());
147 } else {
148 // This should only occur if the task is
149 // cancelled, which would only occur if
150 // the JoinSet were aborted, which in turn
151 // would imply that the receiver has been
152 // dropped and this code is not running
153 return Some(exec_err!("Non Panic Task error: {e}"));
154 }
155 }
156 }
157 }
158 None
159 };
160
161 let check_stream = futures::stream::once(check)
162 // unwrap Option / only return the error
163 .filter_map(|item| async move { item });
164
165 // Convert the receiver into a stream
166 let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
167 let next_item = rx.recv().await;
168 next_item.map(|next_item| (next_item, rx))
169 });
170
171 // Merge the streams together so whichever is ready first
172 // produces the batch
173 futures::stream::select(rx_stream, check_stream).boxed()
174 }
175}
176
177/// Builder for `RecordBatchReceiverStream` that propagates errors
178/// and panic's correctly.
179///
180/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
181/// that produce [`RecordBatch`]es and send them to a single
182/// `Receiver` which can improve parallelism.
183///
184/// This also handles propagating panic`s and canceling the tasks.
185///
186/// # Example
187///
188/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
189/// the `tx` end of the builder, after building the stream, we can receive
190/// those batches with calling `.next()`
191///
192/// ```
193/// # use std::sync::Arc;
194/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
195/// # use datafusion_common::arrow::array::RecordBatch;
196/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
197/// # use futures::stream::StreamExt;
198/// # use tokio::runtime::Builder;
199/// # let rt = Builder::new_current_thread().build().unwrap();
200/// #
201/// # rt.block_on(async {
202/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
203/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
204///
205/// // task 1
206/// let tx_1 = builder.tx();
207/// let schema_1 = Arc::clone(&schema);
208/// builder.spawn(async move {
209/// // Your task needs to send batches to the tx
210/// tx_1.send(Ok(RecordBatch::new_empty(schema_1)))
211/// .await
212/// .unwrap();
213///
214/// Ok(())
215/// });
216///
217/// // task 2
218/// let tx_2 = builder.tx();
219/// let schema_2 = Arc::clone(&schema);
220/// builder.spawn(async move {
221/// // Your task needs to send batches to the tx
222/// tx_2.send(Ok(RecordBatch::new_empty(schema_2)))
223/// .await
224/// .unwrap();
225///
226/// Ok(())
227/// });
228///
229/// let mut stream = builder.build();
230/// while let Some(res_batch) = stream.next().await {
231/// // `res_batch` can either from task 1 or 2
232///
233/// // do something with `res_batch`
234/// }
235/// # });
236/// ```
237pub struct RecordBatchReceiverStreamBuilder {
238 schema: SchemaRef,
239 inner: ReceiverStreamBuilder<RecordBatch>,
240}
241
242impl RecordBatchReceiverStreamBuilder {
243 /// Create new channels with the specified buffer size
244 pub fn new(schema: SchemaRef, capacity: usize) -> Self {
245 Self {
246 schema,
247 inner: ReceiverStreamBuilder::new(capacity),
248 }
249 }
250
251 /// Get a handle for sending [`RecordBatch`] to the output
252 ///
253 /// If the stream is dropped / canceled, the sender will be closed and
254 /// calling `tx().send()` will return an error. Producers should stop
255 /// producing in this case and return control.
256 pub fn tx(&self) -> Sender<Result<RecordBatch>> {
257 self.inner.tx()
258 }
259
260 /// Spawn task that will be aborted if this builder (or the stream
261 /// built from it) are dropped
262 ///
263 /// This is often used to spawn tasks that write to the sender
264 /// retrieved from [`Self::tx`], for examples, see the document
265 /// of this type.
266 pub fn spawn<F>(&mut self, task: F)
267 where
268 F: Future<Output = Result<()>>,
269 F: Send + 'static,
270 {
271 self.inner.spawn(task)
272 }
273
274 /// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
275 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
276 where
277 F: Future<Output = Result<()>>,
278 F: Send + 'static,
279 {
280 self.inner.spawn_on(task, handle)
281 }
282
283 /// Spawn a blocking task tied to the builder and stream.
284 ///
285 /// # Drop / Cancel Behavior
286 ///
287 /// If this builder (or the stream built from it) is dropped **before** the
288 /// task starts, the task is also dropped and will never start execute.
289 ///
290 /// **Note:** Once the blocking task has started, it **will not** be
291 /// forcibly stopped on drop as Rust does not allow forcing a running thread
292 /// to terminate. The task will continue running until it completes or
293 /// encounters an error.
294 ///
295 /// Users should ensure that their blocking function periodically checks for
296 /// errors calling `tx.blocking_send`. An error signals that the stream has
297 /// been dropped / cancelled and the blocking task should exit.
298 ///
299 /// This is often used to spawn tasks that write to the sender
300 /// retrieved from [`Self::tx`], for examples, see the document
301 /// of this type.
302 pub fn spawn_blocking<F>(&mut self, f: F)
303 where
304 F: FnOnce() -> Result<()>,
305 F: Send + 'static,
306 {
307 self.inner.spawn_blocking(f)
308 }
309
310 /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
311 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
312 where
313 F: FnOnce() -> Result<()>,
314 F: Send + 'static,
315 {
316 self.inner.spawn_blocking_on(f, handle)
317 }
318
319 /// Runs the `partition` of the `input` ExecutionPlan on the
320 /// tokio thread pool and writes its outputs to this stream
321 ///
322 /// If the input partition produces an error, the error will be
323 /// sent to the output stream and no further results are sent.
324 pub(crate) fn run_input(
325 &mut self,
326 input: Arc<dyn ExecutionPlan>,
327 partition: usize,
328 context: Arc<TaskContext>,
329 ) {
330 let output = self.tx();
331
332 self.inner.spawn(async move {
333 let mut stream = match input.execute(partition, context) {
334 Err(e) => {
335 // If send fails, the plan being torn down, there
336 // is no place to send the error and no reason to continue.
337 output.send(Err(e)).await.ok();
338 debug!(
339 "Stopping execution: error executing input: {}",
340 displayable(input.as_ref()).one_line()
341 );
342 return Ok(());
343 }
344 Ok(stream) => stream,
345 };
346
347 // Transfer batches from inner stream to the output tx
348 // immediately.
349 while let Some(item) = stream.next().await {
350 let is_err = item.is_err();
351
352 // If send fails, plan being torn down, there is no
353 // place to send the error and no reason to continue.
354 if output.send(item).await.is_err() {
355 debug!(
356 "Stopping execution: output is gone, plan cancelling: {}",
357 displayable(input.as_ref()).one_line()
358 );
359 return Ok(());
360 }
361
362 // Stop after the first error is encountered (Don't
363 // drive all streams to completion)
364 if is_err {
365 debug!(
366 "Stopping execution: plan returned error: {}",
367 displayable(input.as_ref()).one_line()
368 );
369 return Ok(());
370 }
371 }
372
373 Ok(())
374 });
375 }
376
377 /// Create a stream of all [`RecordBatch`] written to `tx`
378 pub fn build(self) -> SendableRecordBatchStream {
379 Box::pin(RecordBatchStreamAdapter::new(
380 self.schema,
381 self.inner.build(),
382 ))
383 }
384}
385
386#[doc(hidden)]
387pub struct RecordBatchReceiverStream {}
388
389impl RecordBatchReceiverStream {
390 /// Create a builder with an internal buffer of capacity batches.
391 pub fn builder(
392 schema: SchemaRef,
393 capacity: usize,
394 ) -> RecordBatchReceiverStreamBuilder {
395 RecordBatchReceiverStreamBuilder::new(schema, capacity)
396 }
397}
398
399pin_project! {
400 /// Combines a [`Stream`] with a [`SchemaRef`] implementing
401 /// [`SendableRecordBatchStream`] for the combination
402 ///
403 /// See [`Self::new`] for an example
404 pub struct RecordBatchStreamAdapter<S> {
405 schema: SchemaRef,
406
407 #[pin]
408 stream: S,
409 }
410}
411
412impl<S> RecordBatchStreamAdapter<S> {
413 /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
414 ///
415 /// Note to create a [`SendableRecordBatchStream`] you pin the result
416 ///
417 /// # Example
418 /// ```
419 /// # use arrow::array::record_batch;
420 /// # use datafusion_execution::SendableRecordBatchStream;
421 /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
422 /// // Create stream of Result<RecordBatch>
423 /// let batch = record_batch!(
424 /// ("a", Int32, [1, 2, 3]),
425 /// ("b", Float64, [Some(4.0), None, Some(5.0)])
426 /// )
427 /// .expect("created batch");
428 /// let schema = batch.schema();
429 /// let stream = futures::stream::iter(vec![Ok(batch)]);
430 /// // Convert the stream to a SendableRecordBatchStream
431 /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
432 /// // Now you can use the adapter as a SendableRecordBatchStream
433 /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
434 /// // ...
435 /// ```
436 pub fn new(schema: SchemaRef, stream: S) -> Self {
437 Self { schema, stream }
438 }
439}
440
441impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct("RecordBatchStreamAdapter")
444 .field("schema", &self.schema)
445 .finish()
446 }
447}
448
449impl<S> Stream for RecordBatchStreamAdapter<S>
450where
451 S: Stream<Item = Result<RecordBatch>>,
452{
453 type Item = Result<RecordBatch>;
454
455 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
456 self.project().stream.poll_next(cx)
457 }
458
459 fn size_hint(&self) -> (usize, Option<usize>) {
460 self.stream.size_hint()
461 }
462}
463
464impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
465where
466 S: Stream<Item = Result<RecordBatch>>,
467{
468 fn schema(&self) -> SchemaRef {
469 Arc::clone(&self.schema)
470 }
471}
472
473/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
474/// that will produce no results
475pub struct EmptyRecordBatchStream {
476 /// Schema wrapped by Arc
477 schema: SchemaRef,
478}
479
480impl EmptyRecordBatchStream {
481 /// Create an empty RecordBatchStream
482 pub fn new(schema: SchemaRef) -> Self {
483 Self { schema }
484 }
485}
486
487impl RecordBatchStream for EmptyRecordBatchStream {
488 fn schema(&self) -> SchemaRef {
489 Arc::clone(&self.schema)
490 }
491}
492
493impl Stream for EmptyRecordBatchStream {
494 type Item = Result<RecordBatch>;
495
496 fn poll_next(
497 self: Pin<&mut Self>,
498 _cx: &mut Context<'_>,
499 ) -> Poll<Option<Self::Item>> {
500 Poll::Ready(None)
501 }
502}
503
504/// Stream wrapper that records `BaselineMetrics` for a particular
505/// `[SendableRecordBatchStream]` (likely a partition)
506pub(crate) struct ObservedStream {
507 inner: SendableRecordBatchStream,
508 baseline_metrics: BaselineMetrics,
509 fetch: Option<usize>,
510 produced: usize,
511}
512
513impl ObservedStream {
514 pub fn new(
515 inner: SendableRecordBatchStream,
516 baseline_metrics: BaselineMetrics,
517 fetch: Option<usize>,
518 ) -> Self {
519 Self {
520 inner,
521 baseline_metrics,
522 fetch,
523 produced: 0,
524 }
525 }
526
527 fn limit_reached(
528 &mut self,
529 poll: Poll<Option<Result<RecordBatch>>>,
530 ) -> Poll<Option<Result<RecordBatch>>> {
531 let Some(fetch) = self.fetch else { return poll };
532
533 if self.produced >= fetch {
534 return Poll::Ready(None);
535 }
536
537 if let Poll::Ready(Some(Ok(batch))) = &poll {
538 if self.produced + batch.num_rows() > fetch {
539 let batch = batch.slice(0, fetch.saturating_sub(self.produced));
540 self.produced += batch.num_rows();
541 return Poll::Ready(Some(Ok(batch)));
542 };
543 self.produced += batch.num_rows()
544 }
545 poll
546 }
547}
548
549impl RecordBatchStream for ObservedStream {
550 fn schema(&self) -> SchemaRef {
551 self.inner.schema()
552 }
553}
554
555impl Stream for ObservedStream {
556 type Item = Result<RecordBatch>;
557
558 fn poll_next(
559 mut self: Pin<&mut Self>,
560 cx: &mut Context<'_>,
561 ) -> Poll<Option<Self::Item>> {
562 let mut poll = self.inner.poll_next_unpin(cx);
563 if self.fetch.is_some() {
564 poll = self.limit_reached(poll);
565 }
566 self.baseline_metrics.record_poll(poll)
567 }
568}
569
570pin_project! {
571 /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches.
572 ///
573 /// This ensures upstream operators receive batches no larger than
574 /// `batch_size`, which can improve parallelism when data sources
575 /// generate very large batches.
576 ///
577 /// # Fields
578 ///
579 /// - `current_batch`: The batch currently being split, if any
580 /// - `offset`: Index of the next row to split from `current_batch`.
581 /// This tracks our position within the current batch being split.
582 ///
583 /// # Invariants
584 ///
585 /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some`
586 /// - When `current_batch` is `None`, `offset` is always 0
587 /// - `batch_size` is always > 0
588pub struct BatchSplitStream {
589 #[pin]
590 input: SendableRecordBatchStream,
591 schema: SchemaRef,
592 batch_size: usize,
593 metrics: SplitMetrics,
594 current_batch: Option<RecordBatch>,
595 offset: usize,
596 }
597}
598
599impl BatchSplitStream {
600 /// Create a new [`BatchSplitStream`]
601 pub fn new(
602 input: SendableRecordBatchStream,
603 batch_size: usize,
604 metrics: SplitMetrics,
605 ) -> Self {
606 let schema = input.schema();
607 Self {
608 input,
609 schema,
610 batch_size,
611 metrics,
612 current_batch: None,
613 offset: 0,
614 }
615 }
616
617 /// Attempt to produce the next sliced batch from the current batch.
618 ///
619 /// Returns `Some(batch)` if a slice was produced, `None` if the current batch
620 /// is exhausted and we need to poll upstream for more data.
621 fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
622 let batch = self.current_batch.take()?;
623
624 // Assert slice boundary safety - offset should never exceed batch size
625 debug_assert!(
626 self.offset <= batch.num_rows(),
627 "Offset {} exceeds batch size {}",
628 self.offset,
629 batch.num_rows()
630 );
631
632 let remaining = batch.num_rows() - self.offset;
633 let to_take = remaining.min(self.batch_size);
634 let out = batch.slice(self.offset, to_take);
635
636 self.metrics.batches_split.add(1);
637 self.offset += to_take;
638 if self.offset < batch.num_rows() {
639 // More data remains in this batch, store it back
640 self.current_batch = Some(batch);
641 } else {
642 // Batch is exhausted, reset offset
643 // Note: current_batch is already None since we took it at the start
644 self.offset = 0;
645 }
646 Some(Ok(out))
647 }
648
649 /// Poll the upstream input for the next batch.
650 ///
651 /// Returns the appropriate `Poll` result based on upstream state.
652 /// Small batches are passed through directly, large batches are stored
653 /// for slicing and return the first slice immediately.
654 fn poll_upstream(
655 &mut self,
656 cx: &mut Context<'_>,
657 ) -> Poll<Option<Result<RecordBatch>>> {
658 match ready!(self.input.as_mut().poll_next(cx)) {
659 Some(Ok(batch)) => {
660 if batch.num_rows() <= self.batch_size {
661 // Small batch, pass through directly
662 Poll::Ready(Some(Ok(batch)))
663 } else {
664 // Large batch, store for slicing and return first slice
665 self.current_batch = Some(batch);
666 // Immediately produce the first slice
667 match self.next_sliced_batch() {
668 Some(result) => Poll::Ready(Some(result)),
669 None => Poll::Ready(None), // Should not happen
670 }
671 }
672 }
673 Some(Err(e)) => Poll::Ready(Some(Err(e))),
674 None => Poll::Ready(None),
675 }
676 }
677}
678
679impl Stream for BatchSplitStream {
680 type Item = Result<RecordBatch>;
681
682 fn poll_next(
683 mut self: Pin<&mut Self>,
684 cx: &mut Context<'_>,
685 ) -> Poll<Option<Self::Item>> {
686 // First, try to produce a slice from the current batch
687 if let Some(result) = self.next_sliced_batch() {
688 return Poll::Ready(Some(result));
689 }
690
691 // No current batch or current batch exhausted, poll upstream
692 self.poll_upstream(cx)
693 }
694}
695
696impl RecordBatchStream for BatchSplitStream {
697 fn schema(&self) -> SchemaRef {
698 Arc::clone(&self.schema)
699 }
700}
701
702#[cfg(test)]
703mod test {
704 use super::*;
705 use crate::test::exec::{
706 assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
707 };
708
709 use arrow::datatypes::{DataType, Field, Schema};
710 use datafusion_common::exec_err;
711
712 fn schema() -> SchemaRef {
713 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
714 }
715
716 #[tokio::test]
717 #[should_panic(expected = "PanickingStream did panic")]
718 async fn record_batch_receiver_stream_propagates_panics() {
719 let schema = schema();
720
721 let num_partitions = 10;
722 let input = PanicExec::new(Arc::clone(&schema), num_partitions);
723 consume(input, 10).await
724 }
725
726 #[tokio::test]
727 #[should_panic(expected = "PanickingStream did panic: 1")]
728 async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
729 let schema = schema();
730
731 // Make 2 partitions, second partition panics before the first
732 let num_partitions = 2;
733 let input = PanicExec::new(Arc::clone(&schema), num_partitions)
734 .with_partition_panic(0, 10)
735 .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
736
737 // Ensure that the panic results in an early shutdown (that
738 // everything stops after the first panic).
739
740 // Since the stream reads every other batch: (0,1,0,1,0,panic)
741 // so should not exceed 5 batches prior to the panic
742 let max_batches = 5;
743 consume(input, max_batches).await
744 }
745
746 #[tokio::test]
747 async fn record_batch_receiver_stream_drop_cancel() {
748 let task_ctx = Arc::new(TaskContext::default());
749 let schema = schema();
750
751 // Make an input that never proceeds
752 let input = BlockingExec::new(Arc::clone(&schema), 1);
753 let refs = input.refs();
754
755 // Configure a RecordBatchReceiverStream to consume the input
756 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
757 builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
758 let stream = builder.build();
759
760 // Input should still be present
761 assert!(std::sync::Weak::strong_count(&refs) > 0);
762
763 // Drop the stream, ensure the refs go to zero
764 drop(stream);
765 assert_strong_count_converges_to_zero(refs).await;
766 }
767
768 #[tokio::test]
769 /// Ensure that if an error is received in one stream, the
770 /// `RecordBatchReceiverStream` stops early and does not drive
771 /// other streams to completion.
772 async fn record_batch_receiver_stream_error_does_not_drive_completion() {
773 let task_ctx = Arc::new(TaskContext::default());
774 let schema = schema();
775
776 // make an input that will error twice
777 let error_stream = MockExec::new(
778 vec![exec_err!("Test1"), exec_err!("Test2")],
779 Arc::clone(&schema),
780 )
781 .with_use_task(false);
782
783 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
784 builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
785 let mut stream = builder.build();
786
787 // Get the first result, which should be an error
788 let first_batch = stream.next().await.unwrap();
789 let first_err = first_batch.unwrap_err();
790 assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
791
792 // There should be no more batches produced (should not get the second error)
793 assert!(stream.next().await.is_none());
794 }
795
796 #[tokio::test]
797 async fn batch_split_stream_basic_functionality() {
798 use arrow::array::{Int32Array, RecordBatch};
799 use futures::stream::{self, StreamExt};
800
801 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
802
803 // Create a large batch that should be split
804 let large_batch = RecordBatch::try_new(
805 Arc::clone(&schema),
806 vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
807 )
808 .unwrap();
809
810 // Create a stream with the large batch
811 let input_stream = stream::iter(vec![Ok(large_batch)]);
812 let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
813 let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
814
815 // Create a BatchSplitStream with batch_size = 500
816 let metrics = ExecutionPlanMetricsSet::new();
817 let split_metrics = SplitMetrics::new(&metrics, 0);
818 let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
819
820 let mut total_rows = 0;
821 let mut batch_count = 0;
822
823 while let Some(result) = split_stream.next().await {
824 let batch = result.unwrap();
825 assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
826 total_rows += batch.num_rows();
827 batch_count += 1;
828 }
829
830 assert_eq!(total_rows, 2000, "All rows should be preserved");
831 assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
832 }
833
834 /// Consumes all the input's partitions into a
835 /// RecordBatchReceiverStream and runs it to completion
836 ///
837 /// panic's if more than max_batches is seen,
838 async fn consume(input: PanicExec, max_batches: usize) {
839 let task_ctx = Arc::new(TaskContext::default());
840
841 let input = Arc::new(input);
842 let num_partitions = input.properties().output_partitioning().partition_count();
843
844 // Configure a RecordBatchReceiverStream to consume all the input partitions
845 let mut builder =
846 RecordBatchReceiverStream::builder(input.schema(), num_partitions);
847 for partition in 0..num_partitions {
848 builder.run_input(
849 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
850 partition,
851 Arc::clone(&task_ctx),
852 );
853 }
854 let mut stream = builder.build();
855
856 // Drain the stream until it is complete, panic'ing on error
857 let mut num_batches = 0;
858 while let Some(next) = stream.next().await {
859 next.unwrap();
860 num_batches += 1;
861 assert!(
862 num_batches < max_batches,
863 "Got the limit of {num_batches} batches before seeing panic"
864 );
865 }
866 }
867
868 #[test]
869 fn record_batch_receiver_stream_builder_spawn_on_runtime() {
870 let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
871 .enable_all()
872 .build()
873 .unwrap();
874
875 let mut builder =
876 RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
877
878 let tx1 = builder.tx();
879 builder.spawn_on(
880 async move {
881 tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
882 .await
883 .unwrap();
884
885 Ok(())
886 },
887 tokio_runtime.handle(),
888 );
889
890 let tx2 = builder.tx();
891 builder.spawn_blocking_on(
892 move || {
893 tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
894 .unwrap();
895
896 Ok(())
897 },
898 tokio_runtime.handle(),
899 );
900
901 let mut stream = builder.build();
902
903 let mut number_of_batches = 0;
904
905 loop {
906 let poll = stream.poll_next_unpin(&mut Context::from_waker(
907 futures::task::noop_waker_ref(),
908 ));
909
910 match poll {
911 Poll::Ready(None) => {
912 break;
913 }
914 Poll::Ready(Some(Ok(batch))) => {
915 number_of_batches += 1;
916 assert_eq!(batch.num_rows(), 0);
917 }
918 Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
919 Poll::Pending => {
920 continue;
921 }
922 }
923 }
924
925 assert_eq!(
926 number_of_batches, 2,
927 "Should have received exactly one empty batch"
928 );
929 }
930}