use datafusion_common::config::ConfigOptions;
use datafusion_physical_expr::PhysicalExpr;
#[cfg(datafusion_coop = "tokio_fallback")]
use futures::Future;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::execution_plan::CardinalityEffect::{self, Equal};
use crate::filter_pushdown::{
ChildPushdownResult, FilterDescription, FilterPushdownPhase,
FilterPushdownPropagation,
};
use crate::projection::ProjectionExec;
use crate::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream,
SendableRecordBatchStream, SortOrderPushdownResult, check_if_same_properties,
};
use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{Result, Statistics, assert_eq_or_internal_err};
use datafusion_execution::TaskContext;
use crate::execution_plan::SchedulingType;
use crate::stream::RecordBatchStreamAdapter;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use futures::{Stream, StreamExt};
pub struct CooperativeStream<T>
where
T: RecordBatchStream + Unpin,
{
inner: T,
#[cfg(datafusion_coop = "per_stream")]
budget: u8,
}
#[cfg(datafusion_coop = "per_stream")]
const YIELD_FREQUENCY: u8 = 128;
impl<T> CooperativeStream<T>
where
T: RecordBatchStream + Unpin,
{
pub fn new(inner: T) -> Self {
Self {
inner,
#[cfg(datafusion_coop = "per_stream")]
budget: YIELD_FREQUENCY,
}
}
}
impl<T> Stream for CooperativeStream<T>
where
T: RecordBatchStream + Unpin,
{
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
#[cfg(any(
datafusion_coop = "tokio",
not(any(
datafusion_coop = "tokio_fallback",
datafusion_coop = "per_stream"
))
))]
{
let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx));
let value = self.inner.poll_next_unpin(cx);
if value.is_ready() {
coop.made_progress();
}
value
}
#[cfg(datafusion_coop = "tokio_fallback")]
{
if !tokio::task::coop::has_budget_remaining() {
cx.waker().wake_by_ref();
return Poll::Pending;
}
let value = self.inner.poll_next_unpin(cx);
if value.is_ready() {
let consume = tokio::task::coop::consume_budget();
let consume_ref = std::pin::pin!(consume);
let _ = consume_ref.poll(cx);
}
value
}
#[cfg(datafusion_coop = "per_stream")]
{
if self.budget == 0 {
self.budget = YIELD_FREQUENCY;
cx.waker().wake_by_ref();
return Poll::Pending;
}
let value = { self.inner.poll_next_unpin(cx) };
if value.is_ready() {
self.budget -= 1;
} else {
self.budget = YIELD_FREQUENCY;
}
value
}
}
}
impl<T> RecordBatchStream for CooperativeStream<T>
where
T: RecordBatchStream + Unpin,
{
fn schema(&self) -> Arc<Schema> {
self.inner.schema()
}
}
#[derive(Debug, Clone)]
pub struct CooperativeExec {
input: Arc<dyn ExecutionPlan>,
properties: Arc<PlanProperties>,
}
impl CooperativeExec {
pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
let properties = PlanProperties::clone(input.properties())
.with_scheduling_type(SchedulingType::Cooperative)
.into();
Self { input, properties }
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
input: children.swap_remove(0),
..Self::clone(self)
}
}
}
impl DisplayAs for CooperativeExec {
fn fmt_as(
&self,
_t: DisplayFormatType,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(f, "CooperativeExec")
}
}
impl ExecutionPlan for CooperativeExec {
fn name(&self) -> &str {
"CooperativeExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> Arc<Schema> {
self.input.schema()
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true; self.children().len()]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
assert_eq_or_internal_err!(
children.len(),
1,
"CooperativeExec requires exactly one child"
);
check_if_same_properties!(self, children);
Ok(Arc::new(CooperativeExec::new(children.swap_remove(0))))
}
fn execute(
&self,
partition: usize,
task_ctx: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let child_stream = self.input.execute(partition, task_ctx)?;
Ok(make_cooperative(child_stream))
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
self.input.partition_statistics(partition)
}
fn supports_limit_pushdown(&self) -> bool {
true
}
fn cardinality_effect(&self) -> CardinalityEffect {
Equal
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
match self.input.try_swapping_with_projection(projection)? {
Some(new_input) => Ok(Some(
Arc::new(self.clone()).with_new_children(vec![new_input])?,
)),
None => Ok(None),
}
}
fn gather_filters_for_pushdown(
&self,
_phase: FilterPushdownPhase,
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
_config: &ConfigOptions,
) -> Result<FilterDescription> {
FilterDescription::from_children(parent_filters, &self.children())
}
fn handle_child_pushdown_result(
&self,
_phase: FilterPushdownPhase,
child_pushdown_result: ChildPushdownResult,
_config: &ConfigOptions,
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
}
fn try_pushdown_sort(
&self,
order: &[PhysicalSortExpr],
) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
let child = self.input();
match child.try_pushdown_sort(order)? {
SortOrderPushdownResult::Exact { inner } => {
let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
Ok(SortOrderPushdownResult::Exact { inner: new_exec })
}
SortOrderPushdownResult::Inexact { inner } => {
let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?;
Ok(SortOrderPushdownResult::Inexact { inner: new_exec })
}
SortOrderPushdownResult::Unsupported => {
Ok(SortOrderPushdownResult::Unsupported)
}
}
}
}
pub fn cooperative<T>(stream: T) -> CooperativeStream<T>
where
T: RecordBatchStream + Unpin + Send + 'static,
{
CooperativeStream::new(stream)
}
pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream {
Box::pin(cooperative(RecordBatchStreamAdapter::new(
stream.schema(),
stream,
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stream::RecordBatchStreamAdapter;
use arrow_schema::SchemaRef;
use futures::{StreamExt, stream};
const TASK_BUDGET: usize = 128;
fn make_empty_batches(n: usize) -> SendableRecordBatchStream {
let schema: SchemaRef = Arc::new(Schema::empty());
let schema_for_stream = Arc::clone(&schema);
let s =
stream::iter((0..n).map(move |_| {
Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream)))
}));
Box::pin(RecordBatchStreamAdapter::new(schema, s))
}
#[tokio::test]
async fn yield_less_than_threshold() -> Result<()> {
let count = TASK_BUDGET - 10;
let inner = make_empty_batches(count);
let out = make_cooperative(inner).collect::<Vec<_>>().await;
assert_eq!(out.len(), count);
Ok(())
}
#[tokio::test]
async fn yield_equal_to_threshold() -> Result<()> {
let count = TASK_BUDGET;
let inner = make_empty_batches(count);
let out = make_cooperative(inner).collect::<Vec<_>>().await;
assert_eq!(out.len(), count);
Ok(())
}
#[tokio::test]
async fn yield_more_than_threshold() -> Result<()> {
let count = TASK_BUDGET + 20;
let inner = make_empty_batches(count);
let out = make_cooperative(inner).collect::<Vec<_>>().await;
assert_eq!(out.len(), count);
Ok(())
}
}