Skip to main content

datafusion_physical_plan/
coop.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for improved cooperative scheduling.
19//!
20//! # Cooperative scheduling
21//!
22//! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work
23//! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a
24//! large dataset.
25//!
26//! If a `Stream` runs for a long period of time without yielding back to the Tokio executor,
27//! it can starve other tasks waiting on that executor to execute them.
28//! Additionally, this prevents the query execution from being cancelled.
29//!
30//! For more background, please also see the [Using Rust async for Query Execution and Cancelling Long-Running Queries blog]
31//!
32//! [Using Rust async for Query Execution and Cancelling Long-Running Queries blog]: https://datafusion.apache.org/blog/2025/06/30/cancellation
33//!
34//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield
35//! points using the utilities in this module. For most operators this is **not** necessary. The
36//! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate)
37//! `RecordBatch`es such as `DataSourceExec` and those that eagerly consume `RecordBatch`es
38//! (for instance, `RepartitionExec`) contain yield points that will make most query `Stream`s yield
39//! periodically.
40//!
41//! There are a couple of types of operators that _should_ insert yield points:
42//! - New source operators that do not make use of Tokio resources
43//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between
44//!   tasks
45//!
46//! ## Adding yield points
47//!
48//! Yield points can be inserted manually using the facilities provided by the
49//! [Tokio coop module](https://docs.rs/tokio/latest/tokio/task/coop/index.html) such as
50//! [`tokio::task::coop::consume_budget`](https://docs.rs/tokio/latest/tokio/task/coop/fn.consume_budget.html).
51//!
52//! Another option is to use the wrapper `Stream` implementation provided by this module which will
53//! consume a unit of task budget every time a `RecordBatch` is produced.
54//! Wrapper `Stream`s can be created using the [`cooperative`] and [`make_cooperative`] functions.
55//!
56//! [`cooperative`] is a generic function that takes ownership of the wrapped [`RecordBatchStream`].
57//! This function has the benefit of not requiring an additional heap allocation and can avoid
58//! dynamic dispatch.
59//!
60//! [`make_cooperative`] is a non-generic function that wraps a [`SendableRecordBatchStream`]. This
61//! can be used to wrap dynamically typed, heap allocated [`RecordBatchStream`]s.
62//!
63//! ## Automatic cooperation
64//!
65//! The `EnsureCooperative` physical optimizer rule, which is included in the default set of
66//! optimizer rules, inspects query plans for potential cooperative scheduling issues.
67//! It injects the [`CooperativeExec`] wrapper `ExecutionPlan` into the query plan where necessary.
68//! This `ExecutionPlan` uses [`make_cooperative`] to wrap the `Stream` of its input.
69//!
70//! The optimizer rule currently checks the plan for exchange-like operators and leave operators
71//! that report [`SchedulingType::NonCooperative`] in their [plan properties](ExecutionPlan::properties).
72
73use datafusion_common::config::ConfigOptions;
74use datafusion_physical_expr::PhysicalExpr;
75#[cfg(datafusion_coop = "tokio_fallback")]
76use futures::Future;
77use std::any::Any;
78use std::pin::Pin;
79use std::sync::Arc;
80use std::task::{Context, Poll};
81
82use crate::execution_plan::CardinalityEffect::{self, Equal};
83use crate::filter_pushdown::{
84    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
85    FilterPushdownPropagation,
86};
87use crate::projection::ProjectionExec;
88use crate::{
89    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
90    SendableRecordBatchStream, SortOrderPushdownResult, check_if_same_properties,
91};
92use arrow::record_batch::RecordBatch;
93use arrow_schema::Schema;
94use datafusion_common::{Result, Statistics, assert_eq_or_internal_err};
95use datafusion_execution::TaskContext;
96
97use crate::execution_plan::SchedulingType;
98use crate::stream::RecordBatchStreamAdapter;
99use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
100use futures::{Stream, StreamExt};
101
102/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime.
103/// It consumes cooperative scheduling budget for each returned [`RecordBatch`],
104/// allowing other tasks to execute when the budget is exhausted.
105///
106/// See the [module level documentation](crate::coop) for an in-depth discussion.
107pub struct CooperativeStream<T>
108where
109    T: RecordBatchStream + Unpin,
110{
111    inner: T,
112    #[cfg(datafusion_coop = "per_stream")]
113    budget: u8,
114}
115
116#[cfg(datafusion_coop = "per_stream")]
117// Magic value that matches Tokio's task budget value
118const YIELD_FREQUENCY: u8 = 128;
119
120impl<T> CooperativeStream<T>
121where
122    T: RecordBatchStream + Unpin,
123{
124    /// Creates a new `CooperativeStream` that wraps the provided stream.
125    /// The resulting stream will cooperate with the Tokio scheduler by consuming a unit of
126    /// scheduling budget when the wrapped `Stream` returns a record batch.
127    pub fn new(inner: T) -> Self {
128        Self {
129            inner,
130            #[cfg(datafusion_coop = "per_stream")]
131            budget: YIELD_FREQUENCY,
132        }
133    }
134}
135
136impl<T> Stream for CooperativeStream<T>
137where
138    T: RecordBatchStream + Unpin,
139{
140    type Item = Result<RecordBatch>;
141
142    fn poll_next(
143        mut self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145    ) -> Poll<Option<Self::Item>> {
146        #[cfg(any(
147            datafusion_coop = "tokio",
148            not(any(
149                datafusion_coop = "tokio_fallback",
150                datafusion_coop = "per_stream"
151            ))
152        ))]
153        {
154            let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
155            let value = self.inner.poll_next_unpin(cx);
156            if value.is_ready() {
157                coop.made_progress();
158            }
159            value
160        }
161
162        #[cfg(datafusion_coop = "tokio_fallback")]
163        {
164            // This is a temporary placeholder implementation that may have slightly
165            // worse performance compared to `poll_proceed`
166            if !tokio::task::coop::has_budget_remaining() {
167                cx.waker().wake_by_ref();
168                return Poll::Pending;
169            }
170
171            let value = self.inner.poll_next_unpin(cx);
172            if value.is_ready() {
173                // In contrast to `poll_proceed` we are not able to consume
174                // budget before proceeding to do work. Instead, we try to consume budget
175                // after the work has been done and just assume that that succeeded.
176                // The poll result is ignored because we don't want to discard
177                // or buffer the Ready result we got from the inner stream.
178                let consume = tokio::task::coop::consume_budget();
179                let consume_ref = std::pin::pin!(consume);
180                let _ = consume_ref.poll(cx);
181            }
182            value
183        }
184
185        #[cfg(datafusion_coop = "per_stream")]
186        {
187            if self.budget == 0 {
188                self.budget = YIELD_FREQUENCY;
189                cx.waker().wake_by_ref();
190                return Poll::Pending;
191            }
192
193            let value = { self.inner.poll_next_unpin(cx) };
194
195            if value.is_ready() {
196                self.budget -= 1;
197            } else {
198                self.budget = YIELD_FREQUENCY;
199            }
200            value
201        }
202    }
203}
204
205impl<T> RecordBatchStream for CooperativeStream<T>
206where
207    T: RecordBatchStream + Unpin,
208{
209    fn schema(&self) -> Arc<Schema> {
210        self.inner.schema()
211    }
212}
213
214/// An execution plan decorator that enables cooperative multitasking.
215/// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function,
216/// which makes the stream participate in Tokio cooperative scheduling.
217#[derive(Debug, Clone)]
218pub struct CooperativeExec {
219    input: Arc<dyn ExecutionPlan>,
220    properties: Arc<PlanProperties>,
221}
222
223impl CooperativeExec {
224    /// Creates a new `CooperativeExec` operator that wraps the given input execution plan.
225    pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
226        let properties = PlanProperties::clone(input.properties())
227            .with_scheduling_type(SchedulingType::Cooperative)
228            .into();
229
230        Self { input, properties }
231    }
232
233    /// Returns a reference to the wrapped input execution plan.
234    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
235        &self.input
236    }
237
238    fn with_new_children_and_same_properties(
239        &self,
240        mut children: Vec<Arc<dyn ExecutionPlan>>,
241    ) -> Self {
242        Self {
243            input: children.swap_remove(0),
244            ..Self::clone(self)
245        }
246    }
247}
248
249impl DisplayAs for CooperativeExec {
250    fn fmt_as(
251        &self,
252        _t: DisplayFormatType,
253        f: &mut std::fmt::Formatter<'_>,
254    ) -> std::fmt::Result {
255        write!(f, "CooperativeExec")
256    }
257}
258
259impl ExecutionPlan for CooperativeExec {
260    fn name(&self) -> &str {
261        "CooperativeExec"
262    }
263
264    fn as_any(&self) -> &dyn Any {
265        self
266    }
267
268    fn schema(&self) -> Arc<Schema> {
269        self.input.schema()
270    }
271
272    fn properties(&self) -> &Arc<PlanProperties> {
273        &self.properties
274    }
275
276    fn maintains_input_order(&self) -> Vec<bool> {
277        vec![true; self.children().len()]
278    }
279
280    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
281        vec![&self.input]
282    }
283
284    fn with_new_children(
285        self: Arc<Self>,
286        mut children: Vec<Arc<dyn ExecutionPlan>>,
287    ) -> Result<Arc<dyn ExecutionPlan>> {
288        assert_eq_or_internal_err!(
289            children.len(),
290            1,
291            "CooperativeExec requires exactly one child"
292        );
293        check_if_same_properties!(self, children);
294        Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
295    }
296
297    fn execute(
298        &self,
299        partition: usize,
300        task_ctx: Arc<TaskContext>,
301    ) -> Result<SendableRecordBatchStream> {
302        let child_stream = self.input.execute(partition, task_ctx)?;
303        Ok(make_cooperative(child_stream))
304    }
305
306    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
307        self.input.partition_statistics(partition)
308    }
309
310    fn supports_limit_pushdown(&self) -> bool {
311        true
312    }
313
314    fn cardinality_effect(&self) -> CardinalityEffect {
315        Equal
316    }
317
318    fn try_swapping_with_projection(
319        &self,
320        projection: &ProjectionExec,
321    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
322        match self.input.try_swapping_with_projection(projection)? {
323            Some(new_input) => Ok(Some(
324                Arc::new(self.clone()).with_new_children(vec![new_input])?,
325            )),
326            None => Ok(None),
327        }
328    }
329
330    fn gather_filters_for_pushdown(
331        &self,
332        _phase: FilterPushdownPhase,
333        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
334        _config: &ConfigOptions,
335    ) -> Result<FilterDescription> {
336        FilterDescription::from_children(parent_filters, &self.children())
337    }
338
339    fn handle_child_pushdown_result(
340        &self,
341        _phase: FilterPushdownPhase,
342        child_pushdown_result: ChildPushdownResult,
343        _config: &ConfigOptions,
344    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
345        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
346    }
347
348    fn try_pushdown_sort(
349        &self,
350        order: &[PhysicalSortExpr],
351    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
352        let child = self.input();
353
354        match child.try_pushdown_sort(order)? {
355            SortOrderPushdownResult::Exact { inner } => {
356                let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
357                Ok(SortOrderPushdownResult::Exact { inner: new_exec })
358            }
359            SortOrderPushdownResult::Inexact { inner } => {
360                let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
361                Ok(SortOrderPushdownResult::Inexact { inner: new_exec })
362            }
363            SortOrderPushdownResult::Unsupported => {
364                Ok(SortOrderPushdownResult::Unsupported)
365            }
366        }
367    }
368}
369
370/// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`].
371/// This wrapper collaborates with the Tokio cooperative scheduler by consuming a unit of
372/// scheduling budget for each returned record batch.
373pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
374where
375    T: RecordBatchStream + Unpin + Send + 'static,
376{
377    CooperativeStream::new(stream)
378}
379
380/// Wraps a `SendableRecordBatchStream` inside a [`CooperativeStream`] to enable cooperative multitasking.
381/// Since `SendableRecordBatchStream` is a `dyn RecordBatchStream` this requires the use of dynamic
382/// method dispatch.
383/// When the stream type is statically known, consider use the generic [`cooperative`] function
384/// to allow static method dispatch.
385pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
386    // TODO is there a more elegant way to overload cooperative
387    Box::pin(cooperative(RecordBatchStreamAdapter::new(
388        stream.schema(),
389        stream,
390    )))
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::stream::RecordBatchStreamAdapter;
397
398    use arrow_schema::SchemaRef;
399
400    use futures::{StreamExt, stream};
401
402    // This is the hardcoded value Tokio uses
403    const TASK_BUDGET: usize = 128;
404
405    /// Helper: construct a SendableRecordBatchStream containing `n` empty batches
406    fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
407        let schema: SchemaRef = Arc::new(Schema::empty());
408        let schema_for_stream = Arc::clone(&schema);
409
410        let s =
411            stream::iter((0..n).map(move |_| {
412                Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
413            }));
414
415        Box::pin(RecordBatchStreamAdapter::new(schema, s))
416    }
417
418    #[tokio::test]
419    async fn yield_less_than_threshold() -> Result<()> {
420        let count = TASK_BUDGET - 10;
421        let inner = make_empty_batches(count);
422        let out = make_cooperative(inner).collect::<Vec<_>>().await;
423        assert_eq!(out.len(), count);
424        Ok(())
425    }
426
427    #[tokio::test]
428    async fn yield_equal_to_threshold() -> Result<()> {
429        let count = TASK_BUDGET;
430        let inner = make_empty_batches(count);
431        let out = make_cooperative(inner).collect::<Vec<_>>().await;
432        assert_eq!(out.len(), count);
433        Ok(())
434    }
435
436    #[tokio::test]
437    async fn yield_more_than_threshold() -> Result<()> {
438        let count = TASK_BUDGET + 20;
439        let inner = make_empty_batches(count);
440        let out = make_cooperative(inner).collect::<Vec<_>>().await;
441        assert_eq!(out.len(), count);
442        Ok(())
443    }
444}