use crate::TaskCountAnnotation::{Desired, Maximum};
use crate::execution_plans::ChildrenIsolatorUnionExec;
use crate::{BroadcastExec, DistributedConfig, TaskCountAnnotation, TaskEstimator};
use datafusion::common::{DataFusionError, plan_datafusion_err};
use datafusion::config::ConfigOptions;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::execution_plan::CardinalityEffect;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::union::UnionExec;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
pub(super) enum PlanOrNetworkBoundary {
Plan(Arc<dyn ExecutionPlan>),
Shuffle,
Coalesce,
Broadcast,
}
impl Debug for PlanOrNetworkBoundary {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Plan(plan) => write!(f, "{}", plan.name()),
Self::Shuffle => write!(f, "[NetworkBoundary] Shuffle"),
Self::Coalesce => write!(f, "[NetworkBoundary] Coalesce"),
Self::Broadcast => write!(f, "[NetworkBoundary] Broadcast"),
}
}
}
impl PlanOrNetworkBoundary {
fn is_network_boundary(&self) -> bool {
matches!(self, Self::Shuffle | Self::Coalesce | Self::Broadcast)
}
}
pub(super) struct AnnotatedPlan {
pub(super) plan_or_nb: PlanOrNetworkBoundary,
pub(super) children: Vec<AnnotatedPlan>,
pub(super) task_count: TaskCountAnnotation,
}
impl Debug for AnnotatedPlan {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
fn fmt_dbg(f: &mut Formatter<'_>, plan: &AnnotatedPlan, depth: usize) -> std::fmt::Result {
write!(
f,
"{}{:?}: task_count={:?}",
" ".repeat(depth * 2),
plan.plan_or_nb,
plan.task_count
)?;
writeln!(f)?;
for child in plan.children.iter() {
fmt_dbg(f, child, depth + 1)?;
}
Ok(())
}
fmt_dbg(f, self, 0)
}
}
pub(super) fn annotate_plan(
plan: Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Result<AnnotatedPlan, DataFusionError> {
_annotate_plan(plan, None, cfg, true)
}
fn _annotate_plan(
plan: Arc<dyn ExecutionPlan>,
parent: Option<&Arc<dyn ExecutionPlan>>,
cfg: &ConfigOptions,
root: bool,
) -> Result<AnnotatedPlan, DataFusionError> {
let d_cfg = DistributedConfig::from_config_options(cfg)?;
let broadcast_joins = d_cfg.broadcast_joins;
let estimator = &d_cfg.__private_task_estimator;
let max_tasks = match d_cfg.max_tasks_per_stage {
0 => d_cfg.__private_worker_resolver.0.get_urls()?.len().max(1),
v => v,
};
let annotated_children = plan
.children()
.iter()
.map(|child| _annotate_plan(Arc::clone(child), Some(&plan), cfg, false))
.collect::<Result<Vec<_>, _>>()?;
if plan.children().is_empty() {
return if let Some(estimate) = estimator.task_estimation(&plan, cfg) {
Ok(AnnotatedPlan {
plan_or_nb: PlanOrNetworkBoundary::Plan(plan),
children: Vec::new(),
task_count: estimate.task_count.limit(max_tasks),
})
} else {
Ok(AnnotatedPlan {
plan_or_nb: PlanOrNetworkBoundary::Plan(plan),
children: Vec::new(),
task_count: Maximum(1),
})
};
}
let mut task_count = estimator
.task_estimation(&plan, cfg)
.map_or(Desired(1), |v| v.task_count);
if d_cfg.children_isolator_unions && plan.as_any().is::<UnionExec>() {
let mut count = 0;
for annotated_child in annotated_children.iter() {
count += annotated_child.task_count.as_usize();
}
task_count = Desired(count);
} else if let Some(node) = plan.as_any().downcast_ref::<HashJoinExec>()
&& node.mode == PartitionMode::CollectLeft
&& !broadcast_joins
{
task_count = Maximum(1);
} else {
for annotated_child in annotated_children.iter() {
task_count = match (task_count, &annotated_child.task_count) {
(Desired(desired), Desired(child)) => Desired(desired.max(*child)),
(Maximum(max), Desired(_)) => Maximum(max),
(Desired(_), Maximum(max)) => Maximum(*max),
(Maximum(max_1), Maximum(max_2)) => Maximum(max_1.min(*max_2)),
};
}
}
task_count = task_count.limit(max_tasks);
let mut annotation = AnnotatedPlan {
plan_or_nb: PlanOrNetworkBoundary::Plan(Arc::clone(&plan)),
children: annotated_children,
task_count: task_count.clone(),
};
if let Some(r_exec) = plan.as_any().downcast_ref::<RepartitionExec>() {
if matches!(r_exec.partitioning(), Partitioning::Hash(_, _)) {
annotation = AnnotatedPlan {
plan_or_nb: PlanOrNetworkBoundary::Shuffle,
children: vec![annotation],
task_count,
};
}
} else if let Some(parent) = parent
&& !plan.children().is_empty()
&& (parent.as_any().is::<CoalescePartitionsExec>()
|| parent.as_any().is::<SortPreservingMergeExec>())
{
if plan.as_any().is::<BroadcastExec>() {
annotation = AnnotatedPlan {
plan_or_nb: PlanOrNetworkBoundary::Broadcast,
children: vec![annotation],
task_count,
};
} else {
annotation = AnnotatedPlan {
plan_or_nb: PlanOrNetworkBoundary::Coalesce,
children: vec![annotation],
task_count,
};
}
}
fn propagate_task_count(
annotation: &mut AnnotatedPlan,
task_count: &TaskCountAnnotation,
d_cfg: &DistributedConfig,
) -> Result<(), DataFusionError> {
annotation.task_count = task_count.clone();
let plan = match &annotation.plan_or_nb {
PlanOrNetworkBoundary::Plan(plan) => plan,
PlanOrNetworkBoundary::Broadcast => {
if let Maximum(max) = task_count {
for child in annotation.children.iter_mut() {
let child_task_count = child.task_count.clone().limit(*max);
propagate_task_count(child, &child_task_count, d_cfg)?;
}
}
return Ok(());
}
PlanOrNetworkBoundary::Shuffle => return Ok(()),
PlanOrNetworkBoundary::Coalesce => return Ok(()),
};
if d_cfg.children_isolator_unions && plan.as_any().is::<UnionExec>() {
let c_i_union = ChildrenIsolatorUnionExec::from_children_and_task_counts(
plan.children().into_iter().cloned(),
annotation.children.iter().map(|v| v.task_count.as_usize()),
task_count.as_usize(),
)?;
for children_and_tasks in c_i_union.task_idx_map.iter() {
for (child_i, task_ctx) in children_and_tasks {
if let Some(child) = annotation.children.get_mut(*child_i) {
propagate_task_count(child, &Maximum(task_ctx.task_count), d_cfg)?
};
}
}
annotation.plan_or_nb = PlanOrNetworkBoundary::Plan(Arc::new(c_i_union));
} else {
for child in &mut annotation.children {
propagate_task_count(child, task_count, d_cfg)?;
}
}
Ok(())
}
if annotation.plan_or_nb.is_network_boundary() {
for annotated_child in annotation.children.iter_mut() {
propagate_task_count(annotated_child, &annotation.task_count, d_cfg)?;
}
if matches!(annotation.plan_or_nb, PlanOrNetworkBoundary::Coalesce) {
annotation.task_count = Maximum(1);
return Ok(annotation);
}
fn calculate_scale_factor(annotation: &AnnotatedPlan, f: f64) -> f64 {
let PlanOrNetworkBoundary::Plan(plan) = &annotation.plan_or_nb else {
return 1.0;
};
let mut sf = None;
for plan in &annotation.children {
sf = match sf {
None => Some(calculate_scale_factor(plan, f)),
Some(sf) => Some(sf.max(calculate_scale_factor(plan, f))),
}
}
let sf = sf.unwrap_or(1.0);
match plan.cardinality_effect() {
CardinalityEffect::LowerEqual => sf / f,
CardinalityEffect::GreaterEqual => sf * f,
_ => sf,
}
}
let sf = calculate_scale_factor(
annotation.children.first().ok_or_else(|| {
plan_datafusion_err!("missing child in a plan annotated with a network boundary")
})?,
d_cfg.cardinality_task_count_factor,
);
let prev_task_count = annotation.task_count.as_usize() as f64;
annotation.task_count = Desired((prev_task_count * sf).ceil() as usize);
Ok(annotation)
} else if root {
let task_count = annotation.task_count.clone();
propagate_task_count(&mut annotation, &task_count, d_cfg)?;
Ok(annotation)
} else {
Ok(annotation)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed_planner::insert_broadcast::insert_broadcast_execs;
use crate::test_utils::plans::{
BuildSideOneTaskEstimator, TestPlanOptions, base_session_builder, context_with_query,
sql_to_physical_plan,
};
use crate::{DistributedExt, TaskEstimation, TaskEstimator, assert_snapshot};
use datafusion::config::ConfigOptions;
use datafusion::execution::SessionStateBuilder;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::filter::FilterExec;
#[tokio::test]
async fn test_select_all() {
let query = r#"
SELECT * FROM weather
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @"DataSourceExec: task_count=Desired(3)")
}
#[tokio::test]
async fn test_aggregation() {
let query = r#"
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
ProjectionExec: task_count=Maximum(1)
SortPreservingMergeExec: task_count=Maximum(1)
[NetworkBoundary] Coalesce: task_count=Maximum(1)
SortExec: task_count=Desired(2)
ProjectionExec: task_count=Desired(2)
AggregateExec: task_count=Desired(2)
[NetworkBoundary] Shuffle: task_count=Desired(2)
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_left_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}
#[tokio::test]
async fn test_left_join_distributed() {
let query = r#"
WITH a AS (
SELECT
AVG("MinTemp") as "MinTemp",
"RainTomorrow"
FROM weather
WHERE "RainToday" = 'yes'
GROUP BY "RainTomorrow"
), b AS (
SELECT
AVG("MaxTemp") as "MaxTemp",
"RainTomorrow"
FROM weather
WHERE "RainToday" = 'no'
GROUP BY "RainTomorrow"
)
SELECT
a."MinTemp",
b."MaxTemp"
FROM a
LEFT JOIN b
ON a."RainTomorrow" = b."RainTomorrow"
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
[NetworkBoundary] Coalesce: task_count=Maximum(1)
ProjectionExec: task_count=Desired(2)
AggregateExec: task_count=Desired(2)
[NetworkBoundary] Shuffle: task_count=Desired(2)
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
ProjectionExec: task_count=Maximum(1)
AggregateExec: task_count=Maximum(1)
[NetworkBoundary] Shuffle: task_count=Maximum(1)
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_inner_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp" FROM weather a INNER JOIN weather b ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}
#[tokio::test]
async fn test_distinct() {
let query = r#"
SELECT DISTINCT "RainToday" FROM weather
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
AggregateExec: task_count=Desired(2)
[NetworkBoundary] Shuffle: task_count=Desired(2)
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_union_all() {
let query = r#"
SELECT "MinTemp" FROM weather WHERE "RainToday" = 'yes'
UNION ALL
SELECT "MaxTemp" FROM weather WHERE "RainToday" = 'no'
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
ChildrenIsolatorUnionExec: task_count=Desired(4)
FilterExec: task_count=Maximum(2)
RepartitionExec: task_count=Maximum(2)
DataSourceExec: task_count=Maximum(2)
ProjectionExec: task_count=Maximum(2)
FilterExec: task_count=Maximum(2)
RepartitionExec: task_count=Maximum(2)
DataSourceExec: task_count=Maximum(2)
")
}
#[tokio::test]
async fn test_subquery() {
let query = r#"
SELECT * FROM (
SELECT "MinTemp", "MaxTemp" FROM weather WHERE "RainToday" = 'yes'
) AS subquery WHERE "MinTemp" > 5
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_window_function() {
let query = r#"
SELECT "MinTemp", ROW_NUMBER() OVER (PARTITION BY "RainToday" ORDER BY "MinTemp") as rn
FROM weather
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
ProjectionExec: task_count=Desired(3)
BoundedWindowAggExec: task_count=Desired(3)
SortExec: task_count=Desired(3)
[NetworkBoundary] Shuffle: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_children_isolator_union() {
let query = r#"
SET distributed.children_isolator_unions = true;
SET distributed.files_per_task = 1;
SELECT "MinTemp" FROM weather WHERE "RainToday" = 'yes'
UNION ALL
SELECT "MaxTemp" FROM weather WHERE "RainToday" = 'no'
UNION ALL
SELECT "Rainfall" FROM weather WHERE "RainTomorrow" = 'yes'
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
ChildrenIsolatorUnionExec: task_count=Desired(4)
FilterExec: task_count=Maximum(1)
RepartitionExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
ProjectionExec: task_count=Maximum(1)
FilterExec: task_count=Maximum(1)
RepartitionExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
ProjectionExec: task_count=Maximum(2)
FilterExec: task_count=Maximum(2)
RepartitionExec: task_count=Maximum(2)
DataSourceExec: task_count=Maximum(2)
")
}
#[tokio::test]
async fn test_intermediate_task_estimator() {
let query = r#"
SELECT DISTINCT "RainToday" FROM weather
"#;
let annotated = sql_to_annotated_with_estimator(query, |_: &RepartitionExec| {
Some(TaskEstimation::maximum(1))
})
.await;
assert_snapshot!(annotated, @r"
AggregateExec: task_count=Desired(1)
[NetworkBoundary] Shuffle: task_count=Desired(1)
RepartitionExec: task_count=Maximum(1)
AggregateExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}
#[tokio::test]
async fn test_union_all_limited_by_intermediate_estimator() {
let query = r#"
SELECT "MinTemp" FROM weather WHERE "RainToday" = 'yes'
UNION ALL
SELECT "MaxTemp" FROM weather WHERE "RainToday" = 'no'
"#;
let annotated = sql_to_annotated_with_estimator(query, |_: &FilterExec| {
Some(TaskEstimation::maximum(1))
})
.await;
assert_snapshot!(annotated, @r"
ChildrenIsolatorUnionExec: task_count=Desired(2)
FilterExec: task_count=Maximum(1)
RepartitionExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
ProjectionExec: task_count=Maximum(1)
FilterExec: task_count=Maximum(1)
RepartitionExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}
#[tokio::test]
async fn test_broadcast_join_annotation() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated_broadcast(query, 4, 4, true).await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Desired(3)
CoalescePartitionsExec: task_count=Desired(3)
[NetworkBoundary] Broadcast: task_count=Desired(3)
BroadcastExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_broadcast_datasource_as_build_child() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let physical_plan = sql_to_physical_plan(query, 1, 4).await;
assert_snapshot!(physical_plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
let annotated = sql_to_annotated_broadcast(query, 1, 4, true).await;
assert!(annotated.contains("Broadcast"));
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Desired(3)
CoalescePartitionsExec: task_count=Desired(3)
[NetworkBoundary] Broadcast: task_count=Desired(3)
BroadcastExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
");
}
#[tokio::test]
async fn test_broadcast_one_to_many() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let annotated =
sql_to_annotated_broadcast_with_estimator(query, 3, BuildSideOneTaskEstimator).await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Desired(3)
CoalescePartitionsExec: task_count=Desired(3)
[NetworkBoundary] Broadcast: task_count=Desired(3)
BroadcastExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
DataSourceExec: task_count=Desired(3)
");
}
#[tokio::test]
async fn test_broadcast_build_coalesce_caps_join_stage() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let annotated =
sql_to_annotated_broadcast_with_estimator(query, 3, BroadcastBuildCoalesceMaxEstimator)
.await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
[NetworkBoundary] Broadcast: task_count=Maximum(1)
BroadcastExec: task_count=Desired(1)
DataSourceExec: task_count=Desired(1)
DataSourceExec: task_count=Maximum(1)
");
}
#[tokio::test]
async fn test_broadcast_disabled_default() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated_broadcast(query, 4, 4, false).await;
assert!(!annotated.contains("Broadcast"));
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}
#[tokio::test]
async fn test_broadcast_multi_join_chain() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp", c."Rainfall"
FROM weather a
INNER JOIN weather b ON a."RainToday" = b."RainToday"
INNER JOIN weather c ON b."RainToday" = c."RainToday"
"#;
let annotated = sql_to_annotated_broadcast(query, 4, 4, true).await;
assert_snapshot!(annotated, @r"
HashJoinExec: task_count=Desired(3)
CoalescePartitionsExec: task_count=Desired(3)
[NetworkBoundary] Broadcast: task_count=Desired(3)
BroadcastExec: task_count=Desired(3)
HashJoinExec: task_count=Desired(3)
CoalescePartitionsExec: task_count=Desired(3)
[NetworkBoundary] Broadcast: task_count=Desired(3)
BroadcastExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}
#[tokio::test]
async fn test_broadcast_union_children_isolator_annotation() {
let query = r#"
SET distributed.children_isolator_unions = true;
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
UNION ALL
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
UNION ALL
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated_broadcast(query, 4, 4, true).await;
assert_snapshot!(annotated, @r"
ChildrenIsolatorUnionExec: task_count=Desired(4)
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
[NetworkBoundary] Broadcast: task_count=Maximum(1)
BroadcastExec: task_count=Desired(1)
DataSourceExec: task_count=Desired(1)
DataSourceExec: task_count=Maximum(1)
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
[NetworkBoundary] Broadcast: task_count=Maximum(1)
BroadcastExec: task_count=Desired(1)
DataSourceExec: task_count=Desired(1)
DataSourceExec: task_count=Maximum(1)
HashJoinExec: task_count=Maximum(2)
CoalescePartitionsExec: task_count=Maximum(2)
[NetworkBoundary] Broadcast: task_count=Maximum(2)
BroadcastExec: task_count=Desired(2)
DataSourceExec: task_count=Desired(2)
DataSourceExec: task_count=Maximum(2)
");
}
#[allow(clippy::type_complexity)]
struct CallbackEstimator {
f: Arc<dyn Fn(&dyn ExecutionPlan) -> Option<TaskEstimation> + Send + Sync>,
}
impl CallbackEstimator {
fn new<T: ExecutionPlan + 'static>(
f: impl Fn(&T) -> Option<TaskEstimation> + Send + Sync + 'static,
) -> Self {
let f = Arc::new(move |plan: &dyn ExecutionPlan| -> Option<TaskEstimation> {
if let Some(plan) = plan.as_any().downcast_ref::<T>() {
f(plan)
} else {
None
}
});
Self { f }
}
}
impl TaskEstimator for CallbackEstimator {
fn task_estimation(
&self,
plan: &Arc<dyn ExecutionPlan>,
_: &ConfigOptions,
) -> Option<TaskEstimation> {
(self.f)(plan.as_ref())
}
fn scale_up_leaf_node(
&self,
_: &Arc<dyn ExecutionPlan>,
_: usize,
_: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>> {
None
}
}
#[derive(Debug)]
struct BroadcastBuildCoalesceMaxEstimator;
impl TaskEstimator for BroadcastBuildCoalesceMaxEstimator {
fn task_estimation(
&self,
plan: &Arc<dyn ExecutionPlan>,
_: &ConfigOptions,
) -> Option<TaskEstimation> {
let coalesce = plan.as_any().downcast_ref::<CoalescePartitionsExec>()?;
if coalesce.input().as_any().is::<BroadcastExec>() {
Some(TaskEstimation::maximum(1))
} else {
None
}
}
fn scale_up_leaf_node(
&self,
_: &Arc<dyn ExecutionPlan>,
_: usize,
_: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>> {
None
}
}
async fn sql_to_annotated(query: &str) -> String {
annotate_test_plan(query, TestPlanOptions::default(), |b| b).await
}
async fn sql_to_annotated_broadcast(
query: &str,
target_partitions: usize,
num_workers: usize,
broadcast_enabled: bool,
) -> String {
let options = TestPlanOptions {
target_partitions,
num_workers,
broadcast_enabled,
};
annotate_test_plan(query, options, |b| b).await
}
async fn sql_to_annotated_with_estimator<T: ExecutionPlan + Send + Sync + 'static>(
query: &str,
estimator: impl Fn(&T) -> Option<TaskEstimation> + Send + Sync + 'static,
) -> String {
let options = TestPlanOptions::default();
annotate_test_plan(query, options, |b| {
b.with_distributed_task_estimator(CallbackEstimator::new(estimator))
})
.await
}
async fn sql_to_annotated_broadcast_with_estimator(
query: &str,
num_workers: usize,
estimator: impl TaskEstimator + Send + Sync + 'static,
) -> String {
let options = TestPlanOptions {
target_partitions: 4,
num_workers,
broadcast_enabled: true,
};
annotate_test_plan(query, options, |b| {
b.with_distributed_task_estimator(estimator)
})
.await
}
async fn annotate_test_plan(
query: &str,
options: TestPlanOptions,
configure: impl FnOnce(SessionStateBuilder) -> SessionStateBuilder,
) -> String {
let builder = base_session_builder(
options.target_partitions,
options.num_workers,
options.broadcast_enabled,
);
let builder = configure(builder);
let (ctx, query) = context_with_query(builder, query).await;
let df = ctx.sql(&query).await.unwrap();
let mut plan = df.create_physical_plan().await.unwrap();
plan = insert_broadcast_execs(plan, ctx.state_ref().read().config_options().as_ref())
.expect("failed to insert broadcasts");
let annotated = annotate_plan(plan, ctx.state_ref().read().config_options().as_ref())
.expect("failed to annotate plan");
format!("{annotated:?}")
}
}