datafusion_physical_plan/coop.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for improved cooperative scheduling.
19//!
20//! # Cooperative scheduling
21//!
22//! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work
23//! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a
24//! large dataset.
25//! If a `Stream` runs for a long period of time without yielding back to the Tokio executor,
26//! it can starve other tasks waiting on that executor to execute them.
27//! Additionally, this prevents the query execution from being cancelled.
28//!
29//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield
30//! points using the utilities in this module. For most operators this is **not** necessary. The
31//! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate)
32//! `RecordBatch`es such as `DataSourceExec` and those that eagerly consume `RecordBatch`es
33//! (for instance, `RepartitionExec`) contain yield points that will make most query `Stream`s yield
34//! periodically.
35//!
36//! There are a couple of types of operators that _should_ insert yield points:
37//! - New source operators that do not make use of Tokio resources
38//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between
39//! tasks
40//!
41//! ## Adding yield points
42//!
43//! Yield points can be inserted manually using the facilities provided by the
44//! [Tokio coop module](https://docs.rs/tokio/latest/tokio/task/coop/index.html) such as
45//! [`tokio::task::coop::consume_budget`](https://docs.rs/tokio/latest/tokio/task/coop/fn.consume_budget.html).
46//!
47//! Another option is to use the wrapper `Stream` implementation provided by this module which will
48//! consume a unit of task budget every time a `RecordBatch` is produced.
49//! Wrapper `Stream`s can be created using the [`cooperative`] and [`make_cooperative`] functions.
50//!
51//! [`cooperative`] is a generic function that takes ownership of the wrapped [`RecordBatchStream`].
52//! This function has the benefit of not requiring an additional heap allocation and can avoid
53//! dynamic dispatch.
54//!
55//! [`make_cooperative`] is a non-generic function that wraps a [`SendableRecordBatchStream`]. This
56//! can be used to wrap dynamically typed, heap allocated [`RecordBatchStream`]s.
57//!
58//! ## Automatic cooperation
59//!
60//! The `EnsureCooperative` physical optimizer rule, which is included in the default set of
61//! optimizer rules, inspects query plans for potential cooperative scheduling issues.
62//! It injects the [`CooperativeExec`] wrapper `ExecutionPlan` into the query plan where necessary.
63//! This `ExecutionPlan` uses [`make_cooperative`] to wrap the `Stream` of its input.
64//!
65//! The optimizer rule currently checks the plan for exchange-like operators and leave operators
66//! that report [`SchedulingType::NonCooperative`] in their [plan properties](ExecutionPlan::properties).
67
68use datafusion_common::config::ConfigOptions;
69use datafusion_physical_expr::PhysicalExpr;
70#[cfg(datafusion_coop = "tokio_fallback")]
71use futures::Future;
72use std::any::Any;
73use std::pin::Pin;
74use std::sync::Arc;
75use std::task::{Context, Poll};
76
77use crate::execution_plan::CardinalityEffect::{self, Equal};
78use crate::filter_pushdown::{
79 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
80 FilterPushdownPropagation,
81};
82use crate::{
83 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
84 SendableRecordBatchStream,
85};
86use arrow::record_batch::RecordBatch;
87use arrow_schema::Schema;
88use datafusion_common::{internal_err, Result, Statistics};
89use datafusion_execution::TaskContext;
90
91use crate::execution_plan::SchedulingType;
92use crate::stream::RecordBatchStreamAdapter;
93use futures::{Stream, StreamExt};
94
95/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime.
96/// It consumes cooperative scheduling budget for each returned [`RecordBatch`],
97/// allowing other tasks to execute when the budget is exhausted.
98///
99/// See the [module level documentation](crate::coop) for an in-depth discussion.
100pub struct CooperativeStream<T>
101where
102 T: RecordBatchStream + Unpin,
103{
104 inner: T,
105 #[cfg(datafusion_coop = "per_stream")]
106 budget: u8,
107}
108
109#[cfg(datafusion_coop = "per_stream")]
110// Magic value that matches Tokio's task budget value
111const YIELD_FREQUENCY: u8 = 128;
112
113impl<T> CooperativeStream<T>
114where
115 T: RecordBatchStream + Unpin,
116{
117 /// Creates a new `CooperativeStream` that wraps the provided stream.
118 /// The resulting stream will cooperate with the Tokio scheduler by consuming a unit of
119 /// scheduling budget when the wrapped `Stream` returns a record batch.
120 pub fn new(inner: T) -> Self {
121 Self {
122 inner,
123 #[cfg(datafusion_coop = "per_stream")]
124 budget: YIELD_FREQUENCY,
125 }
126 }
127}
128
129impl<T> Stream for CooperativeStream<T>
130where
131 T: RecordBatchStream + Unpin,
132{
133 type Item = Result<RecordBatch>;
134
135 fn poll_next(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 ) -> Poll<Option<Self::Item>> {
139 #[cfg(any(
140 datafusion_coop = "tokio",
141 not(any(
142 datafusion_coop = "tokio_fallback",
143 datafusion_coop = "per_stream"
144 ))
145 ))]
146 {
147 let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
148 let value = self.inner.poll_next_unpin(cx);
149 if value.is_ready() {
150 coop.made_progress();
151 }
152 value
153 }
154
155 #[cfg(datafusion_coop = "tokio_fallback")]
156 {
157 // This is a temporary placeholder implementation that may have slightly
158 // worse performance compared to `poll_proceed`
159 if !tokio::task::coop::has_budget_remaining() {
160 cx.waker().wake_by_ref();
161 return Poll::Pending;
162 }
163
164 let value = self.inner.poll_next_unpin(cx);
165 if value.is_ready() {
166 // In contrast to `poll_proceed` we are not able to consume
167 // budget before proceeding to do work. Instead, we try to consume budget
168 // after the work has been done and just assume that that succeeded.
169 // The poll result is ignored because we don't want to discard
170 // or buffer the Ready result we got from the inner stream.
171 let consume = tokio::task::coop::consume_budget();
172 let consume_ref = std::pin::pin!(consume);
173 let _ = consume_ref.poll(cx);
174 }
175 value
176 }
177
178 #[cfg(datafusion_coop = "per_stream")]
179 {
180 if self.budget == 0 {
181 self.budget = YIELD_FREQUENCY;
182 cx.waker().wake_by_ref();
183 return Poll::Pending;
184 }
185
186 let value = { self.inner.poll_next_unpin(cx) };
187
188 if value.is_ready() {
189 self.budget -= 1;
190 } else {
191 self.budget = YIELD_FREQUENCY;
192 }
193 value
194 }
195 }
196}
197
198impl<T> RecordBatchStream for CooperativeStream<T>
199where
200 T: RecordBatchStream + Unpin,
201{
202 fn schema(&self) -> Arc<Schema> {
203 self.inner.schema()
204 }
205}
206
207/// An execution plan decorator that enables cooperative multitasking.
208/// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function,
209/// which makes the stream participate in Tokio cooperative scheduling.
210#[derive(Debug)]
211pub struct CooperativeExec {
212 input: Arc<dyn ExecutionPlan>,
213 properties: PlanProperties,
214}
215
216impl CooperativeExec {
217 /// Creates a new `CooperativeExec` operator that wraps the given input execution plan.
218 pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
219 let properties = input
220 .properties()
221 .clone()
222 .with_scheduling_type(SchedulingType::Cooperative);
223
224 Self { input, properties }
225 }
226
227 /// Returns a reference to the wrapped input execution plan.
228 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
229 &self.input
230 }
231}
232
233impl DisplayAs for CooperativeExec {
234 fn fmt_as(
235 &self,
236 _t: DisplayFormatType,
237 f: &mut std::fmt::Formatter<'_>,
238 ) -> std::fmt::Result {
239 write!(f, "CooperativeExec")
240 }
241}
242
243impl ExecutionPlan for CooperativeExec {
244 fn name(&self) -> &str {
245 "CooperativeExec"
246 }
247
248 fn as_any(&self) -> &dyn Any {
249 self
250 }
251
252 fn schema(&self) -> Arc<Schema> {
253 self.input.schema()
254 }
255
256 fn properties(&self) -> &PlanProperties {
257 &self.properties
258 }
259
260 fn maintains_input_order(&self) -> Vec<bool> {
261 vec![true; self.children().len()]
262 }
263
264 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
265 vec![&self.input]
266 }
267
268 fn with_new_children(
269 self: Arc<Self>,
270 mut children: Vec<Arc<dyn ExecutionPlan>>,
271 ) -> Result<Arc<dyn ExecutionPlan>> {
272 if children.len() != 1 {
273 return internal_err!("CooperativeExec requires exactly one child");
274 }
275 Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
276 }
277
278 fn execute(
279 &self,
280 partition: usize,
281 task_ctx: Arc<TaskContext>,
282 ) -> Result<SendableRecordBatchStream> {
283 let child_stream = self.input.execute(partition, task_ctx)?;
284 Ok(make_cooperative(child_stream))
285 }
286
287 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
288 self.input.partition_statistics(partition)
289 }
290
291 fn supports_limit_pushdown(&self) -> bool {
292 true
293 }
294
295 fn cardinality_effect(&self) -> CardinalityEffect {
296 Equal
297 }
298
299 fn gather_filters_for_pushdown(
300 &self,
301 _phase: FilterPushdownPhase,
302 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
303 _config: &ConfigOptions,
304 ) -> Result<FilterDescription> {
305 FilterDescription::from_children(parent_filters, &self.children())
306 }
307
308 fn handle_child_pushdown_result(
309 &self,
310 _phase: FilterPushdownPhase,
311 child_pushdown_result: ChildPushdownResult,
312 _config: &ConfigOptions,
313 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
314 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
315 }
316}
317
318/// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`].
319/// This wrapper collaborates with the Tokio cooperative scheduler by consuming a unit of
320/// scheduling budget for each returned record batch.
321pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
322where
323 T: RecordBatchStream + Unpin + Send + 'static,
324{
325 CooperativeStream::new(stream)
326}
327
328/// Wraps a `SendableRecordBatchStream` inside a [`CooperativeStream`] to enable cooperative multitasking.
329/// Since `SendableRecordBatchStream` is a `dyn RecordBatchStream` this requires the use of dynamic
330/// method dispatch.
331/// When the stream type is statically known, consider use the generic [`cooperative`] function
332/// to allow static method dispatch.
333pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
334 // TODO is there a more elegant way to overload cooperative
335 Box::pin(cooperative(RecordBatchStreamAdapter::new(
336 stream.schema(),
337 stream,
338 )))
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::stream::RecordBatchStreamAdapter;
345
346 use arrow_schema::SchemaRef;
347
348 use futures::{stream, StreamExt};
349
350 // This is the hardcoded value Tokio uses
351 const TASK_BUDGET: usize = 128;
352
353 /// Helper: construct a SendableRecordBatchStream containing `n` empty batches
354 fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
355 let schema: SchemaRef = Arc::new(Schema::empty());
356 let schema_for_stream = Arc::clone(&schema);
357
358 let s =
359 stream::iter((0..n).map(move |_| {
360 Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
361 }));
362
363 Box::pin(RecordBatchStreamAdapter::new(schema, s))
364 }
365
366 #[tokio::test]
367 async fn yield_less_than_threshold() -> Result<()> {
368 let count = TASK_BUDGET - 10;
369 let inner = make_empty_batches(count);
370 let out = make_cooperative(inner).collect::<Vec<_>>().await;
371 assert_eq!(out.len(), count);
372 Ok(())
373 }
374
375 #[tokio::test]
376 async fn yield_equal_to_threshold() -> Result<()> {
377 let count = TASK_BUDGET;
378 let inner = make_empty_batches(count);
379 let out = make_cooperative(inner).collect::<Vec<_>>().await;
380 assert_eq!(out.len(), count);
381 Ok(())
382 }
383
384 #[tokio::test]
385 async fn yield_more_than_threshold() -> Result<()> {
386 let count = TASK_BUDGET + 20;
387 let inner = make_empty_batches(count);
388 let out = make_cooperative(inner).collect::<Vec<_>>().await;
389 assert_eq!(out.len(), count);
390 Ok(())
391 }
392}