use crate::DistributedTaskContext;
use datafusion::common::{Result, Statistics, exec_err, not_impl_err, plan_err};
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr_common::metrics::MetricsSet;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use std::fmt::Formatter;
use std::sync::Arc;
#[derive(Debug)]
pub struct DistributedLeafExec {
pub(crate) original: Arc<dyn ExecutionPlan>,
pub(crate) properties: Arc<PlanProperties>,
pub(crate) variants: Vec<Arc<dyn ExecutionPlan>>,
}
impl DistributedLeafExec {
pub fn try_new(
original: Arc<dyn ExecutionPlan>,
variants: impl IntoIterator<Item = Arc<dyn ExecutionPlan>>,
) -> Result<Self> {
let mut properties = None;
let variants = variants
.into_iter()
.map(|plan| {
let plan_properties = plan.properties();
let Some(prev) = &properties else {
properties = Some(Arc::clone(plan_properties));
return Ok(plan);
};
if prev.partitioning.partition_count()
!= plan_properties.partitioning.partition_count()
{
return plan_err!("Different partition count where provided in two different variants of DistributedLeafExec")
}
if !prev.eq_properties.schema().eq(plan_properties.eq_properties.schema()) {
return plan_err!("Different schemas where provided in two different variants of DistributedLeafExec")
}
Ok(plan)
})
.collect::<Result<Vec<_>>>()?;
let Some(properties) = properties else {
return plan_err!("Empty list of variants was provided to DistributedLeafExec");
};
Ok(Self {
original,
properties,
variants,
})
}
pub fn original(&self) -> &Arc<dyn ExecutionPlan> {
&self.original
}
pub fn variants(&self) -> &[Arc<dyn ExecutionPlan>] {
&self.variants
}
pub(crate) fn to_task_specialized(&self, task_i: usize) -> Arc<dyn ExecutionPlan> {
Arc::clone(&self.variants[task_i])
}
}
impl DisplayAs for DistributedLeafExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "DistributedLeafExec: ")?;
self.original.fmt_as(t, f)
}
}
impl ExecutionPlan for DistributedLeafExec {
fn name(&self) -> &str {
"DistributedLeafExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("DistributedLeafExec does not accept children")
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let d_ctx = DistributedTaskContext::from_ctx(&context);
if d_ctx.task_count == 1 {
return self.original.execute(partition, context);
}
let Some(plan) = self.variants.get(d_ctx.task_index) else {
return exec_err!(
"Task index {} out of range for a per_task vector of length {}",
d_ctx.task_index,
self.variants.len()
);
};
plan.execute(partition, context)
}
fn metrics(&self) -> Option<MetricsSet> {
self.original.metrics()
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
self.original.partition_statistics(partition)
}
}