Skip to main content

datafusion_physical_plan/
buffer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`BufferExec`] decouples production and consumption on messages by buffering the input in the
19//! background up to a certain capacity.
20
21use crate::execution_plan::{CardinalityEffect, SchedulingType};
22use crate::filter_pushdown::{
23    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
24    FilterPushdownPropagation,
25};
26use crate::projection::ProjectionExec;
27use crate::stream::RecordBatchStreamAdapter;
28use crate::{
29    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult,
30    check_if_same_properties,
31};
32use arrow::array::RecordBatch;
33use datafusion_common::config::ConfigOptions;
34use datafusion_common::{Result, Statistics, internal_err, plan_err};
35use datafusion_common_runtime::SpawnedTask;
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::{SendableRecordBatchStream, TaskContext};
38use datafusion_physical_expr_common::metrics::{
39    ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
40};
41use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
42use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
43use futures::{Stream, StreamExt, TryStreamExt};
44use pin_project_lite::pin_project;
45use std::any::Any;
46use std::fmt;
47use std::pin::Pin;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50use std::task::{Context, Poll};
51use tokio::sync::mpsc::UnboundedReceiver;
52use tokio::sync::{OwnedSemaphorePermit, Semaphore};
53
54/// WARNING: EXPERIMENTAL
55///
56/// Decouples production and consumption of record batches with an internal queue per partition,
57/// eagerly filling up the capacity of the queues even before any message is requested.
58///
59/// ```text
60///             ┌───────────────────────────┐
61///             │        BufferExec         │
62///             │                           │
63///             │┌────── Partition 0 ──────┐│
64///             ││            ┌────┐ ┌────┐││       ┌────┐
65/// ──background poll────────▶│    │ │    ├┼┼───────▶    │
66///             ││            └────┘ └────┘││       └────┘
67///             │└─────────────────────────┘│
68///             │┌────── Partition 1 ──────┐│
69///             ││     ┌────┐ ┌────┐ ┌────┐││       ┌────┐
70/// ──background poll─▶│    │ │    │ │    ├┼┼───────▶    │
71///             ││     └────┘ └────┘ └────┘││       └────┘
72///             │└─────────────────────────┘│
73///             │                           │
74///             │           ...             │
75///             │                           │
76///             │┌────── Partition N ──────┐│
77///             ││                   ┌────┐││       ┌────┐
78/// ──background poll───────────────▶│    ├┼┼───────▶    │
79///             ││                   └────┘││       └────┘
80///             │└─────────────────────────┘│
81///             └───────────────────────────┘
82/// ```
83///
84/// The capacity is provided in bytes, and for each buffered record batch it will take into account
85/// the size reported by [RecordBatch::get_array_memory_size].
86///
87/// If a single record batch exceeds the maximum capacity set in the `capacity` argument, it's still
88/// allowed to pass in order to not deadlock the buffer.
89///
90/// This is useful for operators that conditionally start polling one of their children only after
91/// other child has finished, allowing to perform some early work and accumulating batches in
92/// memory so that they can be served immediately when requested.
93#[derive(Debug, Clone)]
94pub struct BufferExec {
95    input: Arc<dyn ExecutionPlan>,
96    properties: Arc<PlanProperties>,
97    capacity: usize,
98    metrics: ExecutionPlanMetricsSet,
99}
100
101impl BufferExec {
102    /// Builds a new [BufferExec] with the provided capacity in bytes.
103    pub fn new(input: Arc<dyn ExecutionPlan>, capacity: usize) -> Self {
104        let properties = PlanProperties::clone(input.properties())
105            .with_scheduling_type(SchedulingType::Cooperative);
106
107        Self {
108            input,
109            properties: Arc::new(properties),
110            capacity,
111            metrics: ExecutionPlanMetricsSet::new(),
112        }
113    }
114
115    /// Returns the input [ExecutionPlan] of this [BufferExec].
116    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
117        &self.input
118    }
119
120    /// Returns the per-partition capacity in bytes for this [BufferExec].
121    pub fn capacity(&self) -> usize {
122        self.capacity
123    }
124
125    fn with_new_children_and_same_properties(
126        &self,
127        mut children: Vec<Arc<dyn ExecutionPlan>>,
128    ) -> Self {
129        Self {
130            input: children.swap_remove(0),
131            metrics: ExecutionPlanMetricsSet::new(),
132            ..Self::clone(self)
133        }
134    }
135}
136
137impl DisplayAs for BufferExec {
138    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
139        match t {
140            DisplayFormatType::Default | DisplayFormatType::Verbose => {
141                write!(f, "BufferExec: capacity={}", self.capacity)
142            }
143            DisplayFormatType::TreeRender => {
144                writeln!(f, "target_batch_size={}", self.capacity)
145            }
146        }
147    }
148}
149
150impl ExecutionPlan for BufferExec {
151    fn name(&self) -> &str {
152        "BufferExec"
153    }
154
155    fn as_any(&self) -> &dyn Any {
156        self
157    }
158
159    fn properties(&self) -> &Arc<PlanProperties> {
160        &self.properties
161    }
162
163    fn maintains_input_order(&self) -> Vec<bool> {
164        vec![true]
165    }
166
167    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
168        vec![false]
169    }
170
171    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
172        vec![&self.input]
173    }
174
175    fn with_new_children(
176        self: Arc<Self>,
177        mut children: Vec<Arc<dyn ExecutionPlan>>,
178    ) -> Result<Arc<dyn ExecutionPlan>> {
179        check_if_same_properties!(self, children);
180        if children.len() != 1 {
181            return plan_err!("BufferExec can only have one child");
182        }
183        Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity)))
184    }
185
186    fn execute(
187        &self,
188        partition: usize,
189        context: Arc<TaskContext>,
190    ) -> Result<SendableRecordBatchStream> {
191        let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]"))
192            .register(context.memory_pool());
193        let in_stream = self.input.execute(partition, context)?;
194
195        // Set up the metrics for the stream.
196        let curr_mem_in = Arc::new(AtomicUsize::new(0));
197        let curr_mem_out = Arc::clone(&curr_mem_in);
198        let mut max_mem_in = 0;
199        let max_mem = MetricBuilder::new(&self.metrics).gauge("max_mem_used", partition);
200
201        let curr_queued_in = Arc::new(AtomicUsize::new(0));
202        let curr_queued_out = Arc::clone(&curr_queued_in);
203        let mut max_queued_in = 0;
204        let max_queued = MetricBuilder::new(&self.metrics).gauge("max_queued", partition);
205
206        // Capture metrics when an element is queued on the stream.
207        let in_stream = in_stream.inspect_ok(move |v| {
208            let size = v.get_array_memory_size();
209            let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size;
210            if curr_size > max_mem_in {
211                max_mem_in = curr_size;
212                max_mem.set(max_mem_in);
213            }
214
215            let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1;
216            if curr_queued > max_queued_in {
217                max_queued_in = curr_queued;
218                max_queued.set(max_queued_in);
219            }
220        });
221        // Buffer the input.
222        let out_stream =
223            MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation);
224        // Update in the metrics that when an element gets out, some memory gets freed.
225        let out_stream = out_stream.inspect_ok(move |v| {
226            curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed);
227            curr_queued_out.fetch_sub(1, Ordering::Relaxed);
228        });
229
230        Ok(Box::pin(RecordBatchStreamAdapter::new(
231            self.schema(),
232            out_stream,
233        )))
234    }
235
236    fn metrics(&self) -> Option<MetricsSet> {
237        Some(self.metrics.clone_inner())
238    }
239
240    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
241        self.input.partition_statistics(partition)
242    }
243
244    fn supports_limit_pushdown(&self) -> bool {
245        self.input.supports_limit_pushdown()
246    }
247
248    fn cardinality_effect(&self) -> CardinalityEffect {
249        CardinalityEffect::Equal
250    }
251
252    fn try_swapping_with_projection(
253        &self,
254        projection: &ProjectionExec,
255    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
256        match self.input.try_swapping_with_projection(projection)? {
257            Some(new_input) => Ok(Some(
258                Arc::new(self.clone()).with_new_children(vec![new_input])?,
259            )),
260            None => Ok(None),
261        }
262    }
263
264    fn gather_filters_for_pushdown(
265        &self,
266        _phase: FilterPushdownPhase,
267        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
268        _config: &ConfigOptions,
269    ) -> Result<FilterDescription> {
270        FilterDescription::from_children(parent_filters, &self.children())
271    }
272
273    fn handle_child_pushdown_result(
274        &self,
275        _phase: FilterPushdownPhase,
276        child_pushdown_result: ChildPushdownResult,
277        _config: &ConfigOptions,
278    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
279        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
280    }
281
282    fn try_pushdown_sort(
283        &self,
284        order: &[PhysicalSortExpr],
285    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
286        // CoalesceBatchesExec is transparent for sort ordering - it preserves order
287        // Delegate to the child and wrap with a new CoalesceBatchesExec
288        self.input.try_pushdown_sort(order)?.try_map(|new_input| {
289            Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc<dyn ExecutionPlan>)
290        })
291    }
292}
293
294/// Represents anything that occupies a capacity in a [MemoryBufferedStream].
295pub trait SizedMessage {
296    fn size(&self) -> usize;
297}
298
299impl SizedMessage for RecordBatch {
300    fn size(&self) -> usize {
301        self.get_array_memory_size()
302    }
303}
304
305pin_project! {
306/// Decouples production and consumption of messages in a stream with an internal queue, eagerly
307/// filling it up to the specified maximum capacity even before any message is requested.
308///
309/// Allows each message to have a different size, which is taken into account for determining if
310/// the queue is full or not.
311pub struct MemoryBufferedStream<T: SizedMessage> {
312    task: SpawnedTask<()>,
313    batch_rx: UnboundedReceiver<Result<(T, OwnedSemaphorePermit)>>,
314    memory_reservation: Arc<MemoryReservation>,
315}}
316
317impl<T: Send + SizedMessage + 'static> MemoryBufferedStream<T> {
318    /// Builds a new [MemoryBufferedStream] with the provided capacity and event handler.
319    ///
320    /// This immediately spawns a Tokio task that will start consumption of the input stream.
321    pub fn new(
322        mut input: impl Stream<Item = Result<T>> + Unpin + Send + 'static,
323        capacity: usize,
324        memory_reservation: MemoryReservation,
325    ) -> Self {
326        let semaphore = Arc::new(Semaphore::new(capacity));
327        let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel();
328
329        let memory_reservation = Arc::new(memory_reservation);
330        let memory_reservation_clone = Arc::clone(&memory_reservation);
331        let task = SpawnedTask::spawn(async move {
332            loop {
333                // Select on both the input stream and the channel being closed.
334                // By down this, we abort polling the input as soon as the consumer channel is
335                // closed. Otherwise, we would need to wait for a full new message to be available
336                // in order to consider aborting the stream
337                let item_or_err = tokio::select! {
338                    biased;
339                    _ = batch_tx.closed() => break,
340                    item_or_err = input.next() => {
341                        let Some(item_or_err) = item_or_err else {
342                            break; // stream finished
343                        };
344                        item_or_err
345                    }
346                };
347
348                let item = match item_or_err {
349                    Ok(batch) => batch,
350                    Err(err) => {
351                        let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine.
352                        break;
353                    }
354                };
355
356                let size = item.size();
357                if let Err(err) = memory_reservation.try_grow(size) {
358                    let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine.
359                    break;
360                }
361
362                // We need to cap the minimum between amount of permits and the actual size of the
363                // message. If at any point we try to acquire more permits than the capacity of the
364                // semaphore, the stream will deadlock.
365                let capped_size = size.min(capacity) as u32;
366
367                let semaphore = Arc::clone(&semaphore);
368                let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else {
369                    let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!"));
370                    break;
371                };
372
373                if batch_tx.send(Ok((item, permit))).is_err() {
374                    break; // stream was closed
375                };
376            }
377        });
378
379        Self {
380            task,
381            batch_rx,
382            memory_reservation: memory_reservation_clone,
383        }
384    }
385
386    /// Returns the number of queued messages.
387    pub fn messages_queued(&self) -> usize {
388        self.batch_rx.len()
389    }
390}
391
392impl<T: SizedMessage> Stream for MemoryBufferedStream<T> {
393    type Item = Result<T>;
394
395    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
396        let self_project = self.project();
397        match self_project.batch_rx.poll_recv(cx) {
398            Poll::Ready(Some(Ok((item, _semaphore_permit)))) => {
399                self_project.memory_reservation.shrink(item.size());
400                Poll::Ready(Some(Ok(item)))
401            }
402            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
403            Poll::Ready(None) => Poll::Ready(None),
404            Poll::Pending => Poll::Pending,
405        }
406    }
407
408    fn size_hint(&self) -> (usize, Option<usize>) {
409        if self.batch_rx.is_closed() {
410            let len = self.batch_rx.len();
411            (len, Some(len))
412        } else {
413            (self.batch_rx.len(), None)
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use datafusion_common::{DataFusionError, assert_contains};
422    use datafusion_execution::memory_pool::{
423        GreedyMemoryPool, MemoryPool, UnboundedMemoryPool,
424    };
425    use std::error::Error;
426    use std::fmt::Debug;
427    use std::sync::Arc;
428    use std::time::Duration;
429    use tokio::time::timeout;
430
431    #[tokio::test]
432    async fn buffers_only_some_messages() -> Result<(), Box<dyn Error>> {
433        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
434        let (_, res) = memory_pool_and_reservation();
435
436        let buffered = MemoryBufferedStream::new(input, 4, res);
437        wait_for_buffering().await;
438        assert_eq!(buffered.messages_queued(), 2);
439        Ok(())
440    }
441
442    #[tokio::test]
443    async fn yields_all_messages() -> Result<(), Box<dyn Error>> {
444        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
445        let (_, res) = memory_pool_and_reservation();
446
447        let mut buffered = MemoryBufferedStream::new(input, 10, res);
448        wait_for_buffering().await;
449        assert_eq!(buffered.messages_queued(), 4);
450
451        pull_ok_msg(&mut buffered).await?;
452        pull_ok_msg(&mut buffered).await?;
453        pull_ok_msg(&mut buffered).await?;
454        pull_ok_msg(&mut buffered).await?;
455        finished(&mut buffered).await?;
456        Ok(())
457    }
458
459    #[tokio::test]
460    async fn yields_first_msg_even_if_big() -> Result<(), Box<dyn Error>> {
461        let input = futures::stream::iter([25, 1, 2, 3]).map(Ok);
462        let (_, res) = memory_pool_and_reservation();
463
464        let mut buffered = MemoryBufferedStream::new(input, 10, res);
465        wait_for_buffering().await;
466        assert_eq!(buffered.messages_queued(), 1);
467        pull_ok_msg(&mut buffered).await?;
468        Ok(())
469    }
470
471    #[tokio::test]
472    async fn memory_pool_kills_stream() -> Result<(), Box<dyn Error>> {
473        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
474        let (_, res) = bounded_memory_pool_and_reservation(7);
475
476        let mut buffered = MemoryBufferedStream::new(input, 10, res);
477        wait_for_buffering().await;
478
479        pull_ok_msg(&mut buffered).await?;
480        pull_ok_msg(&mut buffered).await?;
481        pull_ok_msg(&mut buffered).await?;
482        let msg = pull_err_msg(&mut buffered).await?;
483
484        assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B");
485        Ok(())
486    }
487
488    #[tokio::test]
489    async fn memory_pool_does_not_kill_stream() -> Result<(), Box<dyn Error>> {
490        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
491        let (_, res) = bounded_memory_pool_and_reservation(7);
492
493        let mut buffered = MemoryBufferedStream::new(input, 3, res);
494        wait_for_buffering().await;
495        pull_ok_msg(&mut buffered).await?;
496
497        wait_for_buffering().await;
498        pull_ok_msg(&mut buffered).await?;
499
500        wait_for_buffering().await;
501        pull_ok_msg(&mut buffered).await?;
502
503        wait_for_buffering().await;
504        pull_ok_msg(&mut buffered).await?;
505
506        wait_for_buffering().await;
507        finished(&mut buffered).await?;
508        Ok(())
509    }
510
511    #[tokio::test]
512    async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box<dyn Error>> {
513        let input = futures::stream::iter([3, 3, 3, 3]).map(Ok);
514        let (_, res) = memory_pool_and_reservation();
515
516        let mut buffered = MemoryBufferedStream::new(input, 2, res);
517        wait_for_buffering().await;
518        assert_eq!(buffered.messages_queued(), 1);
519        pull_ok_msg(&mut buffered).await?;
520
521        wait_for_buffering().await;
522        assert_eq!(buffered.messages_queued(), 1);
523        pull_ok_msg(&mut buffered).await?;
524
525        wait_for_buffering().await;
526        assert_eq!(buffered.messages_queued(), 1);
527        pull_ok_msg(&mut buffered).await?;
528
529        wait_for_buffering().await;
530        assert_eq!(buffered.messages_queued(), 1);
531        pull_ok_msg(&mut buffered).await?;
532
533        wait_for_buffering().await;
534        finished(&mut buffered).await?;
535        Ok(())
536    }
537
538    #[tokio::test]
539    async fn errors_get_propagated() -> Result<(), Box<dyn Error>> {
540        let input = futures::stream::iter([1, 2, 3, 4]).map(|v| {
541            if v == 3 {
542                return internal_err!("Error on 3");
543            }
544            Ok(v)
545        });
546        let (_, res) = memory_pool_and_reservation();
547
548        let mut buffered = MemoryBufferedStream::new(input, 10, res);
549        wait_for_buffering().await;
550
551        pull_ok_msg(&mut buffered).await?;
552        pull_ok_msg(&mut buffered).await?;
553        pull_err_msg(&mut buffered).await?;
554
555        Ok(())
556    }
557
558    #[tokio::test]
559    async fn memory_gets_released_if_stream_drops() -> Result<(), Box<dyn Error>> {
560        let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
561        let (pool, res) = memory_pool_and_reservation();
562
563        let mut buffered = MemoryBufferedStream::new(input, 10, res);
564        wait_for_buffering().await;
565        assert_eq!(buffered.messages_queued(), 4);
566        assert_eq!(pool.reserved(), 10);
567
568        pull_ok_msg(&mut buffered).await?;
569        assert_eq!(buffered.messages_queued(), 3);
570        assert_eq!(pool.reserved(), 9);
571
572        pull_ok_msg(&mut buffered).await?;
573        assert_eq!(buffered.messages_queued(), 2);
574        assert_eq!(pool.reserved(), 7);
575
576        drop(buffered);
577        assert_eq!(pool.reserved(), 0);
578        Ok(())
579    }
580
581    fn memory_pool_and_reservation() -> (Arc<dyn MemoryPool>, MemoryReservation) {
582        let pool = Arc::new(UnboundedMemoryPool::default()) as _;
583        let reservation = MemoryConsumer::new("test").register(&pool);
584        (pool, reservation)
585    }
586
587    fn bounded_memory_pool_and_reservation(
588        size: usize,
589    ) -> (Arc<dyn MemoryPool>, MemoryReservation) {
590        let pool = Arc::new(GreedyMemoryPool::new(size)) as _;
591        let reservation = MemoryConsumer::new("test").register(&pool);
592        (pool, reservation)
593    }
594
595    async fn wait_for_buffering() {
596        // We do not have control over the spawned task, so the best we can do is to yield some
597        // cycles to the tokio runtime and let the task make progress on its own.
598        tokio::time::sleep(Duration::from_millis(1)).await;
599    }
600
601    async fn pull_ok_msg<T: SizedMessage>(
602        buffered: &mut MemoryBufferedStream<T>,
603    ) -> Result<T, Box<dyn Error>> {
604        Ok(timeout(Duration::from_millis(1), buffered.next())
605            .await?
606            .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
607    }
608
609    async fn pull_err_msg<T: SizedMessage + Debug>(
610        buffered: &mut MemoryBufferedStream<T>,
611    ) -> Result<DataFusionError, Box<dyn Error>> {
612        Ok(timeout(Duration::from_millis(1), buffered.next())
613            .await?
614            .map(|v| match v {
615                Ok(v) => internal_err!(
616                    "Stream should not have failed, but succeeded with {v:?}"
617                ),
618                Err(err) => Ok(err),
619            })
620            .unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
621    }
622
623    async fn finished<T: SizedMessage>(
624        buffered: &mut MemoryBufferedStream<T>,
625    ) -> Result<(), Box<dyn Error>> {
626        match timeout(Duration::from_millis(1), buffered.next())
627            .await?
628            .is_none()
629        {
630            true => Ok(()),
631            false => internal_err!("Stream should have finished")?,
632        }
633    }
634
635    impl SizedMessage for usize {
636        fn size(&self) -> usize {
637            *self
638        }
639    }
640}