use std::sync::Arc;
use datafusion::common::JoinType;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::joins::{
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
};
use crate::BroadcastExec;
use super::DistributedConfig;
pub(super) fn insert_broadcast_execs(
plan: Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let d_cfg = DistributedConfig::from_config_options(cfg)?;
if !d_cfg.broadcast_joins {
return Ok(plan);
}
plan.transform_down(|node| {
if !can_broadcast_left_input(node.as_ref()) {
return Ok(Transformed::no(node));
}
let children = node.children();
let Some(build_child) = children.first() else {
return Ok(Transformed::no(node));
};
let broadcast_input = build_child
.downcast_ref::<CoalescePartitionsExec>()
.map_or_else(
|| Arc::clone(build_child),
|coalesce| Arc::clone(coalesce.input()),
);
let broadcast: Arc<dyn ExecutionPlan> = Arc::new(BroadcastExec::new(broadcast_input, 1));
let new_build_child: Arc<dyn ExecutionPlan> =
Arc::new(CoalescePartitionsExec::new(broadcast));
let mut new_children: Vec<Arc<dyn ExecutionPlan>> = children.into_iter().cloned().collect();
new_children[0] = new_build_child;
Ok(Transformed::yes(node.with_new_children(new_children)?))
})
.map(|transformed| transformed.data)
}
fn can_broadcast_left_input(plan: &dyn ExecutionPlan) -> bool {
if let Some(hash_join) = plan.downcast_ref::<HashJoinExec>() {
return hash_join.partition_mode() == &PartitionMode::CollectLeft
&& is_left_broadcast_safe(hash_join.join_type());
}
if let Some(nested_loop_join) = plan.downcast_ref::<NestedLoopJoinExec>() {
return is_left_broadcast_safe(nested_loop_join.join_type());
}
plan.downcast_ref::<CrossJoinExec>().is_some()
}
fn is_left_broadcast_safe(join_type: &JoinType) -> bool {
matches!(
join_type,
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_snapshot;
use crate::test_utils::plans::TestPlanBuilder;
use datafusion::physical_plan::displayable;
#[tokio::test]
async fn test_insert_broadcast_with_existing_coalesce_build_child() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let physical_plan_string = TestPlanBuilder::default()
.num_workers(4)
.physical_plan_as_string(query)
.await;
assert_snapshot!(physical_plan_string, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
#[tokio::test]
async fn test_insert_broadcast_without_existing_coalesce_build_child() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let physical_plan_string = TestPlanBuilder::new()
.target_partitions(1)
.num_workers(4)
.physical_plan_as_string(query)
.await;
assert_snapshot!(physical_plan_string, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
let plan = sql_to_plan_with_broadcast(query, true, 1).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
#[tokio::test]
async fn test_no_broadcast_left_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a LEFT JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
#[tokio::test]
async fn test_no_broadcast_when_disabled() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let plan = sql_to_plan_with_broadcast(query, false, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
#[tokio::test]
async fn test_insert_broadcast_cross_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a CROSS JOIN weather b
"#;
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @"
CrossJoinExec
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp], file_type=parquet
");
}
#[tokio::test]
async fn test_insert_broadcast_nested_loop_inner_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a JOIN weather b ON a."MinTemp" > b."MaxTemp"
"#;
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @"
NestedLoopJoinExec: join_type=Inner, filter=MinTemp@0 > MaxTemp@1
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp], file_type=parquet
");
}
#[tokio::test]
async fn test_no_broadcast_nested_loop_left_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a LEFT JOIN weather b ON a."MinTemp" > b."MaxTemp"
"#;
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @"
NestedLoopJoinExec: join_type=Left, filter=MinTemp@0 > MaxTemp@1
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp], file_type=parquet
");
}
async fn sql_to_plan_with_broadcast(
query: &str,
broadcast_enabled: bool,
target_partitions: usize,
) -> String {
let test_plan = TestPlanBuilder::new()
.target_partitions(target_partitions)
.broadcast_joins(broadcast_enabled)
.build()
.await;
let ctx = test_plan.get_ctx();
let plan = test_plan.physical_plan(query).await;
let plan = insert_broadcast_execs(plan, ctx.state_ref().read().config_options().as_ref())
.expect("failed to insert broadcasts");
format!("{}", displayable(plan.as_ref()).indent(true))
}
}