use crate::common::require_one_child;
use crate::distributed_planner::ProducerHead;
use crate::execution_plans::common::scale_partitioning;
use crate::stage::{LocalStage, Stage};
use crate::worker::WorkerConnectionPool;
use crate::{DistributedTaskContext, NetworkBoundary};
use datafusion::common::{Result, not_impl_err, plan_err};
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::Partitioning;
use datafusion::physical_expr_common::metrics::MetricsSet;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use std::fmt::Formatter;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct NetworkShuffleExec {
pub(crate) properties: Arc<PlanProperties>,
pub(crate) input_stage: Stage,
pub(crate) worker_connections: WorkerConnectionPool,
}
impl NetworkShuffleExec {
pub(crate) fn from_stage(input_stage: Stage, input_properties: Arc<PlanProperties>) -> Self {
Self {
properties: input_properties,
worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
input_stage,
}
}
pub fn try_new(input: Arc<dyn ExecutionPlan>, producer_tasks: usize) -> Result<Self> {
let Some(r_exec) = input.downcast_ref::<RepartitionExec>() else {
return plan_err!("The input of a NetworkShuffleExec can only be a RepartitionExec");
};
if !matches!(r_exec.partitioning(), Partitioning::Hash(_, _)) {
return plan_err!("The input of a NetworkShuffleExec must be hash partitioned");
}
let input_properties = Arc::clone(input.properties());
Ok(Self::from_stage(
Stage::Local(LocalStage {
query_id: Uuid::nil(),
num: 0,
plan: input,
tasks: producer_tasks,
}),
input_properties,
))
}
}
impl NetworkBoundary for NetworkShuffleExec {
fn input_stage(&self) -> &Stage {
&self.input_stage
}
fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
let mut self_clone = self.clone();
self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count());
self_clone.input_stage = input_stage;
Ok(Arc::new(self_clone))
}
fn producer_head(&self, consumer_task_count: usize) -> ProducerHead {
ProducerHead::RepartitionExec {
partitioning: scale_partitioning(&self.properties.partitioning, |prev| {
prev * consumer_task_count
}),
}
}
}
impl DisplayAs for NetworkShuffleExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
let input_tasks = self.input_stage.task_count();
let partitions = self.properties.partitioning.partition_count();
let stage = self.input_stage.num();
write!(
f,
"[Stage {stage}] => NetworkShuffleExec: output_partitions={partitions}, input_tasks={input_tasks}",
)
}
}
impl ExecutionPlan for NetworkShuffleExec {
fn name(&self) -> &str {
"NetworkShuffleExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
match &self.input_stage.local_plan() {
Some(v) => vec![v],
None => vec![],
}
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut self_clone = self.as_ref().clone();
match &mut self_clone.input_stage {
Stage::Local(local) => {
local.plan = require_one_child(children)?;
}
Stage::Remote(_) => not_impl_err!("NetworkBoundary cannot accept children")?,
}
Ok(Arc::new(self_clone))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let remote_stage = match &self.input_stage {
Stage::Local(local) => return local.execute(partition, context),
Stage::Remote(remote_stage) => remote_stage,
};
let task_context = DistributedTaskContext::from_ctx(&context);
let off = self.properties.partitioning.partition_count() * task_context.task_index;
let mut streams = Vec::with_capacity(remote_stage.workers.len());
for input_task_index in 0..remote_stage.workers.len() {
let worker_connection = self.worker_connections.get_or_init_worker_connection(
remote_stage,
off..(off + self.properties.partitioning.partition_count()),
input_task_index,
self.producer_head(task_context.task_count),
&context,
)?;
let stream = worker_connection.execute(off + partition)?;
streams.push(stream);
}
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::select_all(streams),
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.worker_connections.metrics.clone_inner())
}
}