use crate::{BroadcastExec, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, Stage};
use datafusion::common::Result;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use std::sync::Arc;
pub trait NetworkBoundary: ExecutionPlan {
fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>>;
fn input_stage(&self) -> &Stage;
fn producer_head(&self, consumer_tasks: usize) -> ProducerHead;
}
pub enum ProducerHead {
None,
BroadcastExec { output_partitions: usize },
RepartitionExec { partitioning: Partitioning },
}
pub trait NetworkBoundaryExt {
fn as_network_boundary(&self) -> Option<&dyn NetworkBoundary>;
fn is_network_boundary(&self) -> bool {
self.as_network_boundary().is_some()
}
}
impl NetworkBoundaryExt for dyn ExecutionPlan {
fn as_network_boundary(&self) -> Option<&dyn NetworkBoundary> {
if let Some(node) = self.downcast_ref::<NetworkShuffleExec>() {
Some(node)
} else if let Some(node) = self.downcast_ref::<NetworkCoalesceExec>() {
Some(node)
} else if let Some(node) = self.downcast_ref::<NetworkBroadcastExec>() {
Some(node)
} else {
None
}
}
}
pub(crate) fn insert_producer_head(
input: Arc<dyn ExecutionPlan>,
head: ProducerHead,
) -> Result<Arc<dyn ExecutionPlan>> {
let input = if let Some(r_exec) = input.downcast_ref::<RepartitionExec>() {
Arc::clone(r_exec.input())
} else if let Some(b_exec) = input.downcast_ref::<BroadcastExec>() {
Arc::clone(b_exec.input())
} else {
input
};
let plan = match head {
ProducerHead::None => input,
ProducerHead::BroadcastExec { output_partitions } => {
let partitions = input.output_partitioning().partition_count();
Arc::new(BroadcastExec::new(input, output_partitions / partitions))
}
ProducerHead::RepartitionExec { partitioning } => {
Arc::new(RepartitionExec::try_new(input, partitioning)?)
}
};
Ok(plan)
}