datafusion_physical_plan/stream.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30
31use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
32use datafusion_common::{exec_err, Result};
33use datafusion_common_runtime::JoinSet;
34use datafusion_execution::TaskContext;
35
36use futures::ready;
37use futures::stream::BoxStream;
38use futures::{Future, Stream, StreamExt};
39use log::debug;
40use pin_project_lite::pin_project;
41use tokio::runtime::Handle;
42use tokio::sync::mpsc::{Receiver, Sender};
43
44/// Creates a stream from a collection of producing tasks, routing panics to the stream.
45///
46/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being:
47///
48/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`).
49///
50/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
51///
52/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
53///
54/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
55pub(crate) struct ReceiverStreamBuilder<O> {
56 tx: Sender<Result<O>>,
57 rx: Receiver<Result<O>>,
58 join_set: JoinSet<Result<()>>,
59}
60
61impl<O: Send + 'static> ReceiverStreamBuilder<O> {
62 /// Create new channels with the specified buffer size
63 pub fn new(capacity: usize) -> Self {
64 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
65
66 Self {
67 tx,
68 rx,
69 join_set: JoinSet::new(),
70 }
71 }
72
73 /// Get a handle for sending data to the output
74 pub fn tx(&self) -> Sender<Result<O>> {
75 self.tx.clone()
76 }
77
78 /// Spawn task that will be aborted if this builder (or the stream
79 /// built from it) are dropped
80 pub fn spawn<F>(&mut self, task: F)
81 where
82 F: Future<Output = Result<()>>,
83 F: Send + 'static,
84 {
85 self.join_set.spawn(task);
86 }
87
88 /// Same as [`Self::spawn`] but it spawns the task on the provided runtime
89 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
90 where
91 F: Future<Output = Result<()>>,
92 F: Send + 'static,
93 {
94 self.join_set.spawn_on(task, handle);
95 }
96
97 /// Spawn a blocking task that will be aborted if this builder (or the stream
98 /// built from it) are dropped.
99 ///
100 /// This is often used to spawn tasks that write to the sender
101 /// retrieved from `Self::tx`.
102 pub fn spawn_blocking<F>(&mut self, f: F)
103 where
104 F: FnOnce() -> Result<()>,
105 F: Send + 'static,
106 {
107 self.join_set.spawn_blocking(f);
108 }
109
110 /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
111 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
112 where
113 F: FnOnce() -> Result<()>,
114 F: Send + 'static,
115 {
116 self.join_set.spawn_blocking_on(f, handle);
117 }
118
119 /// Create a stream of all data written to `tx`
120 pub fn build(self) -> BoxStream<'static, Result<O>> {
121 let Self {
122 tx,
123 rx,
124 mut join_set,
125 } = self;
126
127 // Doesn't need tx
128 drop(tx);
129
130 // future that checks the result of the join set, and propagates panic if seen
131 let check = async move {
132 while let Some(result) = join_set.join_next().await {
133 match result {
134 Ok(task_result) => {
135 match task_result {
136 // Nothing to report
137 Ok(_) => continue,
138 // This means a blocking task error
139 Err(error) => return Some(Err(error)),
140 }
141 }
142 // This means a tokio task error, likely a panic
143 Err(e) => {
144 if e.is_panic() {
145 // resume on the main thread
146 std::panic::resume_unwind(e.into_panic());
147 } else {
148 // This should only occur if the task is
149 // cancelled, which would only occur if
150 // the JoinSet were aborted, which in turn
151 // would imply that the receiver has been
152 // dropped and this code is not running
153 return Some(exec_err!("Non Panic Task error: {e}"));
154 }
155 }
156 }
157 }
158 None
159 };
160
161 let check_stream = futures::stream::once(check)
162 // unwrap Option / only return the error
163 .filter_map(|item| async move { item });
164
165 // Convert the receiver into a stream
166 let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
167 let next_item = rx.recv().await;
168 next_item.map(|next_item| (next_item, rx))
169 });
170
171 // Merge the streams together so whichever is ready first
172 // produces the batch
173 futures::stream::select(rx_stream, check_stream).boxed()
174 }
175}
176
177/// Builder for `RecordBatchReceiverStream` that propagates errors
178/// and panic's correctly.
179///
180/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
181/// that produce [`RecordBatch`]es and send them to a single
182/// `Receiver` which can improve parallelism.
183///
184/// This also handles propagating panic`s and canceling the tasks.
185///
186/// # Example
187///
188/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
189/// the `tx` end of the builder, after building the stream, we can receive
190/// those batches with calling `.next()`
191///
192/// ```
193/// # use std::sync::Arc;
194/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
195/// # use datafusion_common::arrow::array::RecordBatch;
196/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
197/// # use futures::stream::StreamExt;
198/// # use tokio::runtime::Builder;
199/// # let rt = Builder::new_current_thread().build().unwrap();
200/// #
201/// # rt.block_on(async {
202/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
203/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
204///
205/// // task 1
206/// let tx_1 = builder.tx();
207/// let schema_1 = Arc::clone(&schema);
208/// builder.spawn(async move {
209/// // Your task needs to send batches to the tx
210/// tx_1.send(Ok(RecordBatch::new_empty(schema_1))).await.unwrap();
211///
212/// Ok(())
213/// });
214///
215/// // task 2
216/// let tx_2 = builder.tx();
217/// let schema_2 = Arc::clone(&schema);
218/// builder.spawn(async move {
219/// // Your task needs to send batches to the tx
220/// tx_2.send(Ok(RecordBatch::new_empty(schema_2))).await.unwrap();
221///
222/// Ok(())
223/// });
224///
225/// let mut stream = builder.build();
226/// while let Some(res_batch) = stream.next().await {
227/// // `res_batch` can either from task 1 or 2
228///
229/// // do something with `res_batch`
230/// }
231/// # });
232/// ```
233pub struct RecordBatchReceiverStreamBuilder {
234 schema: SchemaRef,
235 inner: ReceiverStreamBuilder<RecordBatch>,
236}
237
238impl RecordBatchReceiverStreamBuilder {
239 /// Create new channels with the specified buffer size
240 pub fn new(schema: SchemaRef, capacity: usize) -> Self {
241 Self {
242 schema,
243 inner: ReceiverStreamBuilder::new(capacity),
244 }
245 }
246
247 /// Get a handle for sending [`RecordBatch`] to the output
248 ///
249 /// If the stream is dropped / canceled, the sender will be closed and
250 /// calling `tx().send()` will return an error. Producers should stop
251 /// producing in this case and return control.
252 pub fn tx(&self) -> Sender<Result<RecordBatch>> {
253 self.inner.tx()
254 }
255
256 /// Spawn task that will be aborted if this builder (or the stream
257 /// built from it) are dropped
258 ///
259 /// This is often used to spawn tasks that write to the sender
260 /// retrieved from [`Self::tx`], for examples, see the document
261 /// of this type.
262 pub fn spawn<F>(&mut self, task: F)
263 where
264 F: Future<Output = Result<()>>,
265 F: Send + 'static,
266 {
267 self.inner.spawn(task)
268 }
269
270 /// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
271 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
272 where
273 F: Future<Output = Result<()>>,
274 F: Send + 'static,
275 {
276 self.inner.spawn_on(task, handle)
277 }
278
279 /// Spawn a blocking task tied to the builder and stream.
280 ///
281 /// # Drop / Cancel Behavior
282 ///
283 /// If this builder (or the stream built from it) is dropped **before** the
284 /// task starts, the task is also dropped and will never start execute.
285 ///
286 /// **Note:** Once the blocking task has started, it **will not** be
287 /// forcibly stopped on drop as Rust does not allow forcing a running thread
288 /// to terminate. The task will continue running until it completes or
289 /// encounters an error.
290 ///
291 /// Users should ensure that their blocking function periodically checks for
292 /// errors calling `tx.blocking_send`. An error signals that the stream has
293 /// been dropped / cancelled and the blocking task should exit.
294 ///
295 /// This is often used to spawn tasks that write to the sender
296 /// retrieved from [`Self::tx`], for examples, see the document
297 /// of this type.
298 pub fn spawn_blocking<F>(&mut self, f: F)
299 where
300 F: FnOnce() -> Result<()>,
301 F: Send + 'static,
302 {
303 self.inner.spawn_blocking(f)
304 }
305
306 /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
307 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
308 where
309 F: FnOnce() -> Result<()>,
310 F: Send + 'static,
311 {
312 self.inner.spawn_blocking_on(f, handle)
313 }
314
315 /// Runs the `partition` of the `input` ExecutionPlan on the
316 /// tokio thread pool and writes its outputs to this stream
317 ///
318 /// If the input partition produces an error, the error will be
319 /// sent to the output stream and no further results are sent.
320 pub(crate) fn run_input(
321 &mut self,
322 input: Arc<dyn ExecutionPlan>,
323 partition: usize,
324 context: Arc<TaskContext>,
325 ) {
326 let output = self.tx();
327
328 self.inner.spawn(async move {
329 let mut stream = match input.execute(partition, context) {
330 Err(e) => {
331 // If send fails, the plan being torn down, there
332 // is no place to send the error and no reason to continue.
333 output.send(Err(e)).await.ok();
334 debug!(
335 "Stopping execution: error executing input: {}",
336 displayable(input.as_ref()).one_line()
337 );
338 return Ok(());
339 }
340 Ok(stream) => stream,
341 };
342
343 // Transfer batches from inner stream to the output tx
344 // immediately.
345 while let Some(item) = stream.next().await {
346 let is_err = item.is_err();
347
348 // If send fails, plan being torn down, there is no
349 // place to send the error and no reason to continue.
350 if output.send(item).await.is_err() {
351 debug!(
352 "Stopping execution: output is gone, plan cancelling: {}",
353 displayable(input.as_ref()).one_line()
354 );
355 return Ok(());
356 }
357
358 // Stop after the first error is encountered (Don't
359 // drive all streams to completion)
360 if is_err {
361 debug!(
362 "Stopping execution: plan returned error: {}",
363 displayable(input.as_ref()).one_line()
364 );
365 return Ok(());
366 }
367 }
368
369 Ok(())
370 });
371 }
372
373 /// Create a stream of all [`RecordBatch`] written to `tx`
374 pub fn build(self) -> SendableRecordBatchStream {
375 Box::pin(RecordBatchStreamAdapter::new(
376 self.schema,
377 self.inner.build(),
378 ))
379 }
380}
381
382#[doc(hidden)]
383pub struct RecordBatchReceiverStream {}
384
385impl RecordBatchReceiverStream {
386 /// Create a builder with an internal buffer of capacity batches.
387 pub fn builder(
388 schema: SchemaRef,
389 capacity: usize,
390 ) -> RecordBatchReceiverStreamBuilder {
391 RecordBatchReceiverStreamBuilder::new(schema, capacity)
392 }
393}
394
395pin_project! {
396 /// Combines a [`Stream`] with a [`SchemaRef`] implementing
397 /// [`SendableRecordBatchStream`] for the combination
398 ///
399 /// See [`Self::new`] for an example
400 pub struct RecordBatchStreamAdapter<S> {
401 schema: SchemaRef,
402
403 #[pin]
404 stream: S,
405 }
406}
407
408impl<S> RecordBatchStreamAdapter<S> {
409 /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
410 ///
411 /// Note to create a [`SendableRecordBatchStream`] you pin the result
412 ///
413 /// # Example
414 /// ```
415 /// # use arrow::array::record_batch;
416 /// # use datafusion_execution::SendableRecordBatchStream;
417 /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
418 /// // Create stream of Result<RecordBatch>
419 /// let batch = record_batch!(
420 /// ("a", Int32, [1, 2, 3]),
421 /// ("b", Float64, [Some(4.0), None, Some(5.0)])
422 /// ).expect("created batch");
423 /// let schema = batch.schema();
424 /// let stream = futures::stream::iter(vec![Ok(batch)]);
425 /// // Convert the stream to a SendableRecordBatchStream
426 /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
427 /// // Now you can use the adapter as a SendableRecordBatchStream
428 /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
429 /// // ...
430 /// ```
431 pub fn new(schema: SchemaRef, stream: S) -> Self {
432 Self { schema, stream }
433 }
434}
435
436impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
437 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438 f.debug_struct("RecordBatchStreamAdapter")
439 .field("schema", &self.schema)
440 .finish()
441 }
442}
443
444impl<S> Stream for RecordBatchStreamAdapter<S>
445where
446 S: Stream<Item = Result<RecordBatch>>,
447{
448 type Item = Result<RecordBatch>;
449
450 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
451 self.project().stream.poll_next(cx)
452 }
453
454 fn size_hint(&self) -> (usize, Option<usize>) {
455 self.stream.size_hint()
456 }
457}
458
459impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
460where
461 S: Stream<Item = Result<RecordBatch>>,
462{
463 fn schema(&self) -> SchemaRef {
464 Arc::clone(&self.schema)
465 }
466}
467
468/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
469/// that will produce no results
470pub struct EmptyRecordBatchStream {
471 /// Schema wrapped by Arc
472 schema: SchemaRef,
473}
474
475impl EmptyRecordBatchStream {
476 /// Create an empty RecordBatchStream
477 pub fn new(schema: SchemaRef) -> Self {
478 Self { schema }
479 }
480}
481
482impl RecordBatchStream for EmptyRecordBatchStream {
483 fn schema(&self) -> SchemaRef {
484 Arc::clone(&self.schema)
485 }
486}
487
488impl Stream for EmptyRecordBatchStream {
489 type Item = Result<RecordBatch>;
490
491 fn poll_next(
492 self: Pin<&mut Self>,
493 _cx: &mut Context<'_>,
494 ) -> Poll<Option<Self::Item>> {
495 Poll::Ready(None)
496 }
497}
498
499/// Stream wrapper that records `BaselineMetrics` for a particular
500/// `[SendableRecordBatchStream]` (likely a partition)
501pub(crate) struct ObservedStream {
502 inner: SendableRecordBatchStream,
503 baseline_metrics: BaselineMetrics,
504 fetch: Option<usize>,
505 produced: usize,
506}
507
508impl ObservedStream {
509 pub fn new(
510 inner: SendableRecordBatchStream,
511 baseline_metrics: BaselineMetrics,
512 fetch: Option<usize>,
513 ) -> Self {
514 Self {
515 inner,
516 baseline_metrics,
517 fetch,
518 produced: 0,
519 }
520 }
521
522 fn limit_reached(
523 &mut self,
524 poll: Poll<Option<Result<RecordBatch>>>,
525 ) -> Poll<Option<Result<RecordBatch>>> {
526 let Some(fetch) = self.fetch else { return poll };
527
528 if self.produced >= fetch {
529 return Poll::Ready(None);
530 }
531
532 if let Poll::Ready(Some(Ok(batch))) = &poll {
533 if self.produced + batch.num_rows() > fetch {
534 let batch = batch.slice(0, fetch.saturating_sub(self.produced));
535 self.produced += batch.num_rows();
536 return Poll::Ready(Some(Ok(batch)));
537 };
538 self.produced += batch.num_rows()
539 }
540 poll
541 }
542}
543
544impl RecordBatchStream for ObservedStream {
545 fn schema(&self) -> SchemaRef {
546 self.inner.schema()
547 }
548}
549
550impl Stream for ObservedStream {
551 type Item = Result<RecordBatch>;
552
553 fn poll_next(
554 mut self: Pin<&mut Self>,
555 cx: &mut Context<'_>,
556 ) -> Poll<Option<Self::Item>> {
557 let mut poll = self.inner.poll_next_unpin(cx);
558 if self.fetch.is_some() {
559 poll = self.limit_reached(poll);
560 }
561 self.baseline_metrics.record_poll(poll)
562 }
563}
564
565pin_project! {
566 /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches.
567 ///
568 /// This ensures upstream operators receive batches no larger than
569 /// `batch_size`, which can improve parallelism when data sources
570 /// generate very large batches.
571 ///
572 /// # Fields
573 ///
574 /// - `current_batch`: The batch currently being split, if any
575 /// - `offset`: Index of the next row to split from `current_batch`.
576 /// This tracks our position within the current batch being split.
577 ///
578 /// # Invariants
579 ///
580 /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some`
581 /// - When `current_batch` is `None`, `offset` is always 0
582 /// - `batch_size` is always > 0
583pub struct BatchSplitStream {
584 #[pin]
585 input: SendableRecordBatchStream,
586 schema: SchemaRef,
587 batch_size: usize,
588 metrics: SplitMetrics,
589 current_batch: Option<RecordBatch>,
590 offset: usize,
591 }
592}
593
594impl BatchSplitStream {
595 /// Create a new [`BatchSplitStream`]
596 pub fn new(
597 input: SendableRecordBatchStream,
598 batch_size: usize,
599 metrics: SplitMetrics,
600 ) -> Self {
601 let schema = input.schema();
602 Self {
603 input,
604 schema,
605 batch_size,
606 metrics,
607 current_batch: None,
608 offset: 0,
609 }
610 }
611
612 /// Attempt to produce the next sliced batch from the current batch.
613 ///
614 /// Returns `Some(batch)` if a slice was produced, `None` if the current batch
615 /// is exhausted and we need to poll upstream for more data.
616 fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
617 let batch = self.current_batch.take()?;
618
619 // Assert slice boundary safety - offset should never exceed batch size
620 debug_assert!(
621 self.offset <= batch.num_rows(),
622 "Offset {} exceeds batch size {}",
623 self.offset,
624 batch.num_rows()
625 );
626
627 let remaining = batch.num_rows() - self.offset;
628 let to_take = remaining.min(self.batch_size);
629 let out = batch.slice(self.offset, to_take);
630
631 self.metrics.batches_split.add(1);
632 self.offset += to_take;
633 if self.offset < batch.num_rows() {
634 // More data remains in this batch, store it back
635 self.current_batch = Some(batch);
636 } else {
637 // Batch is exhausted, reset offset
638 // Note: current_batch is already None since we took it at the start
639 self.offset = 0;
640 }
641 Some(Ok(out))
642 }
643
644 /// Poll the upstream input for the next batch.
645 ///
646 /// Returns the appropriate `Poll` result based on upstream state.
647 /// Small batches are passed through directly, large batches are stored
648 /// for slicing and return the first slice immediately.
649 fn poll_upstream(
650 &mut self,
651 cx: &mut Context<'_>,
652 ) -> Poll<Option<Result<RecordBatch>>> {
653 match ready!(self.input.as_mut().poll_next(cx)) {
654 Some(Ok(batch)) => {
655 if batch.num_rows() <= self.batch_size {
656 // Small batch, pass through directly
657 Poll::Ready(Some(Ok(batch)))
658 } else {
659 // Large batch, store for slicing and return first slice
660 self.current_batch = Some(batch);
661 // Immediately produce the first slice
662 match self.next_sliced_batch() {
663 Some(result) => Poll::Ready(Some(result)),
664 None => Poll::Ready(None), // Should not happen
665 }
666 }
667 }
668 Some(Err(e)) => Poll::Ready(Some(Err(e))),
669 None => Poll::Ready(None),
670 }
671 }
672}
673
674impl Stream for BatchSplitStream {
675 type Item = Result<RecordBatch>;
676
677 fn poll_next(
678 mut self: Pin<&mut Self>,
679 cx: &mut Context<'_>,
680 ) -> Poll<Option<Self::Item>> {
681 // First, try to produce a slice from the current batch
682 if let Some(result) = self.next_sliced_batch() {
683 return Poll::Ready(Some(result));
684 }
685
686 // No current batch or current batch exhausted, poll upstream
687 self.poll_upstream(cx)
688 }
689}
690
691impl RecordBatchStream for BatchSplitStream {
692 fn schema(&self) -> SchemaRef {
693 Arc::clone(&self.schema)
694 }
695}
696
697#[cfg(test)]
698mod test {
699 use super::*;
700 use crate::test::exec::{
701 assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
702 };
703
704 use arrow::datatypes::{DataType, Field, Schema};
705 use datafusion_common::exec_err;
706
707 fn schema() -> SchemaRef {
708 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
709 }
710
711 #[tokio::test]
712 #[should_panic(expected = "PanickingStream did panic")]
713 async fn record_batch_receiver_stream_propagates_panics() {
714 let schema = schema();
715
716 let num_partitions = 10;
717 let input = PanicExec::new(Arc::clone(&schema), num_partitions);
718 consume(input, 10).await
719 }
720
721 #[tokio::test]
722 #[should_panic(expected = "PanickingStream did panic: 1")]
723 async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
724 let schema = schema();
725
726 // Make 2 partitions, second partition panics before the first
727 let num_partitions = 2;
728 let input = PanicExec::new(Arc::clone(&schema), num_partitions)
729 .with_partition_panic(0, 10)
730 .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
731
732 // Ensure that the panic results in an early shutdown (that
733 // everything stops after the first panic).
734
735 // Since the stream reads every other batch: (0,1,0,1,0,panic)
736 // so should not exceed 5 batches prior to the panic
737 let max_batches = 5;
738 consume(input, max_batches).await
739 }
740
741 #[tokio::test]
742 async fn record_batch_receiver_stream_drop_cancel() {
743 let task_ctx = Arc::new(TaskContext::default());
744 let schema = schema();
745
746 // Make an input that never proceeds
747 let input = BlockingExec::new(Arc::clone(&schema), 1);
748 let refs = input.refs();
749
750 // Configure a RecordBatchReceiverStream to consume the input
751 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
752 builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
753 let stream = builder.build();
754
755 // Input should still be present
756 assert!(std::sync::Weak::strong_count(&refs) > 0);
757
758 // Drop the stream, ensure the refs go to zero
759 drop(stream);
760 assert_strong_count_converges_to_zero(refs).await;
761 }
762
763 #[tokio::test]
764 /// Ensure that if an error is received in one stream, the
765 /// `RecordBatchReceiverStream` stops early and does not drive
766 /// other streams to completion.
767 async fn record_batch_receiver_stream_error_does_not_drive_completion() {
768 let task_ctx = Arc::new(TaskContext::default());
769 let schema = schema();
770
771 // make an input that will error twice
772 let error_stream = MockExec::new(
773 vec![exec_err!("Test1"), exec_err!("Test2")],
774 Arc::clone(&schema),
775 )
776 .with_use_task(false);
777
778 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
779 builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
780 let mut stream = builder.build();
781
782 // Get the first result, which should be an error
783 let first_batch = stream.next().await.unwrap();
784 let first_err = first_batch.unwrap_err();
785 assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
786
787 // There should be no more batches produced (should not get the second error)
788 assert!(stream.next().await.is_none());
789 }
790
791 #[tokio::test]
792 async fn batch_split_stream_basic_functionality() {
793 use arrow::array::{Int32Array, RecordBatch};
794 use futures::stream::{self, StreamExt};
795
796 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
797
798 // Create a large batch that should be split
799 let large_batch = RecordBatch::try_new(
800 Arc::clone(&schema),
801 vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
802 )
803 .unwrap();
804
805 // Create a stream with the large batch
806 let input_stream = stream::iter(vec![Ok(large_batch)]);
807 let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
808 let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
809
810 // Create a BatchSplitStream with batch_size = 500
811 let metrics = ExecutionPlanMetricsSet::new();
812 let split_metrics = SplitMetrics::new(&metrics, 0);
813 let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
814
815 let mut total_rows = 0;
816 let mut batch_count = 0;
817
818 while let Some(result) = split_stream.next().await {
819 let batch = result.unwrap();
820 assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
821 total_rows += batch.num_rows();
822 batch_count += 1;
823 }
824
825 assert_eq!(total_rows, 2000, "All rows should be preserved");
826 assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
827 }
828
829 /// Consumes all the input's partitions into a
830 /// RecordBatchReceiverStream and runs it to completion
831 ///
832 /// panic's if more than max_batches is seen,
833 async fn consume(input: PanicExec, max_batches: usize) {
834 let task_ctx = Arc::new(TaskContext::default());
835
836 let input = Arc::new(input);
837 let num_partitions = input.properties().output_partitioning().partition_count();
838
839 // Configure a RecordBatchReceiverStream to consume all the input partitions
840 let mut builder =
841 RecordBatchReceiverStream::builder(input.schema(), num_partitions);
842 for partition in 0..num_partitions {
843 builder.run_input(
844 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
845 partition,
846 Arc::clone(&task_ctx),
847 );
848 }
849 let mut stream = builder.build();
850
851 // Drain the stream until it is complete, panic'ing on error
852 let mut num_batches = 0;
853 while let Some(next) = stream.next().await {
854 next.unwrap();
855 num_batches += 1;
856 assert!(
857 num_batches < max_batches,
858 "Got the limit of {num_batches} batches before seeing panic"
859 );
860 }
861 }
862
863 #[test]
864 fn record_batch_receiver_stream_builder_spawn_on_runtime() {
865 let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
866 .enable_all()
867 .build()
868 .unwrap();
869
870 let mut builder =
871 RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
872
873 let tx1 = builder.tx();
874 builder.spawn_on(
875 async move {
876 tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
877 .await
878 .unwrap();
879
880 Ok(())
881 },
882 tokio_runtime.handle(),
883 );
884
885 let tx2 = builder.tx();
886 builder.spawn_blocking_on(
887 move || {
888 tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
889 .unwrap();
890
891 Ok(())
892 },
893 tokio_runtime.handle(),
894 );
895
896 let mut stream = builder.build();
897
898 let mut number_of_batches = 0;
899
900 loop {
901 let poll = stream.poll_next_unpin(&mut Context::from_waker(
902 futures::task::noop_waker_ref(),
903 ));
904
905 match poll {
906 Poll::Ready(None) => {
907 break;
908 }
909 Poll::Ready(Some(Ok(batch))) => {
910 number_of_batches += 1;
911 assert_eq!(batch.num_rows(), 0);
912 }
913 Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
914 Poll::Pending => {
915 continue;
916 }
917 }
918 }
919
920 assert_eq!(
921 number_of_batches, 2,
922 "Should have received exactly one empty batch"
923 );
924 }
925}