use crate::PhysicalOptimizerRule;
use crate::optimizer::{ConfigOnlyContext, PhysicalOptimizerContext};
use datafusion_common::Statistics;
use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{JoinSide, JoinType, internal_err};
use datafusion_expr_common::sort_properties::SortProperties;
use datafusion_physical_expr::LexOrdering;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::execution_plan::EmissionType;
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::{
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use datafusion_physical_plan::operator_statistics::StatisticsRegistry;
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use std::sync::Arc;
#[derive(Default, Debug)]
pub struct JoinSelection {}
impl JoinSelection {
#[expect(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn get_stats(
plan: &dyn ExecutionPlan,
registry: Option<&StatisticsRegistry>,
) -> Result<Arc<Statistics>> {
if let Some(reg) = registry {
reg.compute(plan)
.map(|s| Arc::<Statistics>::clone(s.base_arc()))
} else {
plan.partition_statistics(None)
}
}
pub(crate) fn should_swap_join_order(
left: &dyn ExecutionPlan,
right: &dyn ExecutionPlan,
config: &ConfigOptions,
registry: Option<&StatisticsRegistry>,
) -> Result<bool> {
if !config.optimizer.join_reordering {
return Ok(false);
}
let left_stats = get_stats(left, registry)?;
let right_stats = get_stats(right, registry)?;
match (
left_stats.total_byte_size.get_value(),
right_stats.total_byte_size.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => match (
left_stats.num_rows.get_value(),
right_stats.num_rows.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => Ok(false),
},
}
}
fn supports_collect_by_thresholds(
plan: &dyn ExecutionPlan,
threshold_byte_size: usize,
threshold_num_rows: usize,
registry: Option<&StatisticsRegistry>,
) -> bool {
let Ok(stats) = get_stats(plan, registry) else {
return false;
};
if let Some(byte_size) = stats.total_byte_size.get_value() {
*byte_size < threshold_byte_size
} else if let Some(num_rows) = stats.num_rows.get_value() {
*num_rows < threshold_num_rows
} else {
false
}
}
impl PhysicalOptimizerRule for JoinSelection {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
self.optimize_with_context(plan, &ConfigOnlyContext::new(config))
}
fn optimize_with_context(
&self,
plan: Arc<dyn ExecutionPlan>,
context: &dyn PhysicalOptimizerContext,
) -> Result<Arc<dyn ExecutionPlan>> {
let config = context.config_options();
let mut default_registry = None;
let registry: Option<&StatisticsRegistry> =
if config.optimizer.use_statistics_registry {
Some(context.statistics_registry().unwrap_or_else(|| {
default_registry
.insert(StatisticsRegistry::default_with_builtin_providers())
}))
} else {
None
};
let subrules: Vec<Box<PipelineFixerSubrule>> = vec![
Box::new(hash_join_convert_symmetric_subrule),
Box::new(hash_join_swap_subrule),
];
let new_plan = plan
.transform_up(|p| apply_subrules(p, &subrules, config))
.data()?;
new_plan
.transform_up(|plan| {
statistical_join_selection_subrule(plan, config, registry)
})
.data()
}
fn name(&self) -> &str {
"join_selection"
}
fn schema_check(&self) -> bool {
true
}
}
pub(crate) fn try_collect_left(
hash_join: &HashJoinExec,
ignore_threshold: bool,
config: &ConfigOptions,
registry: Option<&StatisticsRegistry>,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let left = hash_join.left();
let right = hash_join.right();
let optimizer_config = &config.optimizer;
let left_can_collect = ignore_threshold
|| supports_collect_by_thresholds(
&**left,
optimizer_config.hash_join_single_partition_threshold,
optimizer_config.hash_join_single_partition_threshold_rows,
registry,
);
let right_can_collect = ignore_threshold
|| supports_collect_by_thresholds(
&**right,
optimizer_config.hash_join_single_partition_threshold,
optimizer_config.hash_join_single_partition_threshold_rows,
registry,
);
match (left_can_collect, right_can_collect) {
(true, true) => {
if hash_join.join_type().supports_swap()
&& !hash_join.null_aware
&& should_swap_join_order(&**left, &**right, config, registry)?
{
Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?))
} else {
Ok(Some(Arc::new(
hash_join
.builder()
.with_partition_mode(PartitionMode::CollectLeft)
.build()?,
)))
}
}
(true, false) => Ok(Some(Arc::new(
hash_join
.builder()
.with_partition_mode(PartitionMode::CollectLeft)
.build()?,
))),
(false, true) => {
if optimizer_config.join_reordering
&& hash_join.join_type().supports_swap()
&& !hash_join.null_aware
{
hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some)
} else {
Ok(None)
}
}
(false, false) => Ok(None),
}
}
pub(crate) fn partitioned_hash_join(
hash_join: &HashJoinExec,
config: &ConfigOptions,
registry: Option<&StatisticsRegistry>,
) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
if hash_join.join_type().supports_swap()
&& !hash_join.null_aware
&& should_swap_join_order(&**left, &**right, config, registry)?
{
hash_join.swap_inputs(PartitionMode::Partitioned)
} else {
let partition_mode = if hash_join.null_aware {
PartitionMode::CollectLeft
} else {
PartitionMode::Partitioned
};
Ok(Arc::new(
hash_join
.builder()
.with_partition_mode(partition_mode)
.build()?,
))
}
}
fn statistical_join_selection_subrule(
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
registry: Option<&StatisticsRegistry>,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let transformed = if let Some(hash_join) = plan.downcast_ref::<HashJoinExec>() {
match hash_join.partition_mode() {
PartitionMode::Auto => try_collect_left(hash_join, false, config, registry)?
.map_or_else(
|| partitioned_hash_join(hash_join, config, registry).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::CollectLeft => {
try_collect_left(hash_join, true, config, registry)?.map_or_else(
|| partitioned_hash_join(hash_join, config, registry).map(Some),
|v| Ok(Some(v)),
)?
}
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if hash_join.join_type().supports_swap()
&& !hash_join.null_aware
&& should_swap_join_order(&**left, &**right, config, registry)?
{
hash_join
.swap_inputs(PartitionMode::Partitioned)
.map(Some)?
} else {
None
}
}
}
} else if let Some(cross_join) = plan.downcast_ref::<CrossJoinExec>() {
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right, config, registry)? {
cross_join.swap_inputs().map(Some)?
} else {
None
}
} else if let Some(nl_join) = plan.downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if nl_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right, config, registry)?
{
nl_join.swap_inputs().map(Some)?
} else {
None
}
} else {
None
};
Ok(if let Some(transformed) = transformed {
Transformed::yes(transformed)
} else {
Transformed::no(plan)
})
}
pub type PipelineFixerSubrule =
dyn Fn(Arc<dyn ExecutionPlan>, &ConfigOptions) -> Result<Arc<dyn ExecutionPlan>>;
fn hash_join_convert_symmetric_subrule(
input: Arc<dyn ExecutionPlan>,
config_options: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(hash_join) = input.downcast_ref::<HashJoinExec>() {
let left_unbounded = hash_join.left.boundedness().is_unbounded();
let left_incremental = matches!(
hash_join.left.pipeline_behavior(),
EmissionType::Incremental | EmissionType::Both
);
let right_unbounded = hash_join.right.boundedness().is_unbounded();
let right_incremental = matches!(
hash_join.right.pipeline_behavior(),
EmissionType::Incremental | EmissionType::Both
);
if left_unbounded && right_unbounded & left_incremental & right_incremental {
let mode = if config_options.optimizer.repartition_joins {
StreamJoinPartitionMode::Partitioned
} else {
StreamJoinPartitionMode::SinglePartition
};
let determine_order = |side: JoinSide| -> Option<LexOrdering> {
hash_join
.filter()
.map(|filter| {
filter.column_indices().iter().any(
|ColumnIndex {
index,
side: column_side,
}| {
if *column_side != side {
return false;
}
let (equivalence, schema) = match side {
JoinSide::Left => (
hash_join.left().equivalence_properties(),
hash_join.left().schema(),
),
JoinSide::Right => (
hash_join.right().equivalence_properties(),
hash_join.right().schema(),
),
JoinSide::None => return false,
};
let name = schema.field(*index).name();
let col = Arc::new(Column::new(name, *index)) as _;
equivalence.get_expr_properties(col).sort_properties
!= SortProperties::Unordered
},
)
})
.unwrap_or(false)
.then(|| {
match side {
JoinSide::Left => hash_join.left().output_ordering(),
JoinSide::Right => hash_join.right().output_ordering(),
JoinSide::None => unreachable!(),
}
.cloned()
})
.flatten()
};
let left_order = determine_order(JoinSide::Left);
let right_order = determine_order(JoinSide::Right);
return SymmetricHashJoinExec::try_new(
Arc::clone(hash_join.left()),
Arc::clone(hash_join.right()),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.null_equality(),
left_order,
right_order,
mode,
)
.map(|exec| Arc::new(exec) as _);
}
}
Ok(input)
}
pub fn hash_join_swap_subrule(
mut input: Arc<dyn ExecutionPlan>,
_config_options: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(hash_join) = input.downcast_ref::<HashJoinExec>()
&& hash_join.left.boundedness().is_unbounded()
&& !hash_join.right.boundedness().is_unbounded()
&& !hash_join.null_aware && matches!(
*hash_join.join_type(),
JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti
)
{
input = swap_join_according_to_unboundedness(hash_join)?;
}
Ok(input)
}
pub(crate) fn swap_join_according_to_unboundedness(
hash_join: &HashJoinExec,
) -> Result<Arc<dyn ExecutionPlan>> {
let partition_mode = hash_join.partition_mode();
let join_type = hash_join.join_type();
match (*partition_mode, *join_type) {
(
_,
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark
| JoinType::Full,
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
(PartitionMode::Partitioned, _) => {
hash_join.swap_inputs(PartitionMode::Partitioned)
}
(PartitionMode::CollectLeft, _) => {
hash_join.swap_inputs(PartitionMode::CollectLeft)
}
(PartitionMode::Auto, _) => {
hash_join.swap_inputs(PartitionMode::Partitioned)
}
}
}
fn apply_subrules(
mut input: Arc<dyn ExecutionPlan>,
subrules: &Vec<Box<PipelineFixerSubrule>>,
config_options: &ConfigOptions,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let original = Arc::clone(&input);
for subrule in subrules {
input = subrule(input, config_options)?;
}
let transformed = !Arc::ptr_eq(&original, &input);
Ok(Transformed::new_transformed(input, transformed))
}