Skip to main content

datafusion_physical_plan/
coop.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for improved cooperative scheduling.
19//!
20//! # Cooperative scheduling
21//!
22//! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work
23//! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a
24//! large dataset.
25//! If a `Stream` runs for a long period of time without yielding back to the Tokio executor,
26//! it can starve other tasks waiting on that executor to execute them.
27//! Additionally, this prevents the query execution from being cancelled.
28//!
29//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield
30//! points using the utilities in this module. For most operators this is **not** necessary. The
31//! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate)
32//! `RecordBatch`es such as `DataSourceExec` and those that eagerly consume `RecordBatch`es
33//! (for instance, `RepartitionExec`) contain yield points that will make most query `Stream`s yield
34//! periodically.
35//!
36//! There are a couple of types of operators that _should_ insert yield points:
37//! - New source operators that do not make use of Tokio resources
38//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between
39//!   tasks
40//!
41//! ## Adding yield points
42//!
43//! Yield points can be inserted manually using the facilities provided by the
44//! [Tokio coop module](https://docs.rs/tokio/latest/tokio/task/coop/index.html) such as
45//! [`tokio::task::coop::consume_budget`](https://docs.rs/tokio/latest/tokio/task/coop/fn.consume_budget.html).
46//!
47//! Another option is to use the wrapper `Stream` implementation provided by this module which will
48//! consume a unit of task budget every time a `RecordBatch` is produced.
49//! Wrapper `Stream`s can be created using the [`cooperative`] and [`make_cooperative`] functions.
50//!
51//! [`cooperative`] is a generic function that takes ownership of the wrapped [`RecordBatchStream`].
52//! This function has the benefit of not requiring an additional heap allocation and can avoid
53//! dynamic dispatch.
54//!
55//! [`make_cooperative`] is a non-generic function that wraps a [`SendableRecordBatchStream`]. This
56//! can be used to wrap dynamically typed, heap allocated [`RecordBatchStream`]s.
57//!
58//! ## Automatic cooperation
59//!
60//! The `EnsureCooperative` physical optimizer rule, which is included in the default set of
61//! optimizer rules, inspects query plans for potential cooperative scheduling issues.
62//! It injects the [`CooperativeExec`] wrapper `ExecutionPlan` into the query plan where necessary.
63//! This `ExecutionPlan` uses [`make_cooperative`] to wrap the `Stream` of its input.
64//!
65//! The optimizer rule currently checks the plan for exchange-like operators and leave operators
66//! that report [`SchedulingType::NonCooperative`] in their [plan properties](ExecutionPlan::properties).
67
68use datafusion_common::config::ConfigOptions;
69use datafusion_physical_expr::PhysicalExpr;
70#[cfg(datafusion_coop = "tokio_fallback")]
71use futures::Future;
72use std::any::Any;
73use std::pin::Pin;
74use std::sync::Arc;
75use std::task::{Context, Poll};
76
77use crate::execution_plan::CardinalityEffect::{self, Equal};
78use crate::filter_pushdown::{
79    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
80    FilterPushdownPropagation,
81};
82use crate::projection::ProjectionExec;
83use crate::{
84    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
85    SendableRecordBatchStream, SortOrderPushdownResult,
86};
87use arrow::record_batch::RecordBatch;
88use arrow_schema::Schema;
89use datafusion_common::{Result, Statistics, assert_eq_or_internal_err};
90use datafusion_execution::TaskContext;
91
92use crate::execution_plan::SchedulingType;
93use crate::stream::RecordBatchStreamAdapter;
94use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
95use futures::{Stream, StreamExt};
96
97/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime.
98/// It consumes cooperative scheduling budget for each returned [`RecordBatch`],
99/// allowing other tasks to execute when the budget is exhausted.
100///
101/// See the [module level documentation](crate::coop) for an in-depth discussion.
102pub struct CooperativeStream<T>
103where
104    T: RecordBatchStream + Unpin,
105{
106    inner: T,
107    #[cfg(datafusion_coop = "per_stream")]
108    budget: u8,
109}
110
111#[cfg(datafusion_coop = "per_stream")]
112// Magic value that matches Tokio's task budget value
113const YIELD_FREQUENCY: u8 = 128;
114
115impl<T> CooperativeStream<T>
116where
117    T: RecordBatchStream + Unpin,
118{
119    /// Creates a new `CooperativeStream` that wraps the provided stream.
120    /// The resulting stream will cooperate with the Tokio scheduler by consuming a unit of
121    /// scheduling budget when the wrapped `Stream` returns a record batch.
122    pub fn new(inner: T) -> Self {
123        Self {
124            inner,
125            #[cfg(datafusion_coop = "per_stream")]
126            budget: YIELD_FREQUENCY,
127        }
128    }
129}
130
131impl<T> Stream for CooperativeStream<T>
132where
133    T: RecordBatchStream + Unpin,
134{
135    type Item = Result<RecordBatch>;
136
137    fn poll_next(
138        mut self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140    ) -> Poll<Option<Self::Item>> {
141        #[cfg(any(
142            datafusion_coop = "tokio",
143            not(any(
144                datafusion_coop = "tokio_fallback",
145                datafusion_coop = "per_stream"
146            ))
147        ))]
148        {
149            let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
150            let value = self.inner.poll_next_unpin(cx);
151            if value.is_ready() {
152                coop.made_progress();
153            }
154            value
155        }
156
157        #[cfg(datafusion_coop = "tokio_fallback")]
158        {
159            // This is a temporary placeholder implementation that may have slightly
160            // worse performance compared to `poll_proceed`
161            if !tokio::task::coop::has_budget_remaining() {
162                cx.waker().wake_by_ref();
163                return Poll::Pending;
164            }
165
166            let value = self.inner.poll_next_unpin(cx);
167            if value.is_ready() {
168                // In contrast to `poll_proceed` we are not able to consume
169                // budget before proceeding to do work. Instead, we try to consume budget
170                // after the work has been done and just assume that that succeeded.
171                // The poll result is ignored because we don't want to discard
172                // or buffer the Ready result we got from the inner stream.
173                let consume = tokio::task::coop::consume_budget();
174                let consume_ref = std::pin::pin!(consume);
175                let _ = consume_ref.poll(cx);
176            }
177            value
178        }
179
180        #[cfg(datafusion_coop = "per_stream")]
181        {
182            if self.budget == 0 {
183                self.budget = YIELD_FREQUENCY;
184                cx.waker().wake_by_ref();
185                return Poll::Pending;
186            }
187
188            let value = { self.inner.poll_next_unpin(cx) };
189
190            if value.is_ready() {
191                self.budget -= 1;
192            } else {
193                self.budget = YIELD_FREQUENCY;
194            }
195            value
196        }
197    }
198}
199
200impl<T> RecordBatchStream for CooperativeStream<T>
201where
202    T: RecordBatchStream + Unpin,
203{
204    fn schema(&self) -> Arc<Schema> {
205        self.inner.schema()
206    }
207}
208
209/// An execution plan decorator that enables cooperative multitasking.
210/// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function,
211/// which makes the stream participate in Tokio cooperative scheduling.
212#[derive(Debug, Clone)]
213pub struct CooperativeExec {
214    input: Arc<dyn ExecutionPlan>,
215    properties: PlanProperties,
216}
217
218impl CooperativeExec {
219    /// Creates a new `CooperativeExec` operator that wraps the given input execution plan.
220    pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
221        let properties = input
222            .properties()
223            .clone()
224            .with_scheduling_type(SchedulingType::Cooperative);
225
226        Self { input, properties }
227    }
228
229    /// Returns a reference to the wrapped input execution plan.
230    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
231        &self.input
232    }
233}
234
235impl DisplayAs for CooperativeExec {
236    fn fmt_as(
237        &self,
238        _t: DisplayFormatType,
239        f: &mut std::fmt::Formatter<'_>,
240    ) -> std::fmt::Result {
241        write!(f, "CooperativeExec")
242    }
243}
244
245impl ExecutionPlan for CooperativeExec {
246    fn name(&self) -> &str {
247        "CooperativeExec"
248    }
249
250    fn as_any(&self) -> &dyn Any {
251        self
252    }
253
254    fn schema(&self) -> Arc<Schema> {
255        self.input.schema()
256    }
257
258    fn properties(&self) -> &PlanProperties {
259        &self.properties
260    }
261
262    fn maintains_input_order(&self) -> Vec<bool> {
263        vec![true; self.children().len()]
264    }
265
266    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
267        vec![&self.input]
268    }
269
270    fn with_new_children(
271        self: Arc<Self>,
272        mut children: Vec<Arc<dyn ExecutionPlan>>,
273    ) -> Result<Arc<dyn ExecutionPlan>> {
274        assert_eq_or_internal_err!(
275            children.len(),
276            1,
277            "CooperativeExec requires exactly one child"
278        );
279        Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
280    }
281
282    fn execute(
283        &self,
284        partition: usize,
285        task_ctx: Arc<TaskContext>,
286    ) -> Result<SendableRecordBatchStream> {
287        let child_stream = self.input.execute(partition, task_ctx)?;
288        Ok(make_cooperative(child_stream))
289    }
290
291    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
292        self.input.partition_statistics(partition)
293    }
294
295    fn supports_limit_pushdown(&self) -> bool {
296        true
297    }
298
299    fn cardinality_effect(&self) -> CardinalityEffect {
300        Equal
301    }
302
303    fn try_swapping_with_projection(
304        &self,
305        projection: &ProjectionExec,
306    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
307        match self.input.try_swapping_with_projection(projection)? {
308            Some(new_input) => Ok(Some(
309                Arc::new(self.clone()).with_new_children(vec![new_input])?,
310            )),
311            None => Ok(None),
312        }
313    }
314
315    fn gather_filters_for_pushdown(
316        &self,
317        _phase: FilterPushdownPhase,
318        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
319        _config: &ConfigOptions,
320    ) -> Result<FilterDescription> {
321        FilterDescription::from_children(parent_filters, &self.children())
322    }
323
324    fn handle_child_pushdown_result(
325        &self,
326        _phase: FilterPushdownPhase,
327        child_pushdown_result: ChildPushdownResult,
328        _config: &ConfigOptions,
329    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
330        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
331    }
332
333    fn try_pushdown_sort(
334        &self,
335        order: &[PhysicalSortExpr],
336    ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
337        let child = self.input();
338
339        match child.try_pushdown_sort(order)? {
340            SortOrderPushdownResult::Exact { inner } => {
341                let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
342                Ok(SortOrderPushdownResult::Exact { inner: new_exec })
343            }
344            SortOrderPushdownResult::Inexact { inner } => {
345                let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
346                Ok(SortOrderPushdownResult::Inexact { inner: new_exec })
347            }
348            SortOrderPushdownResult::Unsupported => {
349                Ok(SortOrderPushdownResult::Unsupported)
350            }
351        }
352    }
353}
354
355/// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`].
356/// This wrapper collaborates with the Tokio cooperative scheduler by consuming a unit of
357/// scheduling budget for each returned record batch.
358pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
359where
360    T: RecordBatchStream + Unpin + Send + 'static,
361{
362    CooperativeStream::new(stream)
363}
364
365/// Wraps a `SendableRecordBatchStream` inside a [`CooperativeStream`] to enable cooperative multitasking.
366/// Since `SendableRecordBatchStream` is a `dyn RecordBatchStream` this requires the use of dynamic
367/// method dispatch.
368/// When the stream type is statically known, consider use the generic [`cooperative`] function
369/// to allow static method dispatch.
370pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
371    // TODO is there a more elegant way to overload cooperative
372    Box::pin(cooperative(RecordBatchStreamAdapter::new(
373        stream.schema(),
374        stream,
375    )))
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::stream::RecordBatchStreamAdapter;
382
383    use arrow_schema::SchemaRef;
384
385    use futures::{StreamExt, stream};
386
387    // This is the hardcoded value Tokio uses
388    const TASK_BUDGET: usize = 128;
389
390    /// Helper: construct a SendableRecordBatchStream containing `n` empty batches
391    fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
392        let schema: SchemaRef = Arc::new(Schema::empty());
393        let schema_for_stream = Arc::clone(&schema);
394
395        let s =
396            stream::iter((0..n).map(move |_| {
397                Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
398            }));
399
400        Box::pin(RecordBatchStreamAdapter::new(schema, s))
401    }
402
403    #[tokio::test]
404    async fn yield_less_than_threshold() -> Result<()> {
405        let count = TASK_BUDGET - 10;
406        let inner = make_empty_batches(count);
407        let out = make_cooperative(inner).collect::<Vec<_>>().await;
408        assert_eq!(out.len(), count);
409        Ok(())
410    }
411
412    #[tokio::test]
413    async fn yield_equal_to_threshold() -> Result<()> {
414        let count = TASK_BUDGET;
415        let inner = make_empty_batches(count);
416        let out = make_cooperative(inner).collect::<Vec<_>>().await;
417        assert_eq!(out.len(), count);
418        Ok(())
419    }
420
421    #[tokio::test]
422    async fn yield_more_than_threshold() -> Result<()> {
423        let count = TASK_BUDGET + 20;
424        let inner = make_empty_batches(count);
425        let out = make_cooperative(inner).collect::<Vec<_>>().await;
426        assert_eq!(out.len(), count);
427        Ok(())
428    }
429}