use crate::DistributedTaskContext;
use crate::common::require_one_child;
use crate::distributed_planner::{NetworkBoundary, ProducerHead};
use crate::execution_plans::common::scale_partitioning_props;
use crate::stage::{LocalStage, Stage};
use crate::worker::WorkerConnectionPool;
use datafusion::common::{exec_err, not_impl_err, plan_err};
use datafusion::error::Result;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr_common::metrics::MetricsSet;
use datafusion::physical_plan::limit::LocalLimitExec;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, PlanProperties,
internal_err,
};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct NetworkCoalesceExec {
pub(crate) properties: Arc<PlanProperties>,
pub(crate) input_stage: Stage,
pub(crate) worker_connections: WorkerConnectionPool,
}
impl NetworkCoalesceExec {
pub(crate) fn from_stage(
input_stage: Stage,
input_properties: Arc<PlanProperties>,
consumer_tasks: usize,
) -> Self {
let max_input_task_count = input_stage.task_count().div_ceil(consumer_tasks).max(1);
let props = scale_partitioning_props(&input_properties, |p| p * max_input_task_count);
Self {
properties: props,
worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
input_stage,
}
}
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
producer_tasks: usize,
consumer_tasks: usize,
) -> Result<Self> {
if consumer_tasks == 0 {
return plan_err!("The `consumer_tasks` input of a NetworkCoalesceExec must not be 0");
}
let input_properties = Arc::clone(input.properties());
Ok(Self::from_stage(
Stage::Local(LocalStage {
query_id: Uuid::nil(),
num: 0,
plan: input,
tasks: producer_tasks,
}),
input_properties,
consumer_tasks,
))
}
pub(crate) fn with_fetch_on_input_stage(&self, fetch: usize) -> Result<Arc<dyn ExecutionPlan>> {
let Stage::Local(local) = &self.input_stage else {
return Ok(Arc::new(self.clone()));
};
let input_with_fetch = if local.plan.fetch().is_some_and(|existing| existing <= fetch) {
Arc::clone(&local.plan)
} else {
local
.plan
.with_fetch(Some(fetch))
.unwrap_or_else(|| Arc::new(LocalLimitExec::new(Arc::clone(&local.plan), fetch)))
};
let mut self_clone = self.clone();
self_clone.input_stage = Stage::Local(LocalStage {
query_id: local.query_id,
num: local.num,
plan: input_with_fetch,
tasks: local.tasks,
});
Ok(Arc::new(self_clone))
}
}
impl NetworkBoundary for NetworkCoalesceExec {
fn input_stage(&self) -> &Stage {
&self.input_stage
}
fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
let mut self_clone = self.clone();
self_clone.properties = scale_partitioning_props(self_clone.properties(), |p| {
p * input_stage.task_count() / self_clone.input_stage.task_count().max(1)
});
self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count());
self_clone.input_stage = input_stage;
Ok(Arc::new(self_clone))
}
fn producer_head(&self, _consumer_task_count: usize) -> ProducerHead {
ProducerHead::None
}
}
impl DisplayAs for NetworkCoalesceExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
let input_tasks = self.input_stage.task_count();
let partitions = self.properties.partitioning.partition_count();
let stage = self.input_stage.num();
write!(
f,
"[Stage {stage}] => NetworkCoalesceExec: output_partitions={partitions}, input_tasks={input_tasks}",
)
}
}
impl ExecutionPlan for NetworkCoalesceExec {
fn name(&self) -> &str {
"NetworkCoalesceExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
match &self.input_stage.local_plan() {
Some(v) => vec![v],
None => vec![],
}
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut self_clone = self.as_ref().clone();
match &mut self_clone.input_stage {
Stage::Local(local) => {
local.plan = require_one_child(children)?;
}
Stage::Remote(_) => not_impl_err!("NetworkBoundary cannot accept children")?,
}
Ok(Arc::new(self_clone))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let remote_stage = match &self.input_stage {
Stage::Local(local) => return local.execute(partition, context),
Stage::Remote(remote_stage) => remote_stage,
};
let task_context = DistributedTaskContext::from_ctx(&context);
if task_context.task_index >= task_context.task_count {
return exec_err!(
"NetworkCoalesceExec invalid task context: task_index={} >= task_count={}",
task_context.task_index,
task_context.task_count
);
}
let partitions_per_task = self
.properties()
.partitioning
.partition_count()
.checked_div(
self.input_stage
.task_count()
.div_ceil(task_context.task_count)
.max(1),
)
.unwrap_or(0);
if partitions_per_task == 0 {
return exec_err!("NetworkCoalesceExec has 0 partitions per input task");
}
let input_task_count = self.input_stage.task_count();
let group = task_group(
input_task_count,
task_context.task_index,
task_context.task_count,
);
let input_task_offset = partition / partitions_per_task;
let target_partition = partition % partitions_per_task;
if input_task_offset >= group.len {
return Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())));
}
if input_task_offset >= group.max_len {
return internal_err!(
"NetworkCoalesceExec input_task_offset={} >= group.max_len={}",
input_task_offset,
group.max_len
);
}
let target_task = group.start_task + input_task_offset;
let worker_connection = self.worker_connections.get_or_init_worker_connection(
remote_stage,
0..partitions_per_task,
target_task,
self.producer_head(task_context.task_count),
&context,
)?;
let stream = worker_connection.execute(target_partition)?;
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
stream,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.worker_connections.metrics.clone_inner())
}
}
#[derive(Debug, Clone, Copy)]
struct TaskGroup {
start_task: usize,
len: usize,
max_len: usize,
}
fn task_group(input_task_count: usize, task_index: usize, task_count: usize) -> TaskGroup {
if task_count == 0 {
return TaskGroup {
start_task: 0,
len: 0,
max_len: 0,
};
}
let base_tasks_per_group = input_task_count / task_count;
let groups_with_extra_task = input_task_count % task_count;
let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
let start_task = (task_index * base_tasks_per_group) + task_index.min(groups_with_extra_task);
let max_len = base_tasks_per_group + usize::from(groups_with_extra_task > 0);
TaskGroup {
start_task,
len,
max_len,
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::datatypes::Schema;
use datafusion::physical_plan::empty::EmptyExec;
#[derive(Clone, Copy)]
struct Case {
name: &'static str,
input_tasks: usize,
consumer_tasks: usize,
}
fn expected_groups(input_tasks: usize, consumer_tasks: usize) -> Vec<(usize, usize)> {
assert!(consumer_tasks > 0, "consumer_tasks must be non-zero");
let base_tasks_per_group = input_tasks / consumer_tasks;
let groups_with_extra_task = input_tasks % consumer_tasks;
let mut groups = Vec::with_capacity(consumer_tasks);
let mut start_task = 0;
for task_index in 0..consumer_tasks {
let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
groups.push((start_task, len));
start_task += len;
}
groups
}
fn assert_case(case: Case) -> Result<()> {
let child: Arc<dyn ExecutionPlan> = Arc::new(EmptyExec::new(Arc::new(Schema::empty())));
let child_partitions = child.properties().partitioning.partition_count();
let exec = NetworkCoalesceExec::try_new(
Arc::clone(&child),
case.input_tasks,
case.consumer_tasks,
)?;
let max_group_size = case.input_tasks.div_ceil(case.consumer_tasks).max(1);
assert_eq!(
exec.properties().partitioning.partition_count(),
child_partitions * max_group_size
);
let groups = expected_groups(case.input_tasks, case.consumer_tasks);
assert_eq!(groups.len(), case.consumer_tasks);
let mut seen = vec![false; case.input_tasks];
let mut expected_start = 0;
let mut padding_slots = 0;
for (index, (start, len)) in groups.into_iter().enumerate() {
assert_eq!(
start, expected_start,
"case {} group {} should be contiguous",
case.name, index
);
assert!(
start + len <= case.input_tasks,
"case {} group {} exceeds input task count",
case.name,
index
);
for (offset, seen_task) in seen.iter_mut().skip(start).take(len).enumerate() {
let task = start + offset;
assert!(
!*seen_task,
"case {} input task {} appears twice",
case.name, task
);
*seen_task = true;
}
expected_start = start + len;
padding_slots += max_group_size - len;
}
assert_eq!(
expected_start, case.input_tasks,
"case {} groups should cover all input tasks",
case.name
);
assert!(
seen.iter().all(|v| *v),
"case {} missing at least one input task",
case.name
);
let total_slots = case.consumer_tasks * max_group_size;
let total_padding = total_slots - case.input_tasks;
assert_eq!(
padding_slots, total_padding,
"case {} padding slots mismatch",
case.name
);
Ok(())
}
const ONE_TO_MANY_INPUT: usize = 1;
const ONE_TO_MANY_OUTPUT: usize = 3;
const MANY_TO_ONE_INPUT: usize = 4;
const MANY_TO_ONE_OUTPUT: usize = 1;
const MANY_TO_FEWER_INPUT: usize = 5;
const MANY_TO_FEWER_OUTPUT: usize = 2;
const FEWER_TO_MANY_INPUT: usize = 2;
const FEWER_TO_MANY_OUTPUT: usize = 5;
#[test]
fn validates_partition_coverage_one_to_many() -> Result<()> {
assert_case(Case {
name: "1_to_n",
input_tasks: ONE_TO_MANY_INPUT,
consumer_tasks: ONE_TO_MANY_OUTPUT,
})
}
#[test]
fn validates_partition_coverage_many_to_one() -> Result<()> {
assert_case(Case {
name: "n_to_1",
input_tasks: MANY_TO_ONE_INPUT,
consumer_tasks: MANY_TO_ONE_OUTPUT,
})
}
#[test]
fn validates_partition_coverage_many_to_fewer() -> Result<()> {
assert_case(Case {
name: "n_to_m_n_gt_m",
input_tasks: MANY_TO_FEWER_INPUT,
consumer_tasks: MANY_TO_FEWER_OUTPUT,
})
}
#[test]
fn validates_partition_coverage_fewer_to_many() -> Result<()> {
assert_case(Case {
name: "m_to_n_n_gt_m",
input_tasks: FEWER_TO_MANY_INPUT,
consumer_tasks: FEWER_TO_MANY_OUTPUT,
})
}
}