#[cfg(all(feature = "integration", test))]
mod tests {
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, SessionState, TaskContext};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
execute_stream,
};
use datafusion_distributed::test_utils::localhost::start_localhost_context;
use datafusion_distributed::test_utils::parquet::register_parquet_tables;
use datafusion_distributed::{
DistributedExt, WorkerQueryContext, assert_snapshot, display_plan_ascii,
};
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_proto::protobuf::proto_error;
use futures::TryStreamExt;
use prost::Message;
use std::fmt::Formatter;
use std::sync::{Arc, RwLock};
use tokio::task::JoinHandle;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
#[tokio::test]
async fn stateful_execution_plan() -> Result<(), Box<dyn std::error::Error>> {
async fn build_state(ctx: WorkerQueryContext) -> Result<SessionState, DataFusionError> {
Ok(ctx
.builder
.with_distributed_user_codec(PassThroughExecCodec)
.build())
}
let (mut ctx_distributed, _guard, _) = start_localhost_context(3, build_state).await;
ctx_distributed.set_distributed_user_codec(PassThroughExecCodec);
register_parquet_tables(&ctx_distributed).await?;
let query = r#"SELECT "MinTemp", "RainToday" FROM weather WHERE "MinTemp" > 20.0 ORDER BY "MinTemp" DESC"#;
let df_distributed = ctx_distributed.sql(query).await?;
let plan = df_distributed.create_physical_plan().await?;
let transformed = plan.transform_up(|plan| {
if plan.children().is_empty() {
return Ok(Transformed::yes(Arc::new(StatefulPassThroughExec::new(
plan,
))));
}
Ok(Transformed::no(plan))
})?;
let plan = transformed.data;
let plan_str = display_plan_ascii(plan.as_ref(), false);
assert_snapshot!(plan_str,
@r"
┌───── DistributedExec ── Tasks: t0:[p0]
│ SortPreservingMergeExec: [MinTemp@0 DESC]
│ [Stage 1] => NetworkCoalesceExec: output_partitions=9, input_tasks=3
└──────────────────────────────────────────────────
┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8]
│ SortExec: expr=[MinTemp@0 DESC], preserve_partitioning=[true]
│ FilterExec: MinTemp@0 > 20
│ StatefulPassThroughExec
│ DistributedLeafExec:
│ t0: DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet:<int>..<int>], [/testdata/weather/result-000000.parquet:<int>..<int>, /testdata/weather/result-000001.parquet:<int>..<int>], [/testdata/weather/result-000002.parquet:<int>..<int>]]}, projection=[MinTemp, RainToday], file_type=parquet, predicate=MinTemp@0 > 20, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 20, required_guarantees=[]
│ t1: DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet:<int>..<int>], [/testdata/weather/result-000001.parquet:<int>..<int>], [/testdata/weather/result-000002.parquet:<int>..<int>]]}, projection=[MinTemp, RainToday], file_type=parquet, predicate=MinTemp@0 > 20, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 20, required_guarantees=[]
│ t2: DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet:<int>..<int>], [/testdata/weather/result-000001.parquet:<int>..<int>, /testdata/weather/result-000002.parquet:<int>..<int>], [/testdata/weather/result-000002.parquet:<int>..<int>]]}, projection=[MinTemp, RainToday], file_type=parquet, predicate=MinTemp@0 > 20, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 20, required_guarantees=[]
└──────────────────────────────────────────────────
",
);
let batches_distributed = pretty_format_batches(
&execute_stream(plan, ctx_distributed.task_ctx())?
.try_collect::<Vec<_>>()
.await?,
)?;
assert!(!batches_distributed.to_string().is_empty());
Ok(())
}
#[derive(Debug)]
pub struct StatefulPassThroughExec {
plan_properties: Arc<PlanProperties>,
child: Arc<dyn ExecutionPlan>,
task: RwLock<Option<JoinHandle<()>>>,
}
impl StatefulPassThroughExec {
fn new(child: Arc<dyn ExecutionPlan>) -> Self {
let plan_properties = Arc::new(PlanProperties::new(
EquivalenceProperties::new(child.schema()),
child.output_partitioning().clone(),
EmissionType::Incremental,
Boundedness::Bounded,
));
Self {
plan_properties,
child,
task: RwLock::new(None),
}
}
}
impl DisplayAs for StatefulPassThroughExec {
fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "StatefulPassThroughExec")
}
}
impl ExecutionPlan for StatefulPassThroughExec {
fn name(&self) -> &str {
"StatefulPassThroughExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.plan_properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.child]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(StatefulPassThroughExec::new(children[0].clone())))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let mut stream = self.child.execute(partition, context)?;
#[allow(clippy::disallowed_methods)]
let handle = tokio::spawn(async move {
while let Some(batch) = stream.next().await {
if tx.send(batch).await.is_err() {
return;
}
}
});
self.task.write().unwrap().replace(handle);
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
ReceiverStream::new(rx),
)))
}
}
#[derive(Debug)]
struct PassThroughExecCodec;
#[derive(Clone, PartialEq, ::prost::Message)]
struct PassThroughExecProto {
}
impl PhysicalExtensionCodec for PassThroughExecCodec {
fn try_decode(
&self,
buf: &[u8],
inputs: &[Arc<dyn ExecutionPlan>],
_ctx: &TaskContext,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
let _node =
PassThroughExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?;
if inputs.len() != 1 {
return Err(proto_error(format!(
"StatefulPassThroughExec expects exactly one child, got {}",
inputs.len()
)));
}
Ok(Arc::new(StatefulPassThroughExec::new(inputs[0].clone())))
}
fn try_encode(
&self,
node: Arc<dyn ExecutionPlan>,
buf: &mut Vec<u8>,
) -> datafusion::common::Result<()> {
let Some(_plan) = node.downcast_ref::<StatefulPassThroughExec>() else {
return Err(proto_error(format!(
"Expected plan to be of type StatefulPassThroughExec, but was {}",
node.name()
)));
};
PassThroughExecProto {}
.encode(buf)
.map_err(|err| proto_error(format!("{err}")))
}
}
}