datafusion_physical_plan/
coop.rs1use datafusion_common::config::ConfigOptions;
74use datafusion_physical_expr::PhysicalExpr;
75#[cfg(datafusion_coop = "tokio_fallback")]
76use futures::Future;
77use std::pin::Pin;
78use std::sync::Arc;
79use std::task::{Context, Poll};
80
81use crate::execution_plan::CardinalityEffect::{self, Equal};
82use crate::filter_pushdown::{
83 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
84 FilterPushdownPropagation,
85};
86use crate::projection::ProjectionExec;
87use crate::{
88 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
89 SendableRecordBatchStream, SortOrderPushdownResult, check_if_same_properties,
90};
91use arrow::record_batch::RecordBatch;
92use arrow_schema::Schema;
93use datafusion_common::{Result, Statistics, assert_eq_or_internal_err};
94use datafusion_execution::TaskContext;
95
96use crate::execution_plan::SchedulingType;
97use crate::stream::RecordBatchStreamAdapter;
98use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
99use futures::{Stream, StreamExt};
100
101pub struct CooperativeStream<T>
107where
108 T: RecordBatchStream + Unpin,
109{
110 inner: T,
111 #[cfg(datafusion_coop = "per_stream")]
112 budget: u8,
113}
114
115#[cfg(datafusion_coop = "per_stream")]
116const YIELD_FREQUENCY: u8 = 128;
118
119impl<T> CooperativeStream<T>
120where
121 T: RecordBatchStream + Unpin,
122{
123 pub fn new(inner: T) -> Self {
127 Self {
128 inner,
129 #[cfg(datafusion_coop = "per_stream")]
130 budget: YIELD_FREQUENCY,
131 }
132 }
133}
134
135impl<T> Stream for CooperativeStream<T>
136where
137 T: RecordBatchStream + Unpin,
138{
139 type Item = Result<RecordBatch>;
140
141 fn poll_next(
142 mut self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 ) -> Poll<Option<Self::Item>> {
145 #[cfg(any(
146 datafusion_coop = "tokio",
147 not(any(
148 datafusion_coop = "tokio_fallback",
149 datafusion_coop = "per_stream"
150 ))
151 ))]
152 {
153 let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
154 let value = self.inner.poll_next_unpin(cx);
155 if value.is_ready() {
156 coop.made_progress();
157 }
158 value
159 }
160
161 #[cfg(datafusion_coop = "tokio_fallback")]
162 {
163 if !tokio::task::coop::has_budget_remaining() {
166 cx.waker().wake_by_ref();
167 return Poll::Pending;
168 }
169
170 let value = self.inner.poll_next_unpin(cx);
171 if value.is_ready() {
172 let consume = tokio::task::coop::consume_budget();
178 let consume_ref = std::pin::pin!(consume);
179 let _ = consume_ref.poll(cx);
180 }
181 value
182 }
183
184 #[cfg(datafusion_coop = "per_stream")]
185 {
186 if self.budget == 0 {
187 self.budget = YIELD_FREQUENCY;
188 cx.waker().wake_by_ref();
189 return Poll::Pending;
190 }
191
192 let value = { self.inner.poll_next_unpin(cx) };
193
194 if value.is_ready() {
195 self.budget -= 1;
196 } else {
197 self.budget = YIELD_FREQUENCY;
198 }
199 value
200 }
201 }
202}
203
204impl<T> RecordBatchStream for CooperativeStream<T>
205where
206 T: RecordBatchStream + Unpin,
207{
208 fn schema(&self) -> Arc<Schema> {
209 self.inner.schema()
210 }
211}
212
213#[derive(Debug, Clone)]
217pub struct CooperativeExec {
218 input: Arc<dyn ExecutionPlan>,
219 properties: Arc<PlanProperties>,
220}
221
222impl CooperativeExec {
223 pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
225 let properties = PlanProperties::clone(input.properties())
226 .with_scheduling_type(SchedulingType::Cooperative)
227 .into();
228
229 Self { input, properties }
230 }
231
232 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
234 &self.input
235 }
236
237 fn with_new_children_and_same_properties(
238 &self,
239 mut children: Vec<Arc<dyn ExecutionPlan>>,
240 ) -> Self {
241 Self {
242 input: children.swap_remove(0),
243 ..Self::clone(self)
244 }
245 }
246}
247
248impl DisplayAs for CooperativeExec {
249 fn fmt_as(
250 &self,
251 _t: DisplayFormatType,
252 f: &mut std::fmt::Formatter<'_>,
253 ) -> std::fmt::Result {
254 write!(f, "CooperativeExec")
255 }
256}
257
258impl ExecutionPlan for CooperativeExec {
259 fn name(&self) -> &str {
260 "CooperativeExec"
261 }
262
263 fn schema(&self) -> Arc<Schema> {
264 self.input.schema()
265 }
266
267 fn properties(&self) -> &Arc<PlanProperties> {
268 &self.properties
269 }
270
271 fn maintains_input_order(&self) -> Vec<bool> {
272 vec![true; self.children().len()]
273 }
274
275 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
276 vec![&self.input]
277 }
278
279 fn with_new_children(
280 self: Arc<Self>,
281 mut children: Vec<Arc<dyn ExecutionPlan>>,
282 ) -> Result<Arc<dyn ExecutionPlan>> {
283 assert_eq_or_internal_err!(
284 children.len(),
285 1,
286 "CooperativeExec requires exactly one child"
287 );
288 check_if_same_properties!(self, children);
289 Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
290 }
291
292 fn execute(
293 &self,
294 partition: usize,
295 task_ctx: Arc<TaskContext>,
296 ) -> Result<SendableRecordBatchStream> {
297 let child_stream = self.input.execute(partition, task_ctx)?;
298 Ok(make_cooperative(child_stream))
299 }
300
301 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
302 self.input.partition_statistics(partition)
303 }
304
305 fn supports_limit_pushdown(&self) -> bool {
306 true
307 }
308
309 fn cardinality_effect(&self) -> CardinalityEffect {
310 Equal
311 }
312
313 fn try_swapping_with_projection(
314 &self,
315 projection: &ProjectionExec,
316 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
317 match self.input.try_swapping_with_projection(projection)? {
318 Some(new_input) => Ok(Some(
319 Arc::new(self.clone()).with_new_children(vec![new_input])?,
320 )),
321 None => Ok(None),
322 }
323 }
324
325 fn gather_filters_for_pushdown(
326 &self,
327 _phase: FilterPushdownPhase,
328 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
329 _config: &ConfigOptions,
330 ) -> Result<FilterDescription> {
331 FilterDescription::from_children(parent_filters, &self.children())
332 }
333
334 fn handle_child_pushdown_result(
335 &self,
336 _phase: FilterPushdownPhase,
337 child_pushdown_result: ChildPushdownResult,
338 _config: &ConfigOptions,
339 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
340 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
341 }
342
343 fn try_pushdown_sort(
344 &self,
345 order: &[PhysicalSortExpr],
346 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
347 let child = self.input();
348
349 match child.try_pushdown_sort(order)? {
350 SortOrderPushdownResult::Exact { inner } => {
351 let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
352 Ok(SortOrderPushdownResult::Exact { inner: new_exec })
353 }
354 SortOrderPushdownResult::Inexact { inner } => {
355 let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
356 Ok(SortOrderPushdownResult::Inexact { inner: new_exec })
357 }
358 SortOrderPushdownResult::Unsupported => {
359 Ok(SortOrderPushdownResult::Unsupported)
360 }
361 }
362 }
363}
364
365pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
369where
370 T: RecordBatchStream + Unpin + Send + 'static,
371{
372 CooperativeStream::new(stream)
373}
374
375pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
381 Box::pin(cooperative(RecordBatchStreamAdapter::new(
383 stream.schema(),
384 stream,
385 )))
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 use arrow_schema::SchemaRef;
393
394 use futures::stream;
395
396 const TASK_BUDGET: usize = 128;
398
399 fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
401 let schema: SchemaRef = Arc::new(Schema::empty());
402 let schema_for_stream = Arc::clone(&schema);
403
404 let s =
405 stream::iter((0..n).map(move |_| {
406 Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
407 }));
408
409 Box::pin(RecordBatchStreamAdapter::new(schema, s))
410 }
411
412 #[tokio::test]
413 async fn yield_less_than_threshold() -> Result<()> {
414 let count = TASK_BUDGET - 10;
415 let inner = make_empty_batches(count);
416 let out = make_cooperative(inner).collect::<Vec<_>>().await;
417 assert_eq!(out.len(), count);
418 Ok(())
419 }
420
421 #[tokio::test]
422 async fn yield_equal_to_threshold() -> Result<()> {
423 let count = TASK_BUDGET;
424 let inner = make_empty_batches(count);
425 let out = make_cooperative(inner).collect::<Vec<_>>().await;
426 assert_eq!(out.len(), count);
427 Ok(())
428 }
429
430 #[tokio::test]
431 async fn yield_more_than_threshold() -> Result<()> {
432 let count = TASK_BUDGET + 20;
433 let inner = make_empty_batches(count);
434 let out = make_cooperative(inner).collect::<Vec<_>>().await;
435 assert_eq!(out.len(), count);
436 Ok(())
437 }
438}