use crate::common::require_one_child;
use crate::execution_plans::common::scale_partitioning;
use crate::stage::Stage;
use crate::worker::WorkerConnectionPool;
use crate::worker::generated::worker as pb;
use crate::worker::generated::worker::TaskKey;
use crate::worker::generated::worker::flight_app_metadata;
use crate::{DistributedTaskContext, ExecutionTask, NetworkBoundary};
use dashmap::DashMap;
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion::common::{Result, plan_err};
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::Partitioning;
use datafusion::physical_expr_common::metrics::MetricsSet;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use std::any::Any;
use std::fmt::Formatter;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct NetworkShuffleExec {
pub(crate) properties: Arc<PlanProperties>,
pub(crate) input_stage: Stage,
pub(crate) worker_connections: WorkerConnectionPool,
pub(crate) metrics_collection: Arc<DashMap<TaskKey, Vec<pb::MetricsSet>>>,
}
impl NetworkShuffleExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
query_id: Uuid,
num: usize,
task_count: usize,
input_task_count: usize,
) -> Result<Self, DataFusionError> {
if !matches!(input.output_partitioning(), Partitioning::Hash(_, _)) {
return plan_err!("NetworkShuffleExec input must be hash partitioned");
}
let transformed = Arc::clone(&input).transform_down(|plan| {
if let Some(r_exe) = plan.as_any().downcast_ref::<RepartitionExec>() {
let scaled = Arc::new(RepartitionExec::try_new(
require_one_child(r_exe.children())?,
scale_partitioning(r_exe.partitioning(), |p| p * task_count),
)?);
Ok(Transformed::new(scaled, true, TreeNodeRecursion::Stop))
} else if matches!(plan.output_partitioning(), Partitioning::Hash(_, _)) {
Ok(Transformed::no(plan))
} else {
plan_err!(
"NetworkShuffleExec input must be hash partitioned, but {} is not",
plan.name()
)
}
})?;
Ok(Self {
input_stage: Stage {
query_id,
num,
plan: Some(transformed.data),
tasks: vec![ExecutionTask { url: None }; input_task_count],
},
worker_connections: WorkerConnectionPool::new(input_task_count),
properties: input.properties().clone(),
metrics_collection: Default::default(),
})
}
}
impl NetworkBoundary for NetworkShuffleExec {
fn input_stage(&self) -> &Stage {
&self.input_stage
}
fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
let mut self_clone = self.clone();
self_clone.input_stage = input_stage;
Ok(Arc::new(self_clone))
}
}
impl DisplayAs for NetworkShuffleExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
let input_tasks = self.input_stage.tasks.len();
let partitions = self.properties.partitioning.partition_count();
let stage = self.input_stage.num;
write!(
f,
"[Stage {stage}] => NetworkShuffleExec: output_partitions={partitions}, input_tasks={input_tasks}",
)
}
}
impl ExecutionPlan for NetworkShuffleExec {
fn name(&self) -> &str {
"NetworkShuffleExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
match &self.input_stage.plan {
Some(v) => vec![v],
None => vec![],
}
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let mut self_clone = self.as_ref().clone();
self_clone.input_stage.plan = Some(require_one_child(children)?);
Ok(Arc::new(self_clone))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream, DataFusionError> {
let task_context = DistributedTaskContext::from_ctx(&context);
let off = self.properties.partitioning.partition_count() * task_context.task_index;
let mut streams = Vec::with_capacity(self.input_stage.tasks.len());
for input_task_index in 0..self.input_stage.tasks.len() {
let worker_connection = self.worker_connections.get_or_init_worker_connection(
&self.input_stage,
off..(off + self.properties.partitioning.partition_count()),
input_task_index,
&context,
)?;
let metrics_collection = Arc::clone(&self.metrics_collection);
let stream = worker_connection.stream_partition(off + partition, move |meta| {
if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content {
for task_metrics in m.tasks {
if let Some(task_key) = task_metrics.task_key {
metrics_collection.insert(task_key, task_metrics.metrics);
};
}
}
})?;
streams.push(stream);
}
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::select_all(streams),
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.worker_connections.metrics.clone_inner())
}
}