use std::sync::Arc;
use datafusion::common::JoinType;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use crate::BroadcastExec;
use super::DistributedConfig;
pub(super) fn insert_broadcast_execs(
plan: Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let d_cfg = DistributedConfig::from_config_options(cfg)?;
if !d_cfg.broadcast_joins {
return Ok(plan);
}
plan.transform_down(|node| {
let Some(hash_join) = node.as_any().downcast_ref::<HashJoinExec>() else {
return Ok(Transformed::no(node));
};
if hash_join.partition_mode() != &PartitionMode::CollectLeft {
return Ok(Transformed::no(node));
}
let join_type = hash_join.join_type();
if !matches!(
join_type,
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark
) {
return Ok(Transformed::no(node));
}
let children = node.children();
let Some(build_child) = children.first() else {
return Ok(Transformed::no(node));
};
let broadcast_input = if let Some(coalesce) = build_child
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
{
Arc::clone(coalesce.input())
} else {
Arc::clone(build_child)
};
let broadcast = Arc::new(BroadcastExec::new(
broadcast_input,
1, ));
let new_build_child: Arc<dyn ExecutionPlan> =
Arc::new(CoalescePartitionsExec::new(broadcast));
let mut new_children: Vec<Arc<dyn ExecutionPlan>> = children.into_iter().cloned().collect();
new_children[0] = new_build_child;
Ok(Transformed::yes(node.with_new_children(new_children)?))
})
.map(|transformed| transformed.data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_snapshot;
use crate::test_utils::plans::{
TestPlanOptions, base_session_builder, context_with_query, sql_to_physical_plan,
};
use datafusion::physical_plan::displayable;
#[tokio::test]
async fn test_insert_broadcast_with_existing_coalesce_build_child() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let physical_plan = sql_to_physical_plan(query, 4, 4).await;
assert_snapshot!(physical_plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
#[tokio::test]
async fn test_insert_broadcast_without_existing_coalesce_build_child() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let physical_plan = sql_to_physical_plan(query, 1, 4).await;
assert_snapshot!(physical_plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
let plan = sql_to_plan_with_broadcast(query, true, 1).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
#[tokio::test]
async fn test_no_broadcast_left_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a LEFT JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Left, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet
");
}
#[tokio::test]
async fn test_no_broadcast_when_disabled() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp"
FROM weather a INNER JOIN weather b
ON a."RainToday" = b."RainToday"
"#;
let plan = sql_to_plan_with_broadcast(query, false, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
async fn sql_to_plan_with_broadcast(
query: &str,
broadcast_enabled: bool,
target_partitions: usize,
) -> String {
let options = TestPlanOptions {
target_partitions,
num_workers: 4,
broadcast_enabled,
};
let builder = base_session_builder(
options.target_partitions,
options.num_workers,
options.broadcast_enabled,
);
let (ctx, query) = context_with_query(builder, query).await;
let df = ctx.sql(&query).await.unwrap();
let plan = df.create_physical_plan().await.unwrap();
let plan = insert_broadcast_execs(plan, ctx.state_ref().read().config_options().as_ref())
.expect("failed to insert broadcasts");
format!("{}", displayable(plan.as_ref()).indent(true))
}
}