use crate::common::require_one_child;
use crate::distributed_planner::{NetworkBoundary, ProducerHead};
use crate::stage::{LocalStage, Stage};
use crate::worker::WorkerConnectionPool;
use crate::{BroadcastExec, DistributedTaskContext};
use datafusion::common::{Result, not_impl_err, plan_err};
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr_common::metrics::MetricsSet;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
};
use std::fmt::Formatter;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct NetworkBroadcastExec {
pub(crate) properties: Arc<PlanProperties>,
pub(crate) input_stage: Stage,
pub(crate) worker_connections: WorkerConnectionPool,
}
impl NetworkBroadcastExec {
pub(crate) fn from_stage(input_stage: Stage, input_properties: Arc<PlanProperties>) -> Self {
let input_partition_count = input_properties.partitioning.partition_count();
let properties = Arc::new(
PlanProperties::clone(&input_properties)
.with_partitioning(Partitioning::UnknownPartitioning(input_partition_count)),
);
Self {
properties,
worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
input_stage,
}
}
pub fn try_new(input: Arc<dyn ExecutionPlan>, producer_tasks: usize) -> Result<Self> {
if !input.is::<BroadcastExec>() {
return plan_err!("The input of a NetworkBroadcastExec can only be a BroadcastExec");
}
let input_properties = Arc::clone(input.properties());
Ok(Self::from_stage(
Stage::Local(LocalStage {
query_id: Uuid::nil(),
num: 0,
plan: input,
tasks: producer_tasks,
}),
input_properties,
))
}
}
impl NetworkBoundary for NetworkBroadcastExec {
fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
let mut self_clone = self.clone();
self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count());
self_clone.input_stage = input_stage;
Ok(Arc::new(self_clone))
}
fn input_stage(&self) -> &Stage {
&self.input_stage
}
fn producer_head(&self, consumer_task_count: usize) -> ProducerHead {
let partition_count = self.properties.output_partitioning().partition_count();
ProducerHead::BroadcastExec {
output_partitions: partition_count * consumer_task_count,
}
}
}
impl DisplayAs for NetworkBroadcastExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
let input_tasks = self.input_stage.task_count();
let stage = self.input_stage.num();
let consumer_partitions = self.properties.partitioning.partition_count();
let stage_partitions = self
.input_stage
.local_plan()
.as_ref()
.map(|p| p.properties().partitioning.partition_count())
.unwrap_or(0);
write!(
f,
"[Stage {stage}] => NetworkBroadcastExec: partitions_per_consumer={consumer_partitions}, stage_partitions={stage_partitions}, input_tasks={input_tasks}",
)
}
}
impl ExecutionPlan for NetworkBroadcastExec {
fn name(&self) -> &str {
"NetworkBroadcastExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
match &self.input_stage.local_plan() {
Some(plan) => vec![plan],
None => vec![],
}
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut self_clone = self.as_ref().clone();
match &mut self_clone.input_stage {
Stage::Local(local) => {
local.plan = require_one_child(children)?;
}
Stage::Remote(_) => not_impl_err!("NetworkBoundary cannot accept children")?,
}
Ok(Arc::new(self_clone))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let remote_stage = match &self.input_stage {
Stage::Local(local) => return local.execute(partition, context),
Stage::Remote(remote_stage) => remote_stage,
};
let task_context = DistributedTaskContext::from_ctx(&context);
let out_partitions = self.properties.partitioning.partition_count();
let off = out_partitions * task_context.task_index;
let mut streams = Vec::with_capacity(self.input_stage.task_count());
for input_task_index in 0..self.input_stage.task_count() {
let worker_connection = self.worker_connections.get_or_init_worker_connection(
remote_stage,
off..(off + self.properties.partitioning.partition_count()),
input_task_index,
self.producer_head(task_context.task_count),
&context,
)?;
let stream = worker_connection.execute(off + partition)?;
streams.push(stream);
}
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::select_all(streams),
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.worker_connections.metrics.clone_inner())
}
}