1use crate::execution_plan::{CardinalityEffect, SchedulingType};
22use crate::filter_pushdown::{
23 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
24 FilterPushdownPropagation,
25};
26use crate::projection::ProjectionExec;
27use crate::stream::RecordBatchStreamAdapter;
28use crate::{
29 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult,
30 check_if_same_properties,
31};
32use arrow::array::RecordBatch;
33use datafusion_common::config::ConfigOptions;
34use datafusion_common::{Result, Statistics, internal_err, plan_err};
35use datafusion_common_runtime::SpawnedTask;
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::{SendableRecordBatchStream, TaskContext};
38use datafusion_physical_expr_common::metrics::{
39 ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, MetricsSet,
40};
41use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
42use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
43use futures::{Stream, StreamExt, TryStreamExt};
44use pin_project_lite::pin_project;
45use std::fmt;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::sync::atomic::{AtomicUsize, Ordering};
49use std::task::{Context, Poll};
50use tokio::sync::mpsc::UnboundedReceiver;
51use tokio::sync::{OwnedSemaphorePermit, Semaphore};
52
53#[derive(Debug, Clone)]
93pub struct BufferExec {
94 input: Arc<dyn ExecutionPlan>,
95 properties: Arc<PlanProperties>,
96 capacity: usize,
97 metrics: ExecutionPlanMetricsSet,
98}
99
100impl BufferExec {
101 pub fn new(input: Arc<dyn ExecutionPlan>, capacity: usize) -> Self {
103 let properties = PlanProperties::clone(input.properties())
104 .with_scheduling_type(SchedulingType::Cooperative);
105
106 Self {
107 input,
108 properties: Arc::new(properties),
109 capacity,
110 metrics: ExecutionPlanMetricsSet::new(),
111 }
112 }
113
114 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
116 &self.input
117 }
118
119 pub fn capacity(&self) -> usize {
121 self.capacity
122 }
123
124 fn with_new_children_and_same_properties(
125 &self,
126 mut children: Vec<Arc<dyn ExecutionPlan>>,
127 ) -> Self {
128 Self {
129 input: children.swap_remove(0),
130 metrics: ExecutionPlanMetricsSet::new(),
131 ..Self::clone(self)
132 }
133 }
134}
135
136impl DisplayAs for BufferExec {
137 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
138 match t {
139 DisplayFormatType::Default | DisplayFormatType::Verbose => {
140 write!(f, "BufferExec: capacity={}", self.capacity)
141 }
142 DisplayFormatType::TreeRender => {
143 writeln!(f, "target_batch_size={}", self.capacity)
144 }
145 }
146 }
147}
148
149impl ExecutionPlan for BufferExec {
150 fn name(&self) -> &str {
151 "BufferExec"
152 }
153
154 fn properties(&self) -> &Arc<PlanProperties> {
155 &self.properties
156 }
157
158 fn maintains_input_order(&self) -> Vec<bool> {
159 vec![true]
160 }
161
162 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
163 vec![false]
164 }
165
166 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
167 vec![&self.input]
168 }
169
170 fn with_new_children(
171 self: Arc<Self>,
172 mut children: Vec<Arc<dyn ExecutionPlan>>,
173 ) -> Result<Arc<dyn ExecutionPlan>> {
174 check_if_same_properties!(self, children);
175 if children.len() != 1 {
176 return plan_err!("BufferExec can only have one child");
177 }
178 Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity)))
179 }
180
181 fn execute(
182 &self,
183 partition: usize,
184 context: Arc<TaskContext>,
185 ) -> Result<SendableRecordBatchStream> {
186 let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]"))
187 .register(context.memory_pool());
188 let in_stream = self.input.execute(partition, context)?;
189
190 let curr_mem_in = Arc::new(AtomicUsize::new(0));
192 let curr_mem_out = Arc::clone(&curr_mem_in);
193 let mut max_mem_in = 0;
194 let max_mem = MetricBuilder::new(&self.metrics)
195 .with_category(MetricCategory::Bytes)
196 .gauge("max_mem_used", partition);
197
198 let curr_queued_in = Arc::new(AtomicUsize::new(0));
199 let curr_queued_out = Arc::clone(&curr_queued_in);
200 let mut max_queued_in = 0;
201 let max_queued = MetricBuilder::new(&self.metrics)
202 .with_category(MetricCategory::Rows)
203 .gauge("max_queued", partition);
204
205 let in_stream = in_stream.inspect_ok(move |v| {
207 let size = v.get_array_memory_size();
208 let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size;
209 if curr_size > max_mem_in {
210 max_mem_in = curr_size;
211 max_mem.set(max_mem_in);
212 }
213
214 let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1;
215 if curr_queued > max_queued_in {
216 max_queued_in = curr_queued;
217 max_queued.set(max_queued_in);
218 }
219 });
220 let out_stream =
222 MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation);
223 let out_stream = out_stream.inspect_ok(move |v| {
225 curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed);
226 curr_queued_out.fetch_sub(1, Ordering::Relaxed);
227 });
228
229 Ok(Box::pin(RecordBatchStreamAdapter::new(
230 self.schema(),
231 out_stream,
232 )))
233 }
234
235 fn metrics(&self) -> Option<MetricsSet> {
236 Some(self.metrics.clone_inner())
237 }
238
239 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
240 self.input.partition_statistics(partition)
241 }
242
243 fn supports_limit_pushdown(&self) -> bool {
244 self.input.supports_limit_pushdown()
245 }
246
247 fn cardinality_effect(&self) -> CardinalityEffect {
248 CardinalityEffect::Equal
249 }
250
251 fn try_swapping_with_projection(
252 &self,
253 projection: &ProjectionExec,
254 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
255 match self.input.try_swapping_with_projection(projection)? {
256 Some(new_input) => Ok(Some(
257 Arc::new(self.clone()).with_new_children(vec![new_input])?,
258 )),
259 None => Ok(None),
260 }
261 }
262
263 fn gather_filters_for_pushdown(
264 &self,
265 _phase: FilterPushdownPhase,
266 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
267 _config: &ConfigOptions,
268 ) -> Result<FilterDescription> {
269 FilterDescription::from_children(parent_filters, &self.children())
270 }
271
272 fn handle_child_pushdown_result(
273 &self,
274 _phase: FilterPushdownPhase,
275 child_pushdown_result: ChildPushdownResult,
276 _config: &ConfigOptions,
277 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
278 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
279 }
280
281 fn try_pushdown_sort(
282 &self,
283 order: &[PhysicalSortExpr],
284 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
285 self.input.try_pushdown_sort(order)?.try_map(|new_input| {
288 Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc<dyn ExecutionPlan>)
289 })
290 }
291}
292
293pub trait SizedMessage {
295 fn size(&self) -> usize;
296}
297
298impl SizedMessage for RecordBatch {
299 fn size(&self) -> usize {
300 self.get_array_memory_size()
301 }
302}
303
304pin_project! {
305pub struct MemoryBufferedStream<T: SizedMessage> {
311 task: SpawnedTask<()>,
312 batch_rx: UnboundedReceiver<Result<(T, OwnedSemaphorePermit)>>,
313 memory_reservation: Arc<MemoryReservation>,
314}}
315
316impl<T: Send + SizedMessage + 'static> MemoryBufferedStream<T> {
317 pub fn new(
321 mut input: impl Stream<Item = Result<T>> + Unpin + Send + 'static,
322 capacity: usize,
323 memory_reservation: MemoryReservation,
324 ) -> Self {
325 let semaphore = Arc::new(Semaphore::new(capacity));
326 let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel();
327
328 let memory_reservation = Arc::new(memory_reservation);
329 let memory_reservation_clone = Arc::clone(&memory_reservation);
330 let task = SpawnedTask::spawn(async move {
331 loop {
332 let item_or_err = tokio::select! {
337 biased;
338 _ = batch_tx.closed() => break,
339 item_or_err = input.next() => {
340 let Some(item_or_err) = item_or_err else {
341 break; };
343 item_or_err
344 }
345 };
346
347 let item = match item_or_err {
348 Ok(batch) => batch,
349 Err(err) => {
350 let _ = batch_tx.send(Err(err)); break;
352 }
353 };
354
355 let size = item.size();
356 if let Err(err) = memory_reservation.try_grow(size) {
357 let _ = batch_tx.send(Err(err)); break;
359 }
360
361 let capped_size = size.min(capacity) as u32;
365
366 let semaphore = Arc::clone(&semaphore);
367 let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else {
368 let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!"));
369 break;
370 };
371
372 if batch_tx.send(Ok((item, permit))).is_err() {
373 break; };
375 }
376 });
377
378 Self {
379 task,
380 batch_rx,
381 memory_reservation: memory_reservation_clone,
382 }
383 }
384
385 pub fn messages_queued(&self) -> usize {
387 self.batch_rx.len()
388 }
389}
390
391impl<T: SizedMessage> Stream for MemoryBufferedStream<T> {
392 type Item = Result<T>;
393
394 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395 let self_project = self.project();
396 match self_project.batch_rx.poll_recv(cx) {
397 Poll::Ready(Some(Ok((item, _semaphore_permit)))) => {
398 self_project.memory_reservation.shrink(item.size());
399 Poll::Ready(Some(Ok(item)))
400 }
401 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
402 Poll::Ready(None) => Poll::Ready(None),
403 Poll::Pending => Poll::Pending,
404 }
405 }
406
407 fn size_hint(&self) -> (usize, Option<usize>) {
408 if self.batch_rx.is_closed() {
409 let len = self.batch_rx.len();
410 (len, Some(len))
411 } else {
412 (self.batch_rx.len(), None)
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use datafusion_common::{DataFusionError, assert_contains};
421 use datafusion_execution::memory_pool::{
422 GreedyMemoryPool, MemoryPool, UnboundedMemoryPool,
423 };
424 use std::error::Error;
425 use std::fmt::Debug;
426 use std::time::Duration;
427 use tokio::time::timeout;
428
429 #[tokio::test]
430 async fn buffers_only_some_messages() -> Result<(), Box<dyn Error>> {
431 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
432 let (_, res) = memory_pool_and_reservation();
433
434 let buffered = MemoryBufferedStream::new(input, 4, res);
435 wait_for_buffering().await;
436 assert_eq!(buffered.messages_queued(), 2);
437 Ok(())
438 }
439
440 #[tokio::test]
441 async fn yields_all_messages() -> Result<(), Box<dyn Error>> {
442 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
443 let (_, res) = memory_pool_and_reservation();
444
445 let mut buffered = MemoryBufferedStream::new(input, 10, res);
446 wait_for_buffering().await;
447 assert_eq!(buffered.messages_queued(), 4);
448
449 pull_ok_msg(&mut buffered).await?;
450 pull_ok_msg(&mut buffered).await?;
451 pull_ok_msg(&mut buffered).await?;
452 pull_ok_msg(&mut buffered).await?;
453 finished(&mut buffered).await?;
454 Ok(())
455 }
456
457 #[tokio::test]
458 async fn yields_first_msg_even_if_big() -> Result<(), Box<dyn Error>> {
459 let input = futures::stream::iter([25, 1, 2, 3]).map(Ok);
460 let (_, res) = memory_pool_and_reservation();
461
462 let mut buffered = MemoryBufferedStream::new(input, 10, res);
463 wait_for_buffering().await;
464 assert_eq!(buffered.messages_queued(), 1);
465 pull_ok_msg(&mut buffered).await?;
466 Ok(())
467 }
468
469 #[tokio::test]
470 async fn memory_pool_kills_stream() -> Result<(), Box<dyn Error>> {
471 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
472 let (_, res) = bounded_memory_pool_and_reservation(7);
473
474 let mut buffered = MemoryBufferedStream::new(input, 10, res);
475 wait_for_buffering().await;
476
477 pull_ok_msg(&mut buffered).await?;
478 pull_ok_msg(&mut buffered).await?;
479 pull_ok_msg(&mut buffered).await?;
480 let msg = pull_err_msg(&mut buffered).await?;
481
482 assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B");
483 Ok(())
484 }
485
486 #[tokio::test]
487 async fn memory_pool_does_not_kill_stream() -> Result<(), Box<dyn Error>> {
488 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
489 let (_, res) = bounded_memory_pool_and_reservation(7);
490
491 let mut buffered = MemoryBufferedStream::new(input, 3, res);
492 wait_for_buffering().await;
493 pull_ok_msg(&mut buffered).await?;
494
495 wait_for_buffering().await;
496 pull_ok_msg(&mut buffered).await?;
497
498 wait_for_buffering().await;
499 pull_ok_msg(&mut buffered).await?;
500
501 wait_for_buffering().await;
502 pull_ok_msg(&mut buffered).await?;
503
504 wait_for_buffering().await;
505 finished(&mut buffered).await?;
506 Ok(())
507 }
508
509 #[tokio::test]
510 async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box<dyn Error>> {
511 let input = futures::stream::iter([3, 3, 3, 3]).map(Ok);
512 let (_, res) = memory_pool_and_reservation();
513
514 let mut buffered = MemoryBufferedStream::new(input, 2, res);
515 wait_for_buffering().await;
516 assert_eq!(buffered.messages_queued(), 1);
517 pull_ok_msg(&mut buffered).await?;
518
519 wait_for_buffering().await;
520 assert_eq!(buffered.messages_queued(), 1);
521 pull_ok_msg(&mut buffered).await?;
522
523 wait_for_buffering().await;
524 assert_eq!(buffered.messages_queued(), 1);
525 pull_ok_msg(&mut buffered).await?;
526
527 wait_for_buffering().await;
528 assert_eq!(buffered.messages_queued(), 1);
529 pull_ok_msg(&mut buffered).await?;
530
531 wait_for_buffering().await;
532 finished(&mut buffered).await?;
533 Ok(())
534 }
535
536 #[tokio::test]
537 async fn errors_get_propagated() -> Result<(), Box<dyn Error>> {
538 let input = futures::stream::iter([1, 2, 3, 4]).map(|v| {
539 if v == 3 {
540 return internal_err!("Error on 3");
541 }
542 Ok(v)
543 });
544 let (_, res) = memory_pool_and_reservation();
545
546 let mut buffered = MemoryBufferedStream::new(input, 10, res);
547 wait_for_buffering().await;
548
549 pull_ok_msg(&mut buffered).await?;
550 pull_ok_msg(&mut buffered).await?;
551 pull_err_msg(&mut buffered).await?;
552
553 Ok(())
554 }
555
556 #[tokio::test]
557 async fn memory_gets_released_if_stream_drops() -> Result<(), Box<dyn Error>> {
558 let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
559 let (pool, res) = memory_pool_and_reservation();
560
561 let mut buffered = MemoryBufferedStream::new(input, 10, res);
562 wait_for_buffering().await;
563 assert_eq!(buffered.messages_queued(), 4);
564 assert_eq!(pool.reserved(), 10);
565
566 pull_ok_msg(&mut buffered).await?;
567 assert_eq!(buffered.messages_queued(), 3);
568 assert_eq!(pool.reserved(), 9);
569
570 pull_ok_msg(&mut buffered).await?;
571 assert_eq!(buffered.messages_queued(), 2);
572 assert_eq!(pool.reserved(), 7);
573
574 drop(buffered);
575 assert_eq!(pool.reserved(), 0);
576 Ok(())
577 }
578
579 fn memory_pool_and_reservation() -> (Arc<dyn MemoryPool>, MemoryReservation) {
580 let pool = Arc::new(UnboundedMemoryPool::default()) as _;
581 let reservation = MemoryConsumer::new("test").register(&pool);
582 (pool, reservation)
583 }
584
585 fn bounded_memory_pool_and_reservation(
586 size: usize,
587 ) -> (Arc<dyn MemoryPool>, MemoryReservation) {
588 let pool = Arc::new(GreedyMemoryPool::new(size)) as _;
589 let reservation = MemoryConsumer::new("test").register(&pool);
590 (pool, reservation)
591 }
592
593 async fn wait_for_buffering() {
594 tokio::time::sleep(Duration::from_millis(1)).await;
597 }
598
599 async fn pull_ok_msg<T: SizedMessage>(
600 buffered: &mut MemoryBufferedStream<T>,
601 ) -> Result<T, Box<dyn Error>> {
602 Ok(timeout(Duration::from_millis(1), buffered.next())
603 .await?
604 .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
605 }
606
607 async fn pull_err_msg<T: SizedMessage + Debug>(
608 buffered: &mut MemoryBufferedStream<T>,
609 ) -> Result<DataFusionError, Box<dyn Error>> {
610 Ok(timeout(Duration::from_millis(1), buffered.next())
611 .await?
612 .map(|v| match v {
613 Ok(v) => internal_err!(
614 "Stream should not have failed, but succeeded with {v:?}"
615 ),
616 Err(err) => Ok(err),
617 })
618 .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
619 }
620
621 async fn finished<T: SizedMessage>(
622 buffered: &mut MemoryBufferedStream<T>,
623 ) -> Result<(), Box<dyn Error>> {
624 match timeout(Duration::from_millis(1), buffered.next())
625 .await?
626 .is_none()
627 {
628 true => Ok(()),
629 false => internal_err!("Stream should have finished")?,
630 }
631 }
632
633 impl SizedMessage for usize {
634 fn size(&self) -> usize {
635 *self
636 }
637 }
638}