use crate::DistributedTaskContext;
use crate::common::task_ctx_with_extension;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{internal_err, plan_err};
use datafusion::error::DataFusionError;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, ExecutionPlanProperties,
Partitioning, PlanProperties,
};
use futures::{Stream, StreamExt};
use itertools::Itertools;
use std::any::Any;
use std::fmt::Formatter;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
#[derive(Debug, Clone)]
pub struct ChildrenIsolatorUnionExec {
pub(crate) properties: Arc<PlanProperties>,
pub(crate) metrics: ExecutionPlanMetricsSet,
pub(crate) children: Vec<Arc<dyn ExecutionPlan>>,
pub(crate) task_idx_map: Vec<
Vec<(
/* child index */ usize,
/* inner distributed task ctx for the isolated child*/ DistributedTaskContext,
)>,
>,
}
impl ChildrenIsolatorUnionExec {
pub(crate) fn from_children_and_task_counts(
children: impl IntoIterator<Item = Arc<dyn ExecutionPlan>>,
children_task_count: impl IntoIterator<Item = usize>,
task_count: usize,
) -> Result<Self, DataFusionError> {
let children = children.into_iter().collect_vec();
let task_count_per_children = children_task_count.into_iter().collect_vec();
if children.len() != task_count_per_children.len() {
return internal_err!(
"ChildrenIsolatorUnionExec received {} children but a vec of {} positions for those children. This is a bug in the distributed planning logic, please report it",
children.len(),
task_count_per_children.len()
);
}
let task_idx_map = split_children(task_count_per_children, task_count)?;
let mut partition_counts = vec![0; task_idx_map.len()];
for (t, children_in_task) in task_idx_map.iter().enumerate() {
for (child_idx, _) in children_in_task {
partition_counts[t] += children[*child_idx].output_partitioning().partition_count();
}
}
let Some(partition_count) = partition_counts.iter().max() else {
return internal_err!(
"ChildrenIsolatorUnionExec built an empty task_idx_map. This is a bug in the distributed planning logic, please report it"
);
};
let mut properties = UnionExec::try_new(children.clone())?
.properties()
.as_ref()
.clone();
properties.partitioning = Partitioning::UnknownPartitioning(*partition_count);
Ok(Self {
properties: Arc::new(properties),
metrics: ExecutionPlanMetricsSet::default(),
children,
task_idx_map,
})
}
}
impl DisplayAs for ChildrenIsolatorUnionExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "DistributedUnionExec:")?;
for (task_i, children_in_task) in self.task_idx_map.iter().enumerate() {
write!(f, " t{task_i}:[")?;
for (i, (child_idx, child_task_ctx)) in children_in_task.iter().enumerate() {
if child_task_ctx.task_count > 1 {
write!(
f,
"c{child_idx}({}/{})",
child_task_ctx.task_index, child_task_ctx.task_count
)?;
} else {
write!(f, "c{child_idx}")?;
}
if i < children_in_task.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")?;
}
Ok(())
}
DisplayFormatType::TreeRender => Ok(()),
}
}
}
impl ExecutionPlan for ChildrenIsolatorUnionExec {
fn name(&self) -> &str {
"ChildrenIsolatorUnionExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
if children.len() != self.children.len() {
return plan_err!(
"Number of children must match the original plan, have {} but expected {}",
children.len(),
self.children.len()
);
}
let mut clone = self.as_ref().clone();
clone.children = children;
Ok(Arc::new(clone))
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
self.children.iter().collect()
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn execute(
&self,
mut partition: usize,
context: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
let d_ctx = DistributedTaskContext::from_ctx(&context);
let children = self.task_idx_map[d_ctx.task_index].clone();
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
for (child_idx, child_task_ctx) in children {
let Some(input) = self.children.get(child_idx) else {
return internal_err!("Could not find child with index {child_idx}");
};
if partition < input.output_partitioning().partition_count() {
let context = Arc::new(task_ctx_with_extension(context.as_ref(), child_task_ctx));
let stream = input.execute(partition, context)?;
return Ok(Box::pin(ObservedStream::new(
stream,
baseline_metrics,
None,
)));
} else {
partition -= input.output_partitioning().partition_count();
}
}
Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())))
}
}
pub(crate) struct ObservedStream {
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
fetch: Option<usize>,
produced: usize,
}
impl ObservedStream {
pub fn new(
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
fetch: Option<usize>,
) -> Self {
Self {
inner,
baseline_metrics,
fetch,
produced: 0,
}
}
fn limit_reached(
&mut self,
poll: Poll<Option<datafusion::common::Result<RecordBatch>>>,
) -> Poll<Option<datafusion::common::Result<RecordBatch>>> {
let Some(fetch) = self.fetch else { return poll };
if self.produced >= fetch {
return Poll::Ready(None);
}
if let Poll::Ready(Some(Ok(batch))) = &poll {
if self.produced + batch.num_rows() > fetch {
let batch = batch.slice(0, fetch.saturating_sub(self.produced));
self.produced += batch.num_rows();
return Poll::Ready(Some(Ok(batch)));
};
self.produced += batch.num_rows()
}
poll
}
}
impl RecordBatchStream for ObservedStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for ObservedStream {
type Item = datafusion::common::Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut poll = self.inner.poll_next_unpin(cx);
if self.fetch.is_some() {
poll = self.limit_reached(poll);
}
self.baseline_metrics.record_poll(poll)
}
}
fn split_children(
mut task_count_per_children: Vec<usize>,
task_count_budget: usize,
) -> Result<
Vec<
Vec<(
/* Child index */ usize,
/* Distributed task ctx for the child */ DistributedTaskContext,
)>,
>,
DataFusionError,
> {
let total_children_tasks = task_count_per_children.iter().sum::<usize>();
if task_count_budget > total_children_tasks {
return internal_err!(
"ChildrenIsolatorUnionExec had a task count {task_count_budget}, which is greater than the sum of child task counts {total_children_tasks}. This is a bug in the distributed planning logic, please report it"
);
} else if task_count_budget == 0 {
return internal_err!(
"ChildrenIsolatorUnionExec had a task count {task_count_budget}. This is a bug in the distributed planning logic, please report it"
);
}
let mut tasks_to_trim = total_children_tasks - task_count_budget;
while tasks_to_trim > 0 {
let mut max_child_task_count_idx = 0;
let mut max_child_task_count_value = 1;
for (i, child_task_count) in task_count_per_children.iter().enumerate() {
if child_task_count > &max_child_task_count_value {
max_child_task_count_idx = i;
max_child_task_count_value = *child_task_count;
}
}
if max_child_task_count_value == 1 {
break;
}
task_count_per_children[max_child_task_count_idx] -= 1;
tasks_to_trim -= 1;
}
let total_child_tasks: usize = task_count_per_children.iter().sum();
let base_per_task = total_child_tasks / task_count_budget;
let mut extra = total_child_tasks % task_count_budget;
let mut result = vec![vec![]; task_count_budget];
let mut task_idx = 0;
let mut current_task_count = 0;
let mut current_task_capacity = base_per_task;
if extra > 0 {
extra -= 1;
current_task_capacity += 1
}
for (child_idx, &child_task_count) in task_count_per_children.iter().enumerate() {
for task_i in 0..child_task_count {
result[task_idx].push((
child_idx,
DistributedTaskContext {
task_index: task_i,
task_count: child_task_count,
},
));
current_task_count += 1;
if current_task_count >= current_task_capacity && task_idx < task_count_budget - 1 {
task_idx += 1;
current_task_count = 0;
current_task_capacity = base_per_task;
if extra > 0 {
extra -= 1;
current_task_capacity += 1
}
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn children_split_all_1_task() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(vec![1, 1, 1], 3)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 1))]
]
);
assert_eq!(
split_children(vec![1, 1, 1], 2)?,
vec![vec![(0, ctx(0, 1)), (1, ctx(0, 1))], vec![(2, ctx(0, 1))]]
);
assert_eq!(
split_children(vec![1, 1, 1], 1)?,
vec![vec![(0, ctx(0, 1)), (1, ctx(0, 1)), (2, ctx(0, 1))]]
);
Ok(())
}
#[test]
fn split_children_different_tasks() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(vec![1, 2, 3], 6)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
vec![(2, ctx(0, 3))],
vec![(2, ctx(1, 3))],
vec![(2, ctx(2, 3))]
]
);
assert_eq!(
split_children(vec![1, 2, 3], 5)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
vec![(2, ctx(0, 2))],
vec![(2, ctx(1, 2))],
]
);
assert_eq!(
split_children(vec![1, 2, 3], 4)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 2))],
vec![(2, ctx(1, 2))],
]
);
assert_eq!(
split_children(vec![1, 2, 3], 3)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 1))],
]
);
assert_eq!(
split_children(vec![1, 2, 3], 2)?,
vec![vec![(0, ctx(0, 1)), (1, ctx(0, 1))], vec![(2, ctx(0, 1))]]
);
assert_eq!(
split_children(vec![1, 2, 3], 1)?,
vec![vec![(0, ctx(0, 1)), (1, ctx(0, 1)), (2, ctx(0, 1))]]
);
Ok(())
}
fn ctx(task_index: usize, task_count: usize) -> DistributedTaskContext {
DistributedTaskContext {
task_index,
task_count,
}
}
}