Skip to main content

datafusion_physical_plan/
buffer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`BufferExec`] decouples production and consumption on messages by buffering the input in the
19//! background up to a certain capacity.
20
21use crate::execution_plan::{CardinalityEffect, SchedulingType};
22use crate::filter_pushdown::{
23    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
24    FilterPushdownPropagation,
25};
26use crate::projection::ProjectionExec;
27use crate::stream::RecordBatchStreamAdapter;
28use crate::{
29    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult,
30    check_if_same_properties,
31};
32use arrow::array::RecordBatch;
33use datafusion_common::config::ConfigOptions;
34use datafusion_common::{Result, Statistics, internal_err, plan_err};
35use datafusion_common_runtime::SpawnedTask;
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::{SendableRecordBatchStream, TaskContext};
38use datafusion_physical_expr_common::metrics::{
39    ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, MetricsSet,
40};
41use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
42use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
43use futures::{Stream, StreamExt, TryStreamExt};
44use pin_project_lite::pin_project;
45use std::fmt;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::sync::atomic::{AtomicUsize, Ordering};
49use std::task::{Context, Poll};
50use tokio::sync::mpsc::UnboundedReceiver;
51use tokio::sync::{OwnedSemaphorePermit, Semaphore};
52
53/// WARNING: EXPERIMENTAL
54///
55/// Decouples production and consumption of record batches with an internal queue per partition,
56/// eagerly filling up the capacity of the queues even before any message is requested.
57///
58/// ```text
59///             ┌───────────────────────────┐
60///             │        BufferExec         │
61///             │                           │
62///             │┌────── Partition 0 ──────┐│
63///             ││            ┌────┐ ┌────┐││       ┌────┐
64/// ──background poll────────▶│    │ │    ├┼┼───────▶    │
65///             ││            └────┘ └────┘││       └────┘
66///             │└─────────────────────────┘│
67///             │┌────── Partition 1 ──────┐│
68///             ││     ┌────┐ ┌────┐ ┌────┐││       ┌────┐
69/// ──background poll─▶│    │ │    │ │    ├┼┼───────▶    │
70///             ││     └────┘ └────┘ └────┘││       └────┘
71///             │└─────────────────────────┘│
72///             │                           │
73///             │           ...             │
74///             │                           │
75///             │┌────── Partition N ──────┐│
76///             ││                   ┌────┐││       ┌────┐
77/// ──background poll───────────────▶│    ├┼┼───────▶    │
78///             ││                   └────┘││       └────┘
79///             │└─────────────────────────┘│
80///             └───────────────────────────┘
81/// ```
82///
83/// The capacity is provided in bytes, and for each buffered record batch it will take into account
84/// the size reported by [RecordBatch::get_array_memory_size].
85///
86/// If a single record batch exceeds the maximum capacity set in the `capacity` argument, it's still
87/// allowed to pass in order to not deadlock the buffer.
88///
89/// This is useful for operators that conditionally start polling one of their children only after
90/// other child has finished, allowing to perform some early work and accumulating batches in
91/// memory so that they can be served immediately when requested.
92#[derive(Debug, Clone)]
93pub struct BufferExec {
94    input: Arc<dyn ExecutionPlan>,
95    properties: Arc<PlanProperties>,
96    capacity: usize,
97    metrics: ExecutionPlanMetricsSet,
98}
99
100impl BufferExec {
101    /// Builds a new [BufferExec] with the provided capacity in bytes.
102    pub fn new(input: Arc<dyn ExecutionPlan>, capacity: usize) -> Self {
103        let properties = PlanProperties::clone(input.properties())
104            .with_scheduling_type(SchedulingType::Cooperative);
105
106        Self {
107            input,
108            properties: Arc::new(properties),
109            capacity,
110            metrics: ExecutionPlanMetricsSet::new(),
111        }
112    }
113
114    /// Returns the input [ExecutionPlan] of this [BufferExec].
115    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
116        &self.input
117    }
118
119    /// Returns the per-partition capacity in bytes for this [BufferExec].
120    pub fn capacity(&self) -> usize {
121        self.capacity
122    }
123
124    fn with_new_children_and_same_properties(
125        &self,
126        mut children: Vec<Arc<dyn ExecutionPlan>>,
127    ) -> Self {
128        Self {
129            input: children.swap_remove(0),
130            metrics: ExecutionPlanMetricsSet::new(),
131            ..Self::clone(self)
132        }
133    }
134}
135
136impl DisplayAs for BufferExec {
137    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
138        match t {
139            DisplayFormatType::Default | DisplayFormatType::Verbose => {
140                write!(f, "BufferExec: capacity={}", self.capacity)
141            }
142            DisplayFormatType::TreeRender => {
143                writeln!(f, "target_batch_size={}", self.capacity)
144            }
145        }
146    }
147}
148
149impl ExecutionPlan for BufferExec {
150    fn name(&self) -> &str {
151        "BufferExec"
152    }
153
154    fn properties(&self) -> &Arc<PlanProperties> {
155        &self.properties
156    }
157
158    fn maintains_input_order(&self) -> Vec<bool> {
159        vec![true]
160    }
161
162    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
163        vec![false]
164    }
165
166    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
167        vec![&self.input]
168    }
169
170    fn with_new_children(
171        self: Arc<Self>,
172        mut children: Vec<Arc<dyn ExecutionPlan>>,
173    ) -> Result<Arc<dyn ExecutionPlan>> {
174        check_if_same_properties!(self, children);
175        if children.len() != 1 {
176            return plan_err!("BufferExec can only have one child");
177        }
178        Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity)))
179    }
180
181    fn execute(
182        &self,
183        partition: usize,
184        context: Arc<TaskContext>,
185    ) -> Result<SendableRecordBatchStream> {
186        let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]"))
187            .register(context.memory_pool());
188        let in_stream = self.input.execute(partition, context)?;
189
190        // Set up the metrics for the stream.
191        let curr_mem_in = Arc::new(AtomicUsize::new(0));
192        let curr_mem_out = Arc::clone(&curr_mem_in);
193        let mut max_mem_in = 0;
194        let max_mem = MetricBuilder::new(&self.metrics)
195            .with_category(MetricCategory::Bytes)
196            .gauge("max_mem_used", partition);
197
198        let curr_queued_in = Arc::new(AtomicUsize::new(0));
199        let curr_queued_out = Arc::clone(&curr_queued_in);
200        let mut max_queued_in = 0;
201        let max_queued = MetricBuilder::new(&self.metrics)
202            .with_category(MetricCategory::Rows)
203            .gauge("max_queued", partition);
204
205        // Capture metrics when an element is queued on the stream.
206        let in_stream = in_stream.inspect_ok(move |v| {
207            let size = v.get_array_memory_size();
208            let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size;
209            if curr_size > max_mem_in {
210                max_mem_in = curr_size;
211                max_mem.set(max_mem_in);
212            }
213
214            let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1;
215            if curr_queued > max_queued_in {
216                max_queued_in = curr_queued;
217                max_queued.set(max_queued_in);
218            }
219        });
220        // Buffer the input.
221        let out_stream =
222            MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation);
223        // Update in the metrics that when an element gets out, some memory gets freed.
224        let out_stream = out_stream.inspect_ok(move |v| {
225            curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed);
226            curr_queued_out.fetch_sub(1, Ordering::Relaxed);
227        });
228
229        Ok(Box::pin(RecordBatchStreamAdapter::new(
230            self.schema(),
231            out_stream,
232        )))
233    }
234
235    fn metrics(&self) -> Option<MetricsSet> {
236        Some(self.metrics.clone_inner())
237    }
238
239    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
240        self.input.partition_statistics(partition)
241    }
242
243    fn supports_limit_pushdown(&self) -> bool {
244        self.input.supports_limit_pushdown()
245    }
246
247    fn cardinality_effect(&self) -> CardinalityEffect {
248        CardinalityEffect::Equal
249    }
250
251    fn try_swapping_with_projection(
252        &self,
253        projection: &ProjectionExec,
254    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
255        match self.input.try_swapping_with_projection(projection)? {
256            Some(new_input) => Ok(Some(
257                Arc::new(self.clone()).with_new_children(vec![new_input])?,
258            )),
259            None => Ok(None),
260        }
261    }
262
263    fn gather_filters_for_pushdown(
264        &self,
265        _phase: FilterPushdownPhase,
266        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
267        _config: &ConfigOptions,
268    ) -> Result<FilterDescription> {
269        FilterDescription::from_children(parent_filters, &self.children())
270    }
271
272    fn handle_child_pushdown_result(
273        &self,
274        _phase: FilterPushdownPhase,
275        child_pushdown_result: ChildPushdownResult,
276        _config: &ConfigOptions,
277    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
278        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
279    }
280
281    fn try_pushdown_sort(
282        &self,
283        order: &[PhysicalSortExpr],
284    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
285        // CoalesceBatchesExec is transparent for sort ordering - it preserves order
286        // Delegate to the child and wrap with a new CoalesceBatchesExec
287        self.input.try_pushdown_sort(order)?.try_map(|new_input| {
288            Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc<dyn ExecutionPlan>)
289        })
290    }
291}
292
293/// Represents anything that occupies a capacity in a [MemoryBufferedStream].
294pub trait SizedMessage {
295    fn size(&self) -> usize;
296}
297
298impl SizedMessage for RecordBatch {
299    fn size(&self) -> usize {
300        self.get_array_memory_size()
301    }
302}
303
304pin_project! {
305/// Decouples production and consumption of messages in a stream with an internal queue, eagerly
306/// filling it up to the specified maximum capacity even before any message is requested.
307///
308/// Allows each message to have a different size, which is taken into account for determining if
309/// the queue is full or not.
310pub struct MemoryBufferedStream<T: SizedMessage> {
311    task: SpawnedTask<()>,
312    batch_rx: UnboundedReceiver<Result<(T, OwnedSemaphorePermit)>>,
313    memory_reservation: Arc<MemoryReservation>,
314}}
315
316impl<T: Send + SizedMessage + 'static> MemoryBufferedStream<T> {
317    /// Builds a new [MemoryBufferedStream] with the provided capacity and event handler.
318    ///
319    /// This immediately spawns a Tokio task that will start consumption of the input stream.
320    pub fn new(
321        mut input: impl Stream<Item = Result<T>> + Unpin + Send + 'static,
322        capacity: usize,
323        memory_reservation: MemoryReservation,
324    ) -> Self {
325        let semaphore = Arc::new(Semaphore::new(capacity));
326        let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel();
327
328        let memory_reservation = Arc::new(memory_reservation);
329        let memory_reservation_clone = Arc::clone(&memory_reservation);
330        let task = SpawnedTask::spawn(async move {
331            loop {
332                // Select on both the input stream and the channel being closed.
333                // By down this, we abort polling the input as soon as the consumer channel is
334                // closed. Otherwise, we would need to wait for a full new message to be available
335                // in order to consider aborting the stream
336                let item_or_err = tokio::select! {
337                    biased;
338                    _ = batch_tx.closed() => break,
339                    item_or_err = input.next() => {
340                        let Some(item_or_err) = item_or_err else {
341                            break; // stream finished
342                        };
343                        item_or_err
344                    }
345                };
346
347                let item = match item_or_err {
348                    Ok(batch) => batch,
349                    Err(err) => {
350                        let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine.
351                        break;
352                    }
353                };
354
355                let size = item.size();
356                if let Err(err) = memory_reservation.try_grow(size) {
357                    let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine.
358                    break;
359                }
360
361                // We need to cap the minimum between amount of permits and the actual size of the
362                // message. If at any point we try to acquire more permits than the capacity of the
363                // semaphore, the stream will deadlock.
364                let capped_size = size.min(capacity) as u32;
365
366                let semaphore = Arc::clone(&semaphore);
367                let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else {
368                    let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!"));
369                    break;
370                };
371
372                if batch_tx.send(Ok((item, permit))).is_err() {
373                    break; // stream was closed
374                };
375            }
376        });
377
378        Self {
379            task,
380            batch_rx,
381            memory_reservation: memory_reservation_clone,
382        }
383    }
384
385    /// Returns the number of queued messages.
386    pub fn messages_queued(&self) -> usize {
387        self.batch_rx.len()
388    }
389}
390
391impl<T: SizedMessage> Stream for MemoryBufferedStream<T> {
392    type Item = Result<T>;
393
394    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395        let self_project = self.project();
396        match self_project.batch_rx.poll_recv(cx) {
397            Poll::Ready(Some(Ok((item, _semaphore_permit)))) => {
398                self_project.memory_reservation.shrink(item.size());
399                Poll::Ready(Some(Ok(item)))
400            }
401            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
402            Poll::Ready(None) => Poll::Ready(None),
403            Poll::Pending => Poll::Pending,
404        }
405    }
406
407    fn size_hint(&self) -> (usize, Option<usize>) {
408        if self.batch_rx.is_closed() {
409            let len = self.batch_rx.len();
410            (len, Some(len))
411        } else {
412            (self.batch_rx.len(), None)
413        }
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use datafusion_common::{DataFusionError, assert_contains};
421    use datafusion_execution::memory_pool::{
422        GreedyMemoryPool, MemoryPool, UnboundedMemoryPool,
423    };
424    use std::error::Error;
425    use std::fmt::Debug;
426    use std::time::Duration;
427    use tokio::time::timeout;
428
429    #[tokio::test]
430    async fn buffers_only_some_messages() -> Result<(), Box<dyn Error>> {
431        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
432        let (_, res) = memory_pool_and_reservation();
433
434        let buffered = MemoryBufferedStream::new(input, 4, res);
435        wait_for_buffering().await;
436        assert_eq!(buffered.messages_queued(), 2);
437        Ok(())
438    }
439
440    #[tokio::test]
441    async fn yields_all_messages() -> Result<(), Box<dyn Error>> {
442        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
443        let (_, res) = memory_pool_and_reservation();
444
445        let mut buffered = MemoryBufferedStream::new(input, 10, res);
446        wait_for_buffering().await;
447        assert_eq!(buffered.messages_queued(), 4);
448
449        pull_ok_msg(&mut buffered).await?;
450        pull_ok_msg(&mut buffered).await?;
451        pull_ok_msg(&mut buffered).await?;
452        pull_ok_msg(&mut buffered).await?;
453        finished(&mut buffered).await?;
454        Ok(())
455    }
456
457    #[tokio::test]
458    async fn yields_first_msg_even_if_big() -> Result<(), Box<dyn Error>> {
459        let input = futures::stream::iter([25, 1, 2, 3]).map(Ok);
460        let (_, res) = memory_pool_and_reservation();
461
462        let mut buffered = MemoryBufferedStream::new(input, 10, res);
463        wait_for_buffering().await;
464        assert_eq!(buffered.messages_queued(), 1);
465        pull_ok_msg(&mut buffered).await?;
466        Ok(())
467    }
468
469    #[tokio::test]
470    async fn memory_pool_kills_stream() -> Result<(), Box<dyn Error>> {
471        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
472        let (_, res) = bounded_memory_pool_and_reservation(7);
473
474        let mut buffered = MemoryBufferedStream::new(input, 10, res);
475        wait_for_buffering().await;
476
477        pull_ok_msg(&mut buffered).await?;
478        pull_ok_msg(&mut buffered).await?;
479        pull_ok_msg(&mut buffered).await?;
480        let msg = pull_err_msg(&mut buffered).await?;
481
482        assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B");
483        Ok(())
484    }
485
486    #[tokio::test]
487    async fn memory_pool_does_not_kill_stream() -> Result<(), Box<dyn Error>> {
488        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
489        let (_, res) = bounded_memory_pool_and_reservation(7);
490
491        let mut buffered = MemoryBufferedStream::new(input, 3, res);
492        wait_for_buffering().await;
493        pull_ok_msg(&mut buffered).await?;
494
495        wait_for_buffering().await;
496        pull_ok_msg(&mut buffered).await?;
497
498        wait_for_buffering().await;
499        pull_ok_msg(&mut buffered).await?;
500
501        wait_for_buffering().await;
502        pull_ok_msg(&mut buffered).await?;
503
504        wait_for_buffering().await;
505        finished(&mut buffered).await?;
506        Ok(())
507    }
508
509    #[tokio::test]
510    async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box<dyn Error>> {
511        let input = futures::stream::iter([3, 3, 3, 3]).map(Ok);
512        let (_, res) = memory_pool_and_reservation();
513
514        let mut buffered = MemoryBufferedStream::new(input, 2, res);
515        wait_for_buffering().await;
516        assert_eq!(buffered.messages_queued(), 1);
517        pull_ok_msg(&mut buffered).await?;
518
519        wait_for_buffering().await;
520        assert_eq!(buffered.messages_queued(), 1);
521        pull_ok_msg(&mut buffered).await?;
522
523        wait_for_buffering().await;
524        assert_eq!(buffered.messages_queued(), 1);
525        pull_ok_msg(&mut buffered).await?;
526
527        wait_for_buffering().await;
528        assert_eq!(buffered.messages_queued(), 1);
529        pull_ok_msg(&mut buffered).await?;
530
531        wait_for_buffering().await;
532        finished(&mut buffered).await?;
533        Ok(())
534    }
535
536    #[tokio::test]
537    async fn errors_get_propagated() -> Result<(), Box<dyn Error>> {
538        let input = futures::stream::iter([1, 2, 3, 4]).map(|v| {
539            if v == 3 {
540                return internal_err!("Error on 3");
541            }
542            Ok(v)
543        });
544        let (_, res) = memory_pool_and_reservation();
545
546        let mut buffered = MemoryBufferedStream::new(input, 10, res);
547        wait_for_buffering().await;
548
549        pull_ok_msg(&mut buffered).await?;
550        pull_ok_msg(&mut buffered).await?;
551        pull_err_msg(&mut buffered).await?;
552
553        Ok(())
554    }
555
556    #[tokio::test]
557    async fn memory_gets_released_if_stream_drops() -> Result<(), Box<dyn Error>> {
558        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
559        let (pool, res) = memory_pool_and_reservation();
560
561        let mut buffered = MemoryBufferedStream::new(input, 10, res);
562        wait_for_buffering().await;
563        assert_eq!(buffered.messages_queued(), 4);
564        assert_eq!(pool.reserved(), 10);
565
566        pull_ok_msg(&mut buffered).await?;
567        assert_eq!(buffered.messages_queued(), 3);
568        assert_eq!(pool.reserved(), 9);
569
570        pull_ok_msg(&mut buffered).await?;
571        assert_eq!(buffered.messages_queued(), 2);
572        assert_eq!(pool.reserved(), 7);
573
574        drop(buffered);
575        assert_eq!(pool.reserved(), 0);
576        Ok(())
577    }
578
579    fn memory_pool_and_reservation() -> (Arc<dyn MemoryPool>, MemoryReservation) {
580        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
581        let reservation = MemoryConsumer::new("test").register(&pool);
582        (pool, reservation)
583    }
584
585    fn bounded_memory_pool_and_reservation(
586        size: usize,
587    ) -> (Arc<dyn MemoryPool>, MemoryReservation) {
588        let pool = Arc::new(GreedyMemoryPool::new(size)) as _;
589        let reservation = MemoryConsumer::new("test").register(&pool);
590        (pool, reservation)
591    }
592
593    async fn wait_for_buffering() {
594        // We do not have control over the spawned task, so the best we can do is to yield some
595        // cycles to the tokio runtime and let the task make progress on its own.
596        tokio::time::sleep(Duration::from_millis(1)).await;
597    }
598
599    async fn pull_ok_msg<T: SizedMessage>(
600        buffered: &mut MemoryBufferedStream<T>,
601    ) -> Result<T, Box<dyn Error>> {
602        Ok(timeout(Duration::from_millis(1), buffered.next())
603            .await?
604            .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
605    }
606
607    async fn pull_err_msg<T: SizedMessage + Debug>(
608        buffered: &mut MemoryBufferedStream<T>,
609    ) -> Result<DataFusionError, Box<dyn Error>> {
610        Ok(timeout(Duration::from_millis(1), buffered.next())
611            .await?
612            .map(|v| match v {
613                Ok(v) => internal_err!(
614                    "Stream should not have failed, but succeeded with {v:?}"
615                ),
616                Err(err) => Ok(err),
617            })
618            .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
619    }
620
621    async fn finished<T: SizedMessage>(
622        buffered: &mut MemoryBufferedStream<T>,
623    ) -> Result<(), Box<dyn Error>> {
624        match timeout(Duration::from_millis(1), buffered.next())
625            .await?
626            .is_none()
627        {
628            true => Ok(()),
629            false => internal_err!("Stream should have finished")?,
630        }
631    }
632
633    impl SizedMessage for usize {
634        fn size(&self) -> usize {
635            *self
636        }
637    }
638}