datafusion_physical_plan/
coop.rs1use datafusion_common::config::ConfigOptions;
74use datafusion_physical_expr::PhysicalExpr;
75#[cfg(datafusion_coop = "tokio_fallback")]
76use futures::Future;
77use std::any::Any;
78use std::pin::Pin;
79use std::sync::Arc;
80use std::task::{Context, Poll};
81
82use crate::execution_plan::CardinalityEffect::{self, Equal};
83use crate::filter_pushdown::{
84 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
85 FilterPushdownPropagation,
86};
87use crate::projection::ProjectionExec;
88use crate::{
89 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
90 SendableRecordBatchStream, SortOrderPushdownResult, check_if_same_properties,
91};
92use arrow::record_batch::RecordBatch;
93use arrow_schema::Schema;
94use datafusion_common::{Result, Statistics, assert_eq_or_internal_err};
95use datafusion_execution::TaskContext;
96
97use crate::execution_plan::SchedulingType;
98use crate::stream::RecordBatchStreamAdapter;
99use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
100use futures::{Stream, StreamExt};
101
102pub struct CooperativeStream<T>
108where
109 T: RecordBatchStream + Unpin,
110{
111 inner: T,
112 #[cfg(datafusion_coop = "per_stream")]
113 budget: u8,
114}
115
116#[cfg(datafusion_coop = "per_stream")]
117const YIELD_FREQUENCY: u8 = 128;
119
120impl<T> CooperativeStream<T>
121where
122 T: RecordBatchStream + Unpin,
123{
124 pub fn new(inner: T) -> Self {
128 Self {
129 inner,
130 #[cfg(datafusion_coop = "per_stream")]
131 budget: YIELD_FREQUENCY,
132 }
133 }
134}
135
136impl<T> Stream for CooperativeStream<T>
137where
138 T: RecordBatchStream + Unpin,
139{
140 type Item = Result<RecordBatch>;
141
142 fn poll_next(
143 mut self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 ) -> Poll<Option<Self::Item>> {
146 #[cfg(any(
147 datafusion_coop = "tokio",
148 not(any(
149 datafusion_coop = "tokio_fallback",
150 datafusion_coop = "per_stream"
151 ))
152 ))]
153 {
154 let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
155 let value = self.inner.poll_next_unpin(cx);
156 if value.is_ready() {
157 coop.made_progress();
158 }
159 value
160 }
161
162 #[cfg(datafusion_coop = "tokio_fallback")]
163 {
164 if !tokio::task::coop::has_budget_remaining() {
167 cx.waker().wake_by_ref();
168 return Poll::Pending;
169 }
170
171 let value = self.inner.poll_next_unpin(cx);
172 if value.is_ready() {
173 let consume = tokio::task::coop::consume_budget();
179 let consume_ref = std::pin::pin!(consume);
180 let _ = consume_ref.poll(cx);
181 }
182 value
183 }
184
185 #[cfg(datafusion_coop = "per_stream")]
186 {
187 if self.budget == 0 {
188 self.budget = YIELD_FREQUENCY;
189 cx.waker().wake_by_ref();
190 return Poll::Pending;
191 }
192
193 let value = { self.inner.poll_next_unpin(cx) };
194
195 if value.is_ready() {
196 self.budget -= 1;
197 } else {
198 self.budget = YIELD_FREQUENCY;
199 }
200 value
201 }
202 }
203}
204
205impl<T> RecordBatchStream for CooperativeStream<T>
206where
207 T: RecordBatchStream + Unpin,
208{
209 fn schema(&self) -> Arc<Schema> {
210 self.inner.schema()
211 }
212}
213
214#[derive(Debug, Clone)]
218pub struct CooperativeExec {
219 input: Arc<dyn ExecutionPlan>,
220 properties: Arc<PlanProperties>,
221}
222
223impl CooperativeExec {
224 pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
226 let properties = PlanProperties::clone(input.properties())
227 .with_scheduling_type(SchedulingType::Cooperative)
228 .into();
229
230 Self { input, properties }
231 }
232
233 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
235 &self.input
236 }
237
238 fn with_new_children_and_same_properties(
239 &self,
240 mut children: Vec<Arc<dyn ExecutionPlan>>,
241 ) -> Self {
242 Self {
243 input: children.swap_remove(0),
244 ..Self::clone(self)
245 }
246 }
247}
248
249impl DisplayAs for CooperativeExec {
250 fn fmt_as(
251 &self,
252 _t: DisplayFormatType,
253 f: &mut std::fmt::Formatter<'_>,
254 ) -> std::fmt::Result {
255 write!(f, "CooperativeExec")
256 }
257}
258
259impl ExecutionPlan for CooperativeExec {
260 fn name(&self) -> &str {
261 "CooperativeExec"
262 }
263
264 fn as_any(&self) -> &dyn Any {
265 self
266 }
267
268 fn schema(&self) -> Arc<Schema> {
269 self.input.schema()
270 }
271
272 fn properties(&self) -> &Arc<PlanProperties> {
273 &self.properties
274 }
275
276 fn maintains_input_order(&self) -> Vec<bool> {
277 vec![true; self.children().len()]
278 }
279
280 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
281 vec![&self.input]
282 }
283
284 fn with_new_children(
285 self: Arc<Self>,
286 mut children: Vec<Arc<dyn ExecutionPlan>>,
287 ) -> Result<Arc<dyn ExecutionPlan>> {
288 assert_eq_or_internal_err!(
289 children.len(),
290 1,
291 "CooperativeExec requires exactly one child"
292 );
293 check_if_same_properties!(self, children);
294 Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
295 }
296
297 fn execute(
298 &self,
299 partition: usize,
300 task_ctx: Arc<TaskContext>,
301 ) -> Result<SendableRecordBatchStream> {
302 let child_stream = self.input.execute(partition, task_ctx)?;
303 Ok(make_cooperative(child_stream))
304 }
305
306 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
307 self.input.partition_statistics(partition)
308 }
309
310 fn supports_limit_pushdown(&self) -> bool {
311 true
312 }
313
314 fn cardinality_effect(&self) -> CardinalityEffect {
315 Equal
316 }
317
318 fn try_swapping_with_projection(
319 &self,
320 projection: &ProjectionExec,
321 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
322 match self.input.try_swapping_with_projection(projection)? {
323 Some(new_input) => Ok(Some(
324 Arc::new(self.clone()).with_new_children(vec![new_input])?,
325 )),
326 None => Ok(None),
327 }
328 }
329
330 fn gather_filters_for_pushdown(
331 &self,
332 _phase: FilterPushdownPhase,
333 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
334 _config: &ConfigOptions,
335 ) -> Result<FilterDescription> {
336 FilterDescription::from_children(parent_filters, &self.children())
337 }
338
339 fn handle_child_pushdown_result(
340 &self,
341 _phase: FilterPushdownPhase,
342 child_pushdown_result: ChildPushdownResult,
343 _config: &ConfigOptions,
344 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
345 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
346 }
347
348 fn try_pushdown_sort(
349 &self,
350 order: &[PhysicalSortExpr],
351 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
352 let child = self.input();
353
354 match child.try_pushdown_sort(order)? {
355 SortOrderPushdownResult::Exact { inner } => {
356 let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
357 Ok(SortOrderPushdownResult::Exact { inner: new_exec })
358 }
359 SortOrderPushdownResult::Inexact { inner } => {
360 let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
361 Ok(SortOrderPushdownResult::Inexact { inner: new_exec })
362 }
363 SortOrderPushdownResult::Unsupported => {
364 Ok(SortOrderPushdownResult::Unsupported)
365 }
366 }
367 }
368}
369
370pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
374where
375 T: RecordBatchStream + Unpin + Send + 'static,
376{
377 CooperativeStream::new(stream)
378}
379
380pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
386 Box::pin(cooperative(RecordBatchStreamAdapter::new(
388 stream.schema(),
389 stream,
390 )))
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::stream::RecordBatchStreamAdapter;
397
398 use arrow_schema::SchemaRef;
399
400 use futures::{StreamExt, stream};
401
402 const TASK_BUDGET: usize = 128;
404
405 fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
407 let schema: SchemaRef = Arc::new(Schema::empty());
408 let schema_for_stream = Arc::clone(&schema);
409
410 let s =
411 stream::iter((0..n).map(move |_| {
412 Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
413 }));
414
415 Box::pin(RecordBatchStreamAdapter::new(schema, s))
416 }
417
418 #[tokio::test]
419 async fn yield_less_than_threshold() -> Result<()> {
420 let count = TASK_BUDGET - 10;
421 let inner = make_empty_batches(count);
422 let out = make_cooperative(inner).collect::<Vec<_>>().await;
423 assert_eq!(out.len(), count);
424 Ok(())
425 }
426
427 #[tokio::test]
428 async fn yield_equal_to_threshold() -> Result<()> {
429 let count = TASK_BUDGET;
430 let inner = make_empty_batches(count);
431 let out = make_cooperative(inner).collect::<Vec<_>>().await;
432 assert_eq!(out.len(), count);
433 Ok(())
434 }
435
436 #[tokio::test]
437 async fn yield_more_than_threshold() -> Result<()> {
438 let count = TASK_BUDGET + 20;
439 let inner = make_empty_batches(count);
440 let out = make_cooperative(inner).collect::<Vec<_>>().await;
441 assert_eq!(out.len(), count);
442 Ok(())
443 }
444}