1use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25#[cfg(test)]
26use super::metrics::ExecutionPlanMetricsSet;
27use super::metrics::{BaselineMetrics, SplitMetrics};
28use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
29use crate::displayable;
30use crate::spill::get_record_batch_memory_size;
31
32use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
33use datafusion_common::{Result, exec_err};
34use datafusion_common_runtime::JoinSet;
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryReservation;
37
38use futures::ready;
39use futures::stream::BoxStream;
40use futures::{Future, Stream, StreamExt};
41use log::debug;
42use pin_project_lite::pin_project;
43use tokio::runtime::Handle;
44use tokio::sync::mpsc::{Receiver, Sender};
45
46pub(crate) struct ReceiverStreamBuilder<O> {
58 tx: Sender<Result<O>>,
59 rx: Receiver<Result<O>>,
60 join_set: JoinSet<Result<()>>,
61}
62
63impl<O: Send + 'static> ReceiverStreamBuilder<O> {
64 pub fn new(capacity: usize) -> Self {
66 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
67
68 Self {
69 tx,
70 rx,
71 join_set: JoinSet::new(),
72 }
73 }
74
75 pub fn tx(&self) -> Sender<Result<O>> {
77 self.tx.clone()
78 }
79
80 pub fn spawn<F>(&mut self, task: F)
83 where
84 F: Future<Output = Result<()>>,
85 F: Send + 'static,
86 {
87 self.join_set.spawn(task);
88 }
89
90 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
92 where
93 F: Future<Output = Result<()>>,
94 F: Send + 'static,
95 {
96 self.join_set.spawn_on(task, handle);
97 }
98
99 pub fn spawn_blocking<F>(&mut self, f: F)
105 where
106 F: FnOnce() -> Result<()>,
107 F: Send + 'static,
108 {
109 self.join_set.spawn_blocking(f);
110 }
111
112 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
114 where
115 F: FnOnce() -> Result<()>,
116 F: Send + 'static,
117 {
118 self.join_set.spawn_blocking_on(f, handle);
119 }
120
121 pub fn build(self) -> BoxStream<'static, Result<O>> {
123 let Self {
124 tx,
125 rx,
126 mut join_set,
127 } = self;
128
129 drop(tx);
131
132 let check = async move {
134 while let Some(result) = join_set.join_next().await {
135 match result {
136 Ok(task_result) => {
137 match task_result {
138 Ok(_) => continue,
140 Err(error) => return Some(Err(error)),
142 }
143 }
144 Err(e) => {
146 if e.is_panic() {
147 std::panic::resume_unwind(e.into_panic());
149 } else {
150 return Some(exec_err!("Non Panic Task error: {e}"));
156 }
157 }
158 }
159 }
160 None
161 };
162
163 let check_stream = futures::stream::once(check)
164 .filter_map(|item| async move { item });
166
167 let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
169 let next_item = rx.recv().await;
170 next_item.map(|next_item| (next_item, rx))
171 });
172
173 futures::stream::select(rx_stream, check_stream).boxed()
176 }
177}
178
179pub struct RecordBatchReceiverStreamBuilder {
240 schema: SchemaRef,
241 inner: ReceiverStreamBuilder<RecordBatch>,
242}
243
244impl RecordBatchReceiverStreamBuilder {
245 pub fn new(schema: SchemaRef, capacity: usize) -> Self {
247 Self {
248 schema,
249 inner: ReceiverStreamBuilder::new(capacity),
250 }
251 }
252
253 pub fn tx(&self) -> Sender<Result<RecordBatch>> {
259 self.inner.tx()
260 }
261
262 pub fn spawn<F>(&mut self, task: F)
269 where
270 F: Future<Output = Result<()>>,
271 F: Send + 'static,
272 {
273 self.inner.spawn(task)
274 }
275
276 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
278 where
279 F: Future<Output = Result<()>>,
280 F: Send + 'static,
281 {
282 self.inner.spawn_on(task, handle)
283 }
284
285 pub fn spawn_blocking<F>(&mut self, f: F)
305 where
306 F: FnOnce() -> Result<()>,
307 F: Send + 'static,
308 {
309 self.inner.spawn_blocking(f)
310 }
311
312 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
314 where
315 F: FnOnce() -> Result<()>,
316 F: Send + 'static,
317 {
318 self.inner.spawn_blocking_on(f, handle)
319 }
320
321 pub(crate) fn run_input(
327 &mut self,
328 input: Arc<dyn ExecutionPlan>,
329 partition: usize,
330 context: Arc<TaskContext>,
331 ) {
332 let output = self.tx();
333 let input_display = if log::log_enabled!(log::Level::Debug) {
334 displayable(input.as_ref()).one_line().to_string()
335 } else {
336 String::new()
337 };
338
339 self.inner.spawn(async move {
340 let mut stream = match input.execute(partition, context) {
341 Err(e) => {
342 output.send(Err(e)).await.ok();
345 debug!(
346 "Stopping execution: error executing input: {input_display}",
347 );
348 return Ok(());
349 }
350 Ok(stream) => stream,
351 };
352
353 drop(input);
357
358 while let Some(item) = stream.next().await {
361 let is_err = item.is_err();
362
363 if output.send(item).await.is_err() {
366 debug!(
367 "Stopping execution: output is gone, plan cancelling: {input_display}",
368 );
369 return Ok(());
370 }
371
372 if is_err {
375 debug!("Stopping execution: plan returned error: {input_display}");
376 return Ok(());
377 }
378 }
379
380 Ok(())
381 });
382 }
383
384 pub fn build(self) -> SendableRecordBatchStream {
386 Box::pin(RecordBatchStreamAdapter::new(
387 self.schema,
388 self.inner.build(),
389 ))
390 }
391}
392
393#[doc(hidden)]
394pub struct RecordBatchReceiverStream {}
395
396impl RecordBatchReceiverStream {
397 pub fn builder(
399 schema: SchemaRef,
400 capacity: usize,
401 ) -> RecordBatchReceiverStreamBuilder {
402 RecordBatchReceiverStreamBuilder::new(schema, capacity)
403 }
404}
405
406pin_project! {
407 pub struct RecordBatchStreamAdapter<S> {
412 schema: SchemaRef,
413
414 #[pin]
418 stream: Option<S>,
419 }
420}
421
422impl<S> RecordBatchStreamAdapter<S> {
423 pub fn new(schema: SchemaRef, stream: S) -> Self {
447 Self {
448 schema,
449 stream: Some(stream),
450 }
451 }
452}
453
454impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 f.debug_struct("RecordBatchStreamAdapter")
457 .field("schema", &self.schema)
458 .finish()
459 }
460}
461
462impl<S> Stream for RecordBatchStreamAdapter<S>
463where
464 S: Stream<Item = Result<RecordBatch>>,
465{
466 type Item = Result<RecordBatch>;
467
468 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
469 let mut this = self.project();
470 let Some(inner) = this.stream.as_mut().as_pin_mut() else {
471 return Poll::Ready(None);
472 };
473 let item = ready!(inner.poll_next(cx));
474 if item.is_none() {
475 unsafe {
481 *this.stream.as_mut().get_unchecked_mut() = None;
482 }
483 }
484 Poll::Ready(item)
485 }
486
487 fn size_hint(&self) -> (usize, Option<usize>) {
488 match self.stream.as_ref() {
489 Some(stream) => stream.size_hint(),
490 None => (0, Some(0)),
491 }
492 }
493}
494
495impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
496where
497 S: Stream<Item = Result<RecordBatch>>,
498{
499 fn schema(&self) -> SchemaRef {
500 Arc::clone(&self.schema)
501 }
502}
503
504pub struct EmptyRecordBatchStream {
507 schema: SchemaRef,
509}
510
511impl EmptyRecordBatchStream {
512 pub fn new(schema: SchemaRef) -> Self {
514 Self { schema }
515 }
516}
517
518impl RecordBatchStream for EmptyRecordBatchStream {
519 fn schema(&self) -> SchemaRef {
520 Arc::clone(&self.schema)
521 }
522}
523
524impl Stream for EmptyRecordBatchStream {
525 type Item = Result<RecordBatch>;
526
527 fn poll_next(
528 self: Pin<&mut Self>,
529 _cx: &mut Context<'_>,
530 ) -> Poll<Option<Self::Item>> {
531 Poll::Ready(None)
532 }
533}
534
535pub(crate) struct ObservedStream {
538 inner: SendableRecordBatchStream,
539 baseline_metrics: BaselineMetrics,
540 fetch: Option<usize>,
541 produced: usize,
542}
543
544impl ObservedStream {
545 pub fn new(
546 inner: SendableRecordBatchStream,
547 baseline_metrics: BaselineMetrics,
548 fetch: Option<usize>,
549 ) -> Self {
550 Self {
551 inner,
552 baseline_metrics,
553 fetch,
554 produced: 0,
555 }
556 }
557
558 fn limit_reached(
559 &mut self,
560 poll: Poll<Option<Result<RecordBatch>>>,
561 ) -> Poll<Option<Result<RecordBatch>>> {
562 let Some(fetch) = self.fetch else { return poll };
563
564 if self.produced >= fetch {
565 self.release_inner();
566 return Poll::Ready(None);
567 }
568
569 if let Poll::Ready(Some(Ok(batch))) = &poll {
570 if self.produced + batch.num_rows() > fetch {
571 let batch = batch.slice(0, fetch.saturating_sub(self.produced));
572 self.produced += batch.num_rows();
573 if self.produced >= fetch {
574 self.release_inner();
575 }
576 return Poll::Ready(Some(Ok(batch)));
577 };
578 self.produced += batch.num_rows()
579 }
580 poll
581 }
582
583 fn release_inner(&mut self) {
586 let schema = self.inner.schema();
587 self.inner = Box::pin(EmptyRecordBatchStream::new(schema));
588 }
589}
590
591impl RecordBatchStream for ObservedStream {
592 fn schema(&self) -> SchemaRef {
593 self.inner.schema()
594 }
595}
596
597impl Stream for ObservedStream {
598 type Item = Result<RecordBatch>;
599
600 fn poll_next(
601 mut self: Pin<&mut Self>,
602 cx: &mut Context<'_>,
603 ) -> Poll<Option<Self::Item>> {
604 let mut poll = self.inner.poll_next_unpin(cx);
605 if self.fetch.is_some() {
606 poll = self.limit_reached(poll);
607 }
608 self.baseline_metrics.record_poll(poll)
609 }
610}
611
612pin_project! {
613 pub struct BatchSplitStream {
631 #[pin]
632 input: SendableRecordBatchStream,
633 schema: SchemaRef,
634 batch_size: usize,
635 metrics: SplitMetrics,
636 current_batch: Option<RecordBatch>,
637 offset: usize,
638 }
639}
640
641impl BatchSplitStream {
642 pub fn new(
644 input: SendableRecordBatchStream,
645 batch_size: usize,
646 metrics: SplitMetrics,
647 ) -> Self {
648 let schema = input.schema();
649 Self {
650 input,
651 schema,
652 batch_size,
653 metrics,
654 current_batch: None,
655 offset: 0,
656 }
657 }
658
659 fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
664 let batch = self.current_batch.take()?;
665
666 debug_assert!(
668 self.offset <= batch.num_rows(),
669 "Offset {} exceeds batch size {}",
670 self.offset,
671 batch.num_rows()
672 );
673
674 let remaining = batch.num_rows() - self.offset;
675 let to_take = remaining.min(self.batch_size);
676 let out = batch.slice(self.offset, to_take);
677
678 self.metrics.batches_split.add(1);
679 self.offset += to_take;
680 if self.offset < batch.num_rows() {
681 self.current_batch = Some(batch);
683 } else {
684 self.offset = 0;
687 }
688 Some(Ok(out))
689 }
690
691 fn poll_upstream(
697 &mut self,
698 cx: &mut Context<'_>,
699 ) -> Poll<Option<Result<RecordBatch>>> {
700 match ready!(self.input.as_mut().poll_next(cx)) {
701 Some(Ok(batch)) => {
702 if batch.num_rows() <= self.batch_size {
703 Poll::Ready(Some(Ok(batch)))
705 } else {
706 self.current_batch = Some(batch);
708 match self.next_sliced_batch() {
710 Some(result) => Poll::Ready(Some(result)),
711 None => Poll::Ready(None), }
713 }
714 }
715 Some(Err(e)) => Poll::Ready(Some(Err(e))),
716 None => {
717 let input_schema = self.input.schema();
719 self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
720 Poll::Ready(None)
721 }
722 }
723 }
724}
725
726impl Stream for BatchSplitStream {
727 type Item = Result<RecordBatch>;
728
729 fn poll_next(
730 mut self: Pin<&mut Self>,
731 cx: &mut Context<'_>,
732 ) -> Poll<Option<Self::Item>> {
733 if let Some(result) = self.next_sliced_batch() {
735 return Poll::Ready(Some(result));
736 }
737
738 self.poll_upstream(cx)
740 }
741}
742
743impl RecordBatchStream for BatchSplitStream {
744 fn schema(&self) -> SchemaRef {
745 Arc::clone(&self.schema)
746 }
747}
748
749pub(crate) struct ReservationStream {
754 schema: SchemaRef,
755 inner: SendableRecordBatchStream,
756 reservation: MemoryReservation,
757}
758
759impl ReservationStream {
760 pub(crate) fn new(
761 schema: SchemaRef,
762 inner: SendableRecordBatchStream,
763 reservation: MemoryReservation,
764 ) -> Self {
765 Self {
766 schema,
767 inner,
768 reservation,
769 }
770 }
771}
772
773impl Stream for ReservationStream {
774 type Item = Result<RecordBatch>;
775
776 fn poll_next(
777 mut self: Pin<&mut Self>,
778 cx: &mut Context<'_>,
779 ) -> Poll<Option<Self::Item>> {
780 let res = self.inner.poll_next_unpin(cx);
781
782 match res {
783 Poll::Ready(res) => {
784 match res {
785 Some(Ok(batch)) => {
786 self.reservation
787 .shrink(get_record_batch_memory_size(&batch));
788 Poll::Ready(Some(Ok(batch)))
789 }
790 Some(Err(err)) => Poll::Ready(Some(Err(err))),
791 None => {
792 self.reservation.free();
794 let inner_schema = self.inner.schema();
796 self.inner = Box::pin(EmptyRecordBatchStream::new(inner_schema));
797 Poll::Ready(None)
798 }
799 }
800 }
801 Poll::Pending => Poll::Pending,
802 }
803 }
804
805 fn size_hint(&self) -> (usize, Option<usize>) {
806 self.inner.size_hint()
807 }
808}
809
810impl RecordBatchStream for ReservationStream {
811 fn schema(&self) -> SchemaRef {
812 Arc::clone(&self.schema)
813 }
814}
815
816#[cfg(test)]
817mod test {
818 use super::*;
819 use crate::test::exec::{
820 BlockingExec, MockExec, PanicExec, assert_strong_count_converges_to_zero,
821 };
822
823 use arrow::datatypes::{DataType, Field, Schema};
824 use datafusion_common::exec_err;
825
826 fn schema() -> SchemaRef {
827 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
828 }
829
830 #[tokio::test]
831 #[should_panic(expected = "PanickingStream did panic")]
832 async fn record_batch_receiver_stream_propagates_panics() {
833 let schema = schema();
834
835 let num_partitions = 10;
836 let input = PanicExec::new(Arc::clone(&schema), num_partitions);
837 consume(input, 10).await
838 }
839
840 #[tokio::test]
841 #[should_panic(expected = "PanickingStream did panic: 1")]
842 async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
843 let schema = schema();
844
845 let num_partitions = 2;
847 let input = PanicExec::new(Arc::clone(&schema), num_partitions)
848 .with_partition_panic(0, 10)
849 .with_partition_panic(1, 3); let max_batches = 5;
857 consume(input, max_batches).await
858 }
859
860 #[tokio::test]
861 async fn record_batch_receiver_stream_drop_cancel() {
862 let task_ctx = Arc::new(TaskContext::default());
863 let schema = schema();
864
865 let input = BlockingExec::new(Arc::clone(&schema), 1);
867 let refs = input.refs();
868
869 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
871 builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
872 let stream = builder.build();
873
874 assert!(std::sync::Weak::strong_count(&refs) > 0);
876
877 drop(stream);
879 assert_strong_count_converges_to_zero(refs).await;
880 }
881
882 #[tokio::test]
883 async fn record_batch_receiver_stream_error_does_not_drive_completion() {
887 let task_ctx = Arc::new(TaskContext::default());
888 let schema = schema();
889
890 let error_stream = MockExec::new(
892 vec![exec_err!("Test1"), exec_err!("Test2")],
893 Arc::clone(&schema),
894 )
895 .with_use_task(false);
896
897 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
898 builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
899 let mut stream = builder.build();
900
901 let first_batch = stream.next().await.unwrap();
903 let first_err = first_batch.unwrap_err();
904 assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
905
906 assert!(stream.next().await.is_none());
908 }
909
910 #[tokio::test]
911 async fn batch_split_stream_basic_functionality() {
912 use arrow::array::{Int32Array, RecordBatch};
913 use futures::stream::{self, StreamExt};
914
915 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
916
917 let large_batch = RecordBatch::try_new(
919 Arc::clone(&schema),
920 vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
921 )
922 .unwrap();
923
924 let input_stream = stream::iter(vec![Ok(large_batch)]);
926 let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
927 let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
928
929 let metrics = ExecutionPlanMetricsSet::new();
931 let split_metrics = SplitMetrics::new(&metrics, 0);
932 let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
933
934 let mut total_rows = 0;
935 let mut batch_count = 0;
936
937 while let Some(result) = split_stream.next().await {
938 let batch = result.unwrap();
939 assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
940 total_rows += batch.num_rows();
941 batch_count += 1;
942 }
943
944 assert_eq!(total_rows, 2000, "All rows should be preserved");
945 assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
946 }
947
948 async fn consume(input: PanicExec, max_batches: usize) {
953 let task_ctx = Arc::new(TaskContext::default());
954
955 let input = Arc::new(input);
956 let num_partitions = input.properties().output_partitioning().partition_count();
957
958 let mut builder =
960 RecordBatchReceiverStream::builder(input.schema(), num_partitions);
961 for partition in 0..num_partitions {
962 builder.run_input(
963 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
964 partition,
965 Arc::clone(&task_ctx),
966 );
967 }
968 let mut stream = builder.build();
969
970 let mut num_batches = 0;
972 while let Some(next) = stream.next().await {
973 next.unwrap();
974 num_batches += 1;
975 assert!(
976 num_batches < max_batches,
977 "Got the limit of {num_batches} batches before seeing panic"
978 );
979 }
980 }
981
982 #[test]
983 fn record_batch_receiver_stream_builder_spawn_on_runtime() {
984 let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
985 .enable_all()
986 .build()
987 .unwrap();
988
989 let mut builder =
990 RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
991
992 let tx1 = builder.tx();
993 builder.spawn_on(
994 async move {
995 tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
996 .await
997 .unwrap();
998
999 Ok(())
1000 },
1001 tokio_runtime.handle(),
1002 );
1003
1004 let tx2 = builder.tx();
1005 builder.spawn_blocking_on(
1006 move || {
1007 tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
1008 .unwrap();
1009
1010 Ok(())
1011 },
1012 tokio_runtime.handle(),
1013 );
1014
1015 let mut stream = builder.build();
1016
1017 let mut number_of_batches = 0;
1018
1019 loop {
1020 let poll = stream.poll_next_unpin(&mut Context::from_waker(
1021 futures::task::noop_waker_ref(),
1022 ));
1023
1024 match poll {
1025 Poll::Ready(None) => {
1026 break;
1027 }
1028 Poll::Ready(Some(Ok(batch))) => {
1029 number_of_batches += 1;
1030 assert_eq!(batch.num_rows(), 0);
1031 }
1032 Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
1033 Poll::Pending => {
1034 continue;
1035 }
1036 }
1037 }
1038
1039 assert_eq!(
1040 number_of_batches, 2,
1041 "Should have received exactly two empty batches"
1042 );
1043 }
1044
1045 #[tokio::test]
1046 async fn test_reservation_stream_shrinks_on_poll() {
1047 use arrow::array::Int32Array;
1048 use datafusion_execution::memory_pool::MemoryConsumer;
1049 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1050
1051 let runtime = RuntimeEnvBuilder::new()
1052 .with_memory_limit(10 * 1024 * 1024, 1.0)
1053 .build_arc()
1054 .unwrap();
1055
1056 let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
1057
1058 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1059
1060 let batch1 = RecordBatch::try_new(
1062 Arc::clone(&schema),
1063 vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
1064 )
1065 .unwrap();
1066 let batch2 = RecordBatch::try_new(
1067 Arc::clone(&schema),
1068 vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))],
1069 )
1070 .unwrap();
1071
1072 let batch1_size = get_record_batch_memory_size(&batch1);
1073 let batch2_size = get_record_batch_memory_size(&batch2);
1074
1075 reservation.try_grow(batch1_size + batch2_size).unwrap();
1077 let initial_reserved = runtime.memory_pool.reserved();
1078 assert_eq!(initial_reserved, batch1_size + batch2_size);
1079
1080 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1082 let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
1083 as SendableRecordBatchStream;
1084
1085 let mut res_stream =
1086 ReservationStream::new(Arc::clone(&schema), inner, reservation);
1087
1088 let result1 = res_stream.next().await;
1090 assert!(result1.is_some());
1091
1092 let after_first = runtime.memory_pool.reserved();
1094 assert_eq!(after_first, batch2_size);
1095
1096 let result2 = res_stream.next().await;
1098 assert!(result2.is_some());
1099
1100 let after_second = runtime.memory_pool.reserved();
1102 assert_eq!(after_second, 0);
1103
1104 let result3 = res_stream.next().await;
1106 assert!(result3.is_none());
1107
1108 assert_eq!(runtime.memory_pool.reserved(), 0);
1110 }
1111
1112 #[tokio::test]
1113 async fn test_reservation_stream_error_handling() {
1114 use datafusion_execution::memory_pool::MemoryConsumer;
1115 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1116
1117 let runtime = RuntimeEnvBuilder::new()
1118 .with_memory_limit(10 * 1024 * 1024, 1.0)
1119 .build_arc()
1120 .unwrap();
1121
1122 let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
1123
1124 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1125
1126 reservation.try_grow(1000).unwrap();
1127 let initial = runtime.memory_pool.reserved();
1128 assert_eq!(initial, 1000);
1129
1130 let stream = futures::stream::iter(vec![exec_err!("Test error")]);
1132 let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
1133 as SendableRecordBatchStream;
1134
1135 let mut res_stream =
1136 ReservationStream::new(Arc::clone(&schema), inner, reservation);
1137
1138 let result = res_stream.next().await;
1140 assert!(result.is_some());
1141 assert!(result.unwrap().is_err());
1142
1143 let after_error = runtime.memory_pool.reserved();
1148 assert_eq!(
1149 after_error, 1000,
1150 "Reservation should still be held after error"
1151 );
1152
1153 drop(res_stream);
1155
1156 assert_eq!(
1158 runtime.memory_pool.reserved(),
1159 0,
1160 "Memory should be freed when stream is dropped"
1161 );
1162 }
1163}