1use crate::execution_plan::{CardinalityEffect, SchedulingType};
22use crate::filter_pushdown::{
23 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
24 FilterPushdownPropagation,
25};
26use crate::projection::ProjectionExec;
27use crate::stream::RecordBatchStreamAdapter;
28use crate::{
29 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult,
30 check_if_same_properties,
31};
32use arrow::array::RecordBatch;
33use datafusion_common::config::ConfigOptions;
34use datafusion_common::{Result, Statistics, internal_err, plan_err};
35use datafusion_common_runtime::SpawnedTask;
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::{SendableRecordBatchStream, TaskContext};
38use datafusion_physical_expr_common::metrics::{
39 ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
40};
41use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
42use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
43use futures::{Stream, StreamExt, TryStreamExt};
44use pin_project_lite::pin_project;
45use std::any::Any;
46use std::fmt;
47use std::pin::Pin;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50use std::task::{Context, Poll};
51use tokio::sync::mpsc::UnboundedReceiver;
52use tokio::sync::{OwnedSemaphorePermit, Semaphore};
53
54#[derive(Debug, Clone)]
94pub struct BufferExec {
95 input: Arc<dyn ExecutionPlan>,
96 properties: Arc<PlanProperties>,
97 capacity: usize,
98 metrics: ExecutionPlanMetricsSet,
99}
100
101impl BufferExec {
102 pub fn new(input: Arc<dyn ExecutionPlan>, capacity: usize) -> Self {
104 let properties = PlanProperties::clone(input.properties())
105 .with_scheduling_type(SchedulingType::Cooperative);
106
107 Self {
108 input,
109 properties: Arc::new(properties),
110 capacity,
111 metrics: ExecutionPlanMetricsSet::new(),
112 }
113 }
114
115 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
117 &self.input
118 }
119
120 pub fn capacity(&self) -> usize {
122 self.capacity
123 }
124
125 fn with_new_children_and_same_properties(
126 &self,
127 mut children: Vec<Arc<dyn ExecutionPlan>>,
128 ) -> Self {
129 Self {
130 input: children.swap_remove(0),
131 metrics: ExecutionPlanMetricsSet::new(),
132 ..Self::clone(self)
133 }
134 }
135}
136
137impl DisplayAs for BufferExec {
138 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
139 match t {
140 DisplayFormatType::Default | DisplayFormatType::Verbose => {
141 write!(f, "BufferExec: capacity={}", self.capacity)
142 }
143 DisplayFormatType::TreeRender => {
144 writeln!(f, "target_batch_size={}", self.capacity)
145 }
146 }
147 }
148}
149
150impl ExecutionPlan for BufferExec {
151 fn name(&self) -> &str {
152 "BufferExec"
153 }
154
155 fn as_any(&self) -> &dyn Any {
156 self
157 }
158
159 fn properties(&self) -> &Arc<PlanProperties> {
160 &self.properties
161 }
162
163 fn maintains_input_order(&self) -> Vec<bool> {
164 vec![true]
165 }
166
167 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
168 vec![false]
169 }
170
171 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
172 vec![&self.input]
173 }
174
175 fn with_new_children(
176 self: Arc<Self>,
177 mut children: Vec<Arc<dyn ExecutionPlan>>,
178 ) -> Result<Arc<dyn ExecutionPlan>> {
179 check_if_same_properties!(self, children);
180 if children.len() != 1 {
181 return plan_err!("BufferExec can only have one child");
182 }
183 Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity)))
184 }
185
186 fn execute(
187 &self,
188 partition: usize,
189 context: Arc<TaskContext>,
190 ) -> Result<SendableRecordBatchStream> {
191 let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]"))
192 .register(context.memory_pool());
193 let in_stream = self.input.execute(partition, context)?;
194
195 let curr_mem_in = Arc::new(AtomicUsize::new(0));
197 let curr_mem_out = Arc::clone(&curr_mem_in);
198 let mut max_mem_in = 0;
199 let max_mem = MetricBuilder::new(&self.metrics).gauge("max_mem_used", partition);
200
201 let curr_queued_in = Arc::new(AtomicUsize::new(0));
202 let curr_queued_out = Arc::clone(&curr_queued_in);
203 let mut max_queued_in = 0;
204 let max_queued = MetricBuilder::new(&self.metrics).gauge("max_queued", partition);
205
206 let in_stream = in_stream.inspect_ok(move |v| {
208 let size = v.get_array_memory_size();
209 let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size;
210 if curr_size > max_mem_in {
211 max_mem_in = curr_size;
212 max_mem.set(max_mem_in);
213 }
214
215 let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1;
216 if curr_queued > max_queued_in {
217 max_queued_in = curr_queued;
218 max_queued.set(max_queued_in);
219 }
220 });
221 let out_stream =
223 MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation);
224 let out_stream = out_stream.inspect_ok(move |v| {
226 curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed);
227 curr_queued_out.fetch_sub(1, Ordering::Relaxed);
228 });
229
230 Ok(Box::pin(RecordBatchStreamAdapter::new(
231 self.schema(),
232 out_stream,
233 )))
234 }
235
236 fn metrics(&self) -> Option<MetricsSet> {
237 Some(self.metrics.clone_inner())
238 }
239
240 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
241 self.input.partition_statistics(partition)
242 }
243
244 fn supports_limit_pushdown(&self) -> bool {
245 self.input.supports_limit_pushdown()
246 }
247
248 fn cardinality_effect(&self) -> CardinalityEffect {
249 CardinalityEffect::Equal
250 }
251
252 fn try_swapping_with_projection(
253 &self,
254 projection: &ProjectionExec,
255 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
256 match self.input.try_swapping_with_projection(projection)? {
257 Some(new_input) => Ok(Some(
258 Arc::new(self.clone()).with_new_children(vec![new_input])?,
259 )),
260 None => Ok(None),
261 }
262 }
263
264 fn gather_filters_for_pushdown(
265 &self,
266 _phase: FilterPushdownPhase,
267 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
268 _config: &ConfigOptions,
269 ) -> Result<FilterDescription> {
270 FilterDescription::from_children(parent_filters, &self.children())
271 }
272
273 fn handle_child_pushdown_result(
274 &self,
275 _phase: FilterPushdownPhase,
276 child_pushdown_result: ChildPushdownResult,
277 _config: &ConfigOptions,
278 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
279 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
280 }
281
282 fn try_pushdown_sort(
283 &self,
284 order: &[PhysicalSortExpr],
285 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
286 self.input.try_pushdown_sort(order)?.try_map(|new_input| {
289 Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc<dyn ExecutionPlan>)
290 })
291 }
292}
293
294pub trait SizedMessage {
296 fn size(&self) -> usize;
297}
298
299impl SizedMessage for RecordBatch {
300 fn size(&self) -> usize {
301 self.get_array_memory_size()
302 }
303}
304
305pin_project! {
306pub struct MemoryBufferedStream<T: SizedMessage> {
312 task: SpawnedTask<()>,
313 batch_rx: UnboundedReceiver<Result<(T, OwnedSemaphorePermit)>>,
314 memory_reservation: Arc<MemoryReservation>,
315}}
316
317impl<T: Send + SizedMessage + 'static> MemoryBufferedStream<T> {
318 pub fn new(
322 mut input: impl Stream<Item = Result<T>> + Unpin + Send + 'static,
323 capacity: usize,
324 memory_reservation: MemoryReservation,
325 ) -> Self {
326 let semaphore = Arc::new(Semaphore::new(capacity));
327 let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel();
328
329 let memory_reservation = Arc::new(memory_reservation);
330 let memory_reservation_clone = Arc::clone(&memory_reservation);
331 let task = SpawnedTask::spawn(async move {
332 loop {
333 let item_or_err = tokio::select! {
338 biased;
339 _ = batch_tx.closed() => break,
340 item_or_err = input.next() => {
341 let Some(item_or_err) = item_or_err else {
342 break; };
344 item_or_err
345 }
346 };
347
348 let item = match item_or_err {
349 Ok(batch) => batch,
350 Err(err) => {
351 let _ = batch_tx.send(Err(err)); break;
353 }
354 };
355
356 let size = item.size();
357 if let Err(err) = memory_reservation.try_grow(size) {
358 let _ = batch_tx.send(Err(err)); break;
360 }
361
362 let capped_size = size.min(capacity) as u32;
366
367 let semaphore = Arc::clone(&semaphore);
368 let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else {
369 let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!"));
370 break;
371 };
372
373 if batch_tx.send(Ok((item, permit))).is_err() {
374 break; };
376 }
377 });
378
379 Self {
380 task,
381 batch_rx,
382 memory_reservation: memory_reservation_clone,
383 }
384 }
385
386 pub fn messages_queued(&self) -> usize {
388 self.batch_rx.len()
389 }
390}
391
392impl<T: SizedMessage> Stream for MemoryBufferedStream<T> {
393 type Item = Result<T>;
394
395 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
396 let self_project = self.project();
397 match self_project.batch_rx.poll_recv(cx) {
398 Poll::Ready(Some(Ok((item, _semaphore_permit)))) => {
399 self_project.memory_reservation.shrink(item.size());
400 Poll::Ready(Some(Ok(item)))
401 }
402 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
403 Poll::Ready(None) => Poll::Ready(None),
404 Poll::Pending => Poll::Pending,
405 }
406 }
407
408 fn size_hint(&self) -> (usize, Option<usize>) {
409 if self.batch_rx.is_closed() {
410 let len = self.batch_rx.len();
411 (len, Some(len))
412 } else {
413 (self.batch_rx.len(), None)
414 }
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use datafusion_common::{DataFusionError, assert_contains};
422 use datafusion_execution::memory_pool::{
423 GreedyMemoryPool, MemoryPool, UnboundedMemoryPool,
424 };
425 use std::error::Error;
426 use std::fmt::Debug;
427 use std::sync::Arc;
428 use std::time::Duration;
429 use tokio::time::timeout;
430
431 #[tokio::test]
432 async fn buffers_only_some_messages() -> Result<(), Box<dyn Error>> {
433 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
434 let (_, res) = memory_pool_and_reservation();
435
436 let buffered = MemoryBufferedStream::new(input, 4, res);
437 wait_for_buffering().await;
438 assert_eq!(buffered.messages_queued(), 2);
439 Ok(())
440 }
441
442 #[tokio::test]
443 async fn yields_all_messages() -> Result<(), Box<dyn Error>> {
444 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
445 let (_, res) = memory_pool_and_reservation();
446
447 let mut buffered = MemoryBufferedStream::new(input, 10, res);
448 wait_for_buffering().await;
449 assert_eq!(buffered.messages_queued(), 4);
450
451 pull_ok_msg(&mut buffered).await?;
452 pull_ok_msg(&mut buffered).await?;
453 pull_ok_msg(&mut buffered).await?;
454 pull_ok_msg(&mut buffered).await?;
455 finished(&mut buffered).await?;
456 Ok(())
457 }
458
459 #[tokio::test]
460 async fn yields_first_msg_even_if_big() -> Result<(), Box<dyn Error>> {
461 let input = futures::stream::iter([25, 1, 2, 3]).map(Ok);
462 let (_, res) = memory_pool_and_reservation();
463
464 let mut buffered = MemoryBufferedStream::new(input, 10, res);
465 wait_for_buffering().await;
466 assert_eq!(buffered.messages_queued(), 1);
467 pull_ok_msg(&mut buffered).await?;
468 Ok(())
469 }
470
471 #[tokio::test]
472 async fn memory_pool_kills_stream() -> Result<(), Box<dyn Error>> {
473 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
474 let (_, res) = bounded_memory_pool_and_reservation(7);
475
476 let mut buffered = MemoryBufferedStream::new(input, 10, res);
477 wait_for_buffering().await;
478
479 pull_ok_msg(&mut buffered).await?;
480 pull_ok_msg(&mut buffered).await?;
481 pull_ok_msg(&mut buffered).await?;
482 let msg = pull_err_msg(&mut buffered).await?;
483
484 assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B");
485 Ok(())
486 }
487
488 #[tokio::test]
489 async fn memory_pool_does_not_kill_stream() -> Result<(), Box<dyn Error>> {
490 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
491 let (_, res) = bounded_memory_pool_and_reservation(7);
492
493 let mut buffered = MemoryBufferedStream::new(input, 3, res);
494 wait_for_buffering().await;
495 pull_ok_msg(&mut buffered).await?;
496
497 wait_for_buffering().await;
498 pull_ok_msg(&mut buffered).await?;
499
500 wait_for_buffering().await;
501 pull_ok_msg(&mut buffered).await?;
502
503 wait_for_buffering().await;
504 pull_ok_msg(&mut buffered).await?;
505
506 wait_for_buffering().await;
507 finished(&mut buffered).await?;
508 Ok(())
509 }
510
511 #[tokio::test]
512 async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box<dyn Error>> {
513 let input = futures::stream::iter([3, 3, 3, 3]).map(Ok);
514 let (_, res) = memory_pool_and_reservation();
515
516 let mut buffered = MemoryBufferedStream::new(input, 2, res);
517 wait_for_buffering().await;
518 assert_eq!(buffered.messages_queued(), 1);
519 pull_ok_msg(&mut buffered).await?;
520
521 wait_for_buffering().await;
522 assert_eq!(buffered.messages_queued(), 1);
523 pull_ok_msg(&mut buffered).await?;
524
525 wait_for_buffering().await;
526 assert_eq!(buffered.messages_queued(), 1);
527 pull_ok_msg(&mut buffered).await?;
528
529 wait_for_buffering().await;
530 assert_eq!(buffered.messages_queued(), 1);
531 pull_ok_msg(&mut buffered).await?;
532
533 wait_for_buffering().await;
534 finished(&mut buffered).await?;
535 Ok(())
536 }
537
538 #[tokio::test]
539 async fn errors_get_propagated() -> Result<(), Box<dyn Error>> {
540 let input = futures::stream::iter([1, 2, 3, 4]).map(|v| {
541 if v == 3 {
542 return internal_err!("Error on 3");
543 }
544 Ok(v)
545 });
546 let (_, res) = memory_pool_and_reservation();
547
548 let mut buffered = MemoryBufferedStream::new(input, 10, res);
549 wait_for_buffering().await;
550
551 pull_ok_msg(&mut buffered).await?;
552 pull_ok_msg(&mut buffered).await?;
553 pull_err_msg(&mut buffered).await?;
554
555 Ok(())
556 }
557
558 #[tokio::test]
559 async fn memory_gets_released_if_stream_drops() -> Result<(), Box<dyn Error>> {
560 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
561 let (pool, res) = memory_pool_and_reservation();
562
563 let mut buffered = MemoryBufferedStream::new(input, 10, res);
564 wait_for_buffering().await;
565 assert_eq!(buffered.messages_queued(), 4);
566 assert_eq!(pool.reserved(), 10);
567
568 pull_ok_msg(&mut buffered).await?;
569 assert_eq!(buffered.messages_queued(), 3);
570 assert_eq!(pool.reserved(), 9);
571
572 pull_ok_msg(&mut buffered).await?;
573 assert_eq!(buffered.messages_queued(), 2);
574 assert_eq!(pool.reserved(), 7);
575
576 drop(buffered);
577 assert_eq!(pool.reserved(), 0);
578 Ok(())
579 }
580
581 fn memory_pool_and_reservation() -> (Arc<dyn MemoryPool>, MemoryReservation) {
582 let pool = Arc::new(UnboundedMemoryPool::default()) as _;
583 let reservation = MemoryConsumer::new("test").register(&pool);
584 (pool, reservation)
585 }
586
587 fn bounded_memory_pool_and_reservation(
588 size: usize,
589 ) -> (Arc<dyn MemoryPool>, MemoryReservation) {
590 let pool = Arc::new(GreedyMemoryPool::new(size)) as _;
591 let reservation = MemoryConsumer::new("test").register(&pool);
592 (pool, reservation)
593 }
594
595 async fn wait_for_buffering() {
596 tokio::time::sleep(Duration::from_millis(1)).await;
599 }
600
601 async fn pull_ok_msg<T: SizedMessage>(
602 buffered: &mut MemoryBufferedStream<T>,
603 ) -> Result<T, Box<dyn Error>> {
604 Ok(timeout(Duration::from_millis(1), buffered.next())
605 .await?
606 .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
607 }
608
609 async fn pull_err_msg<T: SizedMessage + Debug>(
610 buffered: &mut MemoryBufferedStream<T>,
611 ) -> Result<DataFusionError, Box<dyn Error>> {
612 Ok(timeout(Duration::from_millis(1), buffered.next())
613 .await?
614 .map(|v| match v {
615 Ok(v) => internal_err!(
616 "Stream should not have failed, but succeeded with {v:?}"
617 ),
618 Err(err) => Ok(err),
619 })
620 .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
621 }
622
623 async fn finished<T: SizedMessage>(
624 buffered: &mut MemoryBufferedStream<T>,
625 ) -> Result<(), Box<dyn Error>> {
626 match timeout(Duration::from_millis(1), buffered.next())
627 .await?
628 .is_none()
629 {
630 true => Ok(()),
631 false => internal_err!("Stream should have finished")?,
632 }
633 }
634
635 impl SizedMessage for usize {
636 fn size(&self) -> usize {
637 *self
638 }
639 }
640}