datafusion_physical_plan/
coop.rs1use datafusion_common::config::ConfigOptions;
69use datafusion_physical_expr::PhysicalExpr;
70#[cfg(datafusion_coop = "tokio_fallback")]
71use futures::Future;
72use std::any::Any;
73use std::pin::Pin;
74use std::sync::Arc;
75use std::task::{Context, Poll};
76
77use crate::execution_plan::CardinalityEffect::{self, Equal};
78use crate::filter_pushdown::{
79 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
80 FilterPushdownPropagation,
81};
82use crate::projection::ProjectionExec;
83use crate::{
84 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
85 SendableRecordBatchStream, SortOrderPushdownResult,
86};
87use arrow::record_batch::RecordBatch;
88use arrow_schema::Schema;
89use datafusion_common::{Result, Statistics, assert_eq_or_internal_err};
90use datafusion_execution::TaskContext;
91
92use crate::execution_plan::SchedulingType;
93use crate::stream::RecordBatchStreamAdapter;
94use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
95use futures::{Stream, StreamExt};
96
97pub struct CooperativeStream<T>
103where
104 T: RecordBatchStream + Unpin,
105{
106 inner: T,
107 #[cfg(datafusion_coop = "per_stream")]
108 budget: u8,
109}
110
111#[cfg(datafusion_coop = "per_stream")]
112const YIELD_FREQUENCY: u8 = 128;
114
115impl<T> CooperativeStream<T>
116where
117 T: RecordBatchStream + Unpin,
118{
119 pub fn new(inner: T) -> Self {
123 Self {
124 inner,
125 #[cfg(datafusion_coop = "per_stream")]
126 budget: YIELD_FREQUENCY,
127 }
128 }
129}
130
131impl<T> Stream for CooperativeStream<T>
132where
133 T: RecordBatchStream + Unpin,
134{
135 type Item = Result<RecordBatch>;
136
137 fn poll_next(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Option<Self::Item>> {
141 #[cfg(any(
142 datafusion_coop = "tokio",
143 not(any(
144 datafusion_coop = "tokio_fallback",
145 datafusion_coop = "per_stream"
146 ))
147 ))]
148 {
149 let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
150 let value = self.inner.poll_next_unpin(cx);
151 if value.is_ready() {
152 coop.made_progress();
153 }
154 value
155 }
156
157 #[cfg(datafusion_coop = "tokio_fallback")]
158 {
159 if !tokio::task::coop::has_budget_remaining() {
162 cx.waker().wake_by_ref();
163 return Poll::Pending;
164 }
165
166 let value = self.inner.poll_next_unpin(cx);
167 if value.is_ready() {
168 let consume = tokio::task::coop::consume_budget();
174 let consume_ref = std::pin::pin!(consume);
175 let _ = consume_ref.poll(cx);
176 }
177 value
178 }
179
180 #[cfg(datafusion_coop = "per_stream")]
181 {
182 if self.budget == 0 {
183 self.budget = YIELD_FREQUENCY;
184 cx.waker().wake_by_ref();
185 return Poll::Pending;
186 }
187
188 let value = { self.inner.poll_next_unpin(cx) };
189
190 if value.is_ready() {
191 self.budget -= 1;
192 } else {
193 self.budget = YIELD_FREQUENCY;
194 }
195 value
196 }
197 }
198}
199
200impl<T> RecordBatchStream for CooperativeStream<T>
201where
202 T: RecordBatchStream + Unpin,
203{
204 fn schema(&self) -> Arc<Schema> {
205 self.inner.schema()
206 }
207}
208
209#[derive(Debug, Clone)]
213pub struct CooperativeExec {
214 input: Arc<dyn ExecutionPlan>,
215 properties: PlanProperties,
216}
217
218impl CooperativeExec {
219 pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
221 let properties = input
222 .properties()
223 .clone()
224 .with_scheduling_type(SchedulingType::Cooperative);
225
226 Self { input, properties }
227 }
228
229 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
231 &self.input
232 }
233}
234
235impl DisplayAs for CooperativeExec {
236 fn fmt_as(
237 &self,
238 _t: DisplayFormatType,
239 f: &mut std::fmt::Formatter<'_>,
240 ) -> std::fmt::Result {
241 write!(f, "CooperativeExec")
242 }
243}
244
245impl ExecutionPlan for CooperativeExec {
246 fn name(&self) -> &str {
247 "CooperativeExec"
248 }
249
250 fn as_any(&self) -> &dyn Any {
251 self
252 }
253
254 fn schema(&self) -> Arc<Schema> {
255 self.input.schema()
256 }
257
258 fn properties(&self) -> &PlanProperties {
259 &self.properties
260 }
261
262 fn maintains_input_order(&self) -> Vec<bool> {
263 vec![true; self.children().len()]
264 }
265
266 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
267 vec![&self.input]
268 }
269
270 fn with_new_children(
271 self: Arc<Self>,
272 mut children: Vec<Arc<dyn ExecutionPlan>>,
273 ) -> Result<Arc<dyn ExecutionPlan>> {
274 assert_eq_or_internal_err!(
275 children.len(),
276 1,
277 "CooperativeExec requires exactly one child"
278 );
279 Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
280 }
281
282 fn execute(
283 &self,
284 partition: usize,
285 task_ctx: Arc<TaskContext>,
286 ) -> Result<SendableRecordBatchStream> {
287 let child_stream = self.input.execute(partition, task_ctx)?;
288 Ok(make_cooperative(child_stream))
289 }
290
291 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
292 self.input.partition_statistics(partition)
293 }
294
295 fn supports_limit_pushdown(&self) -> bool {
296 true
297 }
298
299 fn cardinality_effect(&self) -> CardinalityEffect {
300 Equal
301 }
302
303 fn try_swapping_with_projection(
304 &self,
305 projection: &ProjectionExec,
306 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
307 match self.input.try_swapping_with_projection(projection)? {
308 Some(new_input) => Ok(Some(
309 Arc::new(self.clone()).with_new_children(vec![new_input])?,
310 )),
311 None => Ok(None),
312 }
313 }
314
315 fn gather_filters_for_pushdown(
316 &self,
317 _phase: FilterPushdownPhase,
318 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
319 _config: &ConfigOptions,
320 ) -> Result<FilterDescription> {
321 FilterDescription::from_children(parent_filters, &self.children())
322 }
323
324 fn handle_child_pushdown_result(
325 &self,
326 _phase: FilterPushdownPhase,
327 child_pushdown_result: ChildPushdownResult,
328 _config: &ConfigOptions,
329 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
330 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
331 }
332
333 fn try_pushdown_sort(
334 &self,
335 order: &[PhysicalSortExpr],
336 ) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
337 let child = self.input();
338
339 match child.try_pushdown_sort(order)? {
340 SortOrderPushdownResult::Exact { inner } => {
341 let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
342 Ok(SortOrderPushdownResult::Exact { inner: new_exec })
343 }
344 SortOrderPushdownResult::Inexact { inner } => {
345 let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
346 Ok(SortOrderPushdownResult::Inexact { inner: new_exec })
347 }
348 SortOrderPushdownResult::Unsupported => {
349 Ok(SortOrderPushdownResult::Unsupported)
350 }
351 }
352 }
353}
354
355pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
359where
360 T: RecordBatchStream + Unpin + Send + 'static,
361{
362 CooperativeStream::new(stream)
363}
364
365pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
371 Box::pin(cooperative(RecordBatchStreamAdapter::new(
373 stream.schema(),
374 stream,
375 )))
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::stream::RecordBatchStreamAdapter;
382
383 use arrow_schema::SchemaRef;
384
385 use futures::{StreamExt, stream};
386
387 const TASK_BUDGET: usize = 128;
389
390 fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
392 let schema: SchemaRef = Arc::new(Schema::empty());
393 let schema_for_stream = Arc::clone(&schema);
394
395 let s =
396 stream::iter((0..n).map(move |_| {
397 Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
398 }));
399
400 Box::pin(RecordBatchStreamAdapter::new(schema, s))
401 }
402
403 #[tokio::test]
404 async fn yield_less_than_threshold() -> Result<()> {
405 let count = TASK_BUDGET - 10;
406 let inner = make_empty_batches(count);
407 let out = make_cooperative(inner).collect::<Vec<_>>().await;
408 assert_eq!(out.len(), count);
409 Ok(())
410 }
411
412 #[tokio::test]
413 async fn yield_equal_to_threshold() -> Result<()> {
414 let count = TASK_BUDGET;
415 let inner = make_empty_batches(count);
416 let out = make_cooperative(inner).collect::<Vec<_>>().await;
417 assert_eq!(out.len(), count);
418 Ok(())
419 }
420
421 #[tokio::test]
422 async fn yield_more_than_threshold() -> Result<()> {
423 let count = TASK_BUDGET + 20;
424 let inner = make_empty_batches(count);
425 let out = make_cooperative(inner).collect::<Vec<_>>().await;
426 assert_eq!(out.len(), count);
427 Ok(())
428 }
429}