use crate::DistributedTaskContext;
use crate::common::task_ctx_with_extension;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{internal_err, plan_err};
use datafusion::error::DataFusionError;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, ExecutionPlanProperties,
Partitioning, PlanProperties,
};
use futures::{Stream, StreamExt};
use itertools::Itertools;
use std::fmt::Formatter;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
#[derive(Debug, Clone)]
pub struct ChildrenIsolatorUnionExec {
pub(crate) properties: Arc<PlanProperties>,
pub(crate) metrics: ExecutionPlanMetricsSet,
pub(crate) children: Vec<Arc<dyn ExecutionPlan>>,
pub(crate) child_weights: Vec<ChildWeight>,
pub(crate) task_idx_map: Vec<
Vec<(
/* child index */ usize,
/* inner distributed task ctx for the isolated child*/ DistributedTaskContext,
)>,
>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct ChildWeight {
pub(crate) weight: f64,
pub(crate) max: Option<usize>,
}
impl ChildWeight {
pub fn desired(w: f64) -> Self {
Self {
weight: w,
max: None,
}
}
pub fn maximum(n: usize) -> Self {
Self {
weight: n as f64,
max: Some(n),
}
}
}
impl ChildrenIsolatorUnionExec {
pub(crate) fn from_children_and_weights(
children: impl IntoIterator<Item = Arc<dyn ExecutionPlan>>,
children_weights: impl IntoIterator<Item = ChildWeight>,
task_count: usize,
) -> Result<Self, DataFusionError> {
let children = children.into_iter().collect_vec();
let weights = children_weights.into_iter().collect_vec();
if children.len() != weights.len() {
return internal_err!(
"ChildrenIsolatorUnionExec received {} children but a vec of {} weights for those children. This is a bug in the distributed planning logic, please report it",
children.len(),
weights.len()
);
}
let task_idx_map = split_children(&weights, task_count)?;
let mut partition_counts = vec![0; task_idx_map.len()];
for (t, children_in_task) in task_idx_map.iter().enumerate() {
for (child_idx, _) in children_in_task {
partition_counts[t] += children[*child_idx].output_partitioning().partition_count();
}
}
let Some(partition_count) = partition_counts.iter().max() else {
return internal_err!(
"ChildrenIsolatorUnionExec built an empty task_idx_map. This is a bug in the distributed planning logic, please report it"
);
};
let mut properties = UnionExec::try_new(children.clone())?
.properties()
.as_ref()
.clone();
properties.partitioning = Partitioning::UnknownPartitioning(*partition_count);
Ok(Self {
properties: Arc::new(properties),
metrics: ExecutionPlanMetricsSet::default(),
children,
child_weights: weights,
task_idx_map,
})
}
pub(crate) fn child_task_counts(&self) -> Vec<usize> {
let mut counts = vec![0; self.children.len()];
for children_in_task in &self.task_idx_map {
for (child_idx, child_task_ctx) in children_in_task {
counts[*child_idx] = counts[*child_idx].max(child_task_ctx.task_count);
}
}
counts
}
pub(crate) fn to_task_specialized(&self, task_i: usize) -> Self {
let mut children_to_keep = vec![];
for (child_i, _) in &self.task_idx_map[task_i] {
children_to_keep.push(*child_i);
}
let new_children = self
.children
.iter()
.enumerate()
.map(
|(child_i, plan)| match children_to_keep.contains(&child_i) {
true => Arc::clone(plan),
false => Arc::new(
EmptyExec::new(plan.schema())
.with_partitions(plan.output_partitioning().partition_count()),
) as Arc<dyn ExecutionPlan>,
},
)
.collect_vec();
Self {
children: new_children,
properties: self.properties.clone(),
metrics: self.metrics.clone(),
child_weights: self.child_weights.clone(),
task_idx_map: self.task_idx_map.clone(),
}
}
}
impl DisplayAs for ChildrenIsolatorUnionExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "DistributedUnionExec:")?;
for (task_i, children_in_task) in self.task_idx_map.iter().enumerate() {
write!(f, " t{task_i}:[")?;
for (i, (child_idx, child_task_ctx)) in children_in_task.iter().enumerate() {
if child_task_ctx.task_count > 1 {
write!(
f,
"c{child_idx}({}/{})",
child_task_ctx.task_index, child_task_ctx.task_count
)?;
} else {
write!(f, "c{child_idx}")?;
}
if i < children_in_task.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")?;
}
Ok(())
}
DisplayFormatType::TreeRender => Ok(()),
}
}
}
impl ExecutionPlan for ChildrenIsolatorUnionExec {
fn name(&self) -> &str {
"ChildrenIsolatorUnionExec"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
if children.len() != self.children.len() {
return plan_err!(
"Number of children must match the original plan, have {} but expected {}",
children.len(),
self.children.len()
);
}
Ok(Arc::new(Self::from_children_and_weights(
children,
self.child_weights.clone(),
self.task_idx_map.len(),
)?))
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
self.children.iter().collect()
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn execute(
&self,
mut partition: usize,
context: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
let d_ctx = DistributedTaskContext::from_ctx(&context);
let children = self.task_idx_map[d_ctx.task_index].clone();
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
for (child_idx, child_task_ctx) in children {
let Some(input) = self.children.get(child_idx) else {
return internal_err!("Could not find child with index {child_idx}");
};
if partition < input.output_partitioning().partition_count() {
let context = Arc::new(task_ctx_with_extension(context.as_ref(), child_task_ctx));
let stream = input.execute(partition, context)?;
return Ok(Box::pin(ObservedStream::new(
stream,
baseline_metrics,
None,
)));
} else {
partition -= input.output_partitioning().partition_count();
}
}
Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())))
}
}
pub(crate) struct ObservedStream {
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
fetch: Option<usize>,
produced: usize,
}
impl ObservedStream {
pub fn new(
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
fetch: Option<usize>,
) -> Self {
Self {
inner,
baseline_metrics,
fetch,
produced: 0,
}
}
fn limit_reached(
&mut self,
poll: Poll<Option<datafusion::common::Result<RecordBatch>>>,
) -> Poll<Option<datafusion::common::Result<RecordBatch>>> {
let Some(fetch) = self.fetch else { return poll };
if self.produced >= fetch {
return Poll::Ready(None);
}
if let Poll::Ready(Some(Ok(batch))) = &poll {
if self.produced + batch.num_rows() > fetch {
let batch = batch.slice(0, fetch.saturating_sub(self.produced));
self.produced += batch.num_rows();
return Poll::Ready(Some(Ok(batch)));
};
self.produced += batch.num_rows()
}
poll
}
}
impl RecordBatchStream for ObservedStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for ObservedStream {
type Item = datafusion::common::Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut poll = self.inner.poll_next_unpin(cx);
if self.fetch.is_some() {
poll = self.limit_reached(poll);
}
self.baseline_metrics.record_poll(poll)
}
}
fn split_children(
children: &[ChildWeight],
task_count_budget: usize,
) -> Result<
Vec<
Vec<(
/* Child index */ usize,
/* Distributed task ctx for the child */ DistributedTaskContext,
)>,
>,
DataFusionError,
> {
if task_count_budget == 0 {
return internal_err!(
"ChildrenIsolatorUnionExec had a task count {task_count_budget}. This is a bug in the distributed planning logic, please report it"
);
}
if children.is_empty() {
return internal_err!(
"ChildrenIsolatorUnionExec built with no children. This is a bug in the distributed planning logic, please report it"
);
}
for (i, weight) in children.iter().enumerate() {
if weight.max == Some(0) {
return plan_err!(
"ChildrenIsolatorUnionExec child {i} has a max task count of 0, which is invalid"
);
}
if weight.weight < 0.0 {
return plan_err!(
"ChildrenIsolatorUnionExec child {i} has a negative desired wait of {}, which is invalid.",
weight.weight
);
}
if !weight.weight.is_finite() {
return plan_err!(
"ChildrenIsolatorUnionExec child {i} has a non-finite desired wait of {}, which is invalid.",
weight.weight
);
}
}
let child_weights: Vec<f64> = children.iter().map(|w| w.weight).collect();
let total_weight: f64 = child_weights.iter().sum(); let child_count = children.len();
let unrounded_child_task_counts: Vec<f64> = if total_weight > 0.0 {
child_weights
.iter()
.map(|w| task_count_budget as f64 * w / total_weight)
.collect()
} else {
vec![task_count_budget as f64 / child_count as f64; child_count]
};
let mut child_task_counts = unrounded_child_task_counts
.iter()
.map(|x| x.floor() as usize)
.collect::<Vec<_>>();
for (task_count, child_weight) in child_task_counts.iter_mut().zip(children.iter()) {
if let Some(max) = child_weight.max {
*task_count = (*task_count).min(max);
}
}
let allocated_task_counts: usize = child_task_counts.iter().sum();
let mut unallocated_task_counts = task_count_budget.saturating_sub(allocated_task_counts);
if unallocated_task_counts > 0 {
let mut order: Vec<usize> = (0..child_count).collect();
order.sort_by(|&a, &b| {
let ra = unrounded_child_task_counts[a] - unrounded_child_task_counts[a].floor();
let rb = unrounded_child_task_counts[b] - unrounded_child_task_counts[b].floor();
rb.partial_cmp(&ra)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.cmp(&b))
});
while unallocated_task_counts > 0 {
let mut made_progress = false;
for &idx in &order {
if unallocated_task_counts == 0 {
break;
}
if let Some(max) = children[idx].max
&& child_task_counts[idx] >= max
{
continue;
}
child_task_counts[idx] += 1;
unallocated_task_counts -= 1;
made_progress = true;
}
if !made_progress {
break;
}
}
}
let mut result = vec![vec![]; task_count_budget];
let mut task_idx = 0;
for (child_idx, &task_count) in child_task_counts.iter().enumerate() {
for task_i in 0..task_count {
result[task_idx].push((
child_idx,
DistributedTaskContext {
task_index: task_i,
task_count,
},
));
task_idx += 1;
}
}
if task_idx > 0 {
let mut zero_alloc_i = 0usize;
for (child_idx, &task_count) in child_task_counts.iter().enumerate() {
if task_count != 0 {
continue;
}
let slot = zero_alloc_i % task_idx;
result[slot].push((
child_idx,
DistributedTaskContext {
task_index: 0,
task_count: 1,
},
));
zero_alloc_i += 1;
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn children_split_all_1_task() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(&[des(1.0), des(1.0), des(1.0)], 3)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 1))]
]
);
assert_eq!(
split_children(&[des(1.0), des(1.0), des(1.0)], 2)?,
vec![vec![(0, ctx(0, 1)), (2, ctx(0, 1))], vec![(1, ctx(0, 1))]]
);
assert_eq!(
split_children(&[des(1.0), des(1.0), des(1.0)], 1)?,
vec![vec![(0, ctx(0, 1)), (1, ctx(0, 1)), (2, ctx(0, 1))]]
);
Ok(())
}
#[test]
fn split_children_different_tasks() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(&[des(1.0), des(2.0), des(3.0)], 6)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
vec![(2, ctx(0, 3))],
vec![(2, ctx(1, 3))],
vec![(2, ctx(2, 3))]
]
);
assert_eq!(
split_children(&[des(1.0), des(2.0), des(3.0)], 5)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
vec![(2, ctx(0, 2))],
vec![(2, ctx(1, 2))],
]
);
assert_eq!(
split_children(&[des(1.0), des(2.0), des(3.0)], 4)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 2))],
vec![(2, ctx(1, 2))],
]
);
assert_eq!(
split_children(&[des(1.0), des(2.0), des(3.0)], 3)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 1))],
]
);
assert_eq!(
split_children(&[des(1.0), des(2.0), des(3.0)], 2)?,
vec![vec![(1, ctx(0, 1)), (0, ctx(0, 1))], vec![(2, ctx(0, 1))]]
);
assert_eq!(
split_children(&[des(1.0), des(2.0), des(3.0)], 1)?,
vec![vec![(2, ctx(0, 1)), (0, ctx(0, 1)), (1, ctx(0, 1))]]
);
Ok(())
}
#[test]
fn split_children_budget_exceeds_children_weight_sum() -> Result<(), Box<dyn std::error::Error>>
{
assert_eq!(
split_children(&[des(1.0), des(1.0)], 3)?,
vec![
vec![(0, ctx(0, 2))],
vec![(0, ctx(1, 2))],
vec![(1, ctx(0, 1))],
]
);
assert_eq!(
split_children(&[des(1.0), des(1.0)], 5)?,
vec![
vec![(0, ctx(0, 3))],
vec![(0, ctx(1, 3))],
vec![(0, ctx(2, 3))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
]
);
assert_eq!(
split_children(&[des(1.0), des(2.0)], 4)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 3))],
vec![(1, ctx(1, 3))],
vec![(1, ctx(2, 3))],
]
);
Ok(())
}
#[test]
fn split_children_packs_zero_share_children_into_last_slot()
-> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(&[des(10.0), des(1.0), des(1.0)], 3)?,
vec![
vec![(0, ctx(0, 3)), (1, ctx(0, 1))],
vec![(0, ctx(1, 3)), (2, ctx(0, 1))],
vec![(0, ctx(2, 3))],
]
);
Ok(())
}
#[test]
fn split_children_respects_maximum_caps() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(&[max(1), max(1)], 3)?,
vec![vec![(0, ctx(0, 1))], vec![(1, ctx(0, 1))], vec![]]
);
assert_eq!(
split_children(&[max(1), des(1.0)], 3)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
]
);
assert_eq!(
split_children(&[max(2), des(1.0), des(1.0)], 6)?,
vec![
vec![(0, ctx(0, 2))],
vec![(0, ctx(1, 2))],
vec![(1, ctx(0, 2))],
vec![(1, ctx(1, 2))],
vec![(2, ctx(0, 2))],
vec![(2, ctx(1, 2))],
]
);
assert_eq!(
split_children(&[max(2), max(1)], 3)?,
vec![
vec![(0, ctx(0, 2))],
vec![(0, ctx(1, 2))],
vec![(1, ctx(0, 1))],
]
);
Ok(())
}
#[test]
fn split_children_all_zero_weights_splits_evenly() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(
split_children(&[des(0.0), des(0.0), des(0.0)], 3)?,
vec![
vec![(0, ctx(0, 1))],
vec![(1, ctx(0, 1))],
vec![(2, ctx(0, 1))],
]
);
Ok(())
}
#[test]
fn split_children_rejects_negative_weight() {
let err = split_children(&[des(1.0), des(-1.0), des(1.0)], 3).unwrap_err();
assert!(
err.to_string().contains("negative"),
"unexpected error: {err}"
);
}
#[test]
fn split_children_rejects_nan_weight() {
let err = split_children(&[des(f64::NAN), des(1.0)], 2).unwrap_err();
assert!(
err.to_string().contains("non-finite"),
"unexpected error: {err}"
);
}
#[test]
fn split_children_rejects_infinite_weight() {
let err = split_children(&[des(1.0), des(f64::INFINITY)], 2).unwrap_err();
assert!(
err.to_string().contains("non-finite"),
"unexpected error: {err}"
);
}
#[test]
fn split_children_rejects_zero_max() {
let err = split_children(
&[
des(1.0),
ChildWeight {
weight: 1.0,
max: Some(0),
},
des(1.0),
],
3,
)
.unwrap_err();
assert!(
err.to_string().contains("max task count of 0"),
"unexpected error: {err}"
);
}
fn ctx(task_index: usize, task_count: usize) -> DistributedTaskContext {
DistributedTaskContext {
task_index,
task_count,
}
}
fn des(w: f64) -> ChildWeight {
ChildWeight::desired(w)
}
fn max(n: usize) -> ChildWeight {
ChildWeight::maximum(n)
}
}