use std::sync::Arc;
use crate::PhysicalOptimizerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_err, JoinSide, JoinType};
use datafusion_expr_common::sort_properties::SortProperties;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::LexOrdering;
use datafusion_physical_plan::execution_plan::EmissionType;
use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use datafusion_physical_plan::joins::{
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
#[derive(Default, Debug)]
pub struct JoinSelection {}
impl JoinSelection {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
pub(crate) fn should_swap_join_order(
left: &dyn ExecutionPlan,
right: &dyn ExecutionPlan,
) -> Result<bool> {
let left_stats = left.statistics()?;
let right_stats = right.statistics()?;
match (
left_stats.total_byte_size.get_value(),
right_stats.total_byte_size.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => match (
left_stats.num_rows.get_value(),
right_stats.num_rows.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => Ok(false),
},
}
}
fn supports_collect_by_thresholds(
plan: &dyn ExecutionPlan,
threshold_byte_size: usize,
threshold_num_rows: usize,
) -> bool {
let Ok(stats) = plan.statistics() else {
return false;
};
if let Some(byte_size) = stats.total_byte_size.get_value() {
*byte_size != 0 && *byte_size < threshold_byte_size
} else if let Some(num_rows) = stats.num_rows.get_value() {
*num_rows != 0 && *num_rows < threshold_num_rows
} else {
false
}
}
#[deprecated(since = "45.0.0", note = "use JoinType::supports_swap instead")]
#[allow(dead_code)]
pub(crate) fn supports_swap(join_type: JoinType) -> bool {
join_type.supports_swap()
}
#[deprecated(since = "45.0.0", note = "use datafusion-functions-nested instead")]
#[allow(dead_code)]
pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType {
join_type.swap()
}
#[deprecated(since = "45.0.0", note = "use HashJoinExec::swap_inputs instead")]
pub fn swap_hash_join(
hash_join: &HashJoinExec,
partition_mode: PartitionMode,
) -> Result<Arc<dyn ExecutionPlan>> {
hash_join.swap_inputs(partition_mode)
}
#[deprecated(since = "45.0.0", note = "use NestedLoopJoinExec::swap_inputs")]
#[allow(dead_code)]
pub(crate) fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
join.swap_inputs()
}
#[deprecated(since = "45.0.0", note = "use filter.map(JoinFilter::swap) instead")]
#[allow(dead_code)]
fn swap_join_filter(filter: Option<&JoinFilter>) -> Option<JoinFilter> {
filter.map(JoinFilter::swap)
}
#[deprecated(since = "45.0.0", note = "use JoinFilter::swap instead")]
#[allow(dead_code)]
pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter {
filter.swap()
}
impl PhysicalOptimizerRule for JoinSelection {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let subrules: Vec<Box<PipelineFixerSubrule>> = vec![
Box::new(hash_join_convert_symmetric_subrule),
Box::new(hash_join_swap_subrule),
];
let new_plan = plan
.transform_up(|p| apply_subrules(p, &subrules, config))
.data()?;
let config = &config.optimizer;
let collect_threshold_byte_size = config.hash_join_single_partition_threshold;
let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows;
new_plan
.transform_up(|plan| {
statistical_join_selection_subrule(
plan,
collect_threshold_byte_size,
collect_threshold_num_rows,
)
})
.data()
}
fn name(&self) -> &str {
"join_selection"
}
fn schema_check(&self) -> bool {
true
}
}
pub(crate) fn try_collect_left(
hash_join: &HashJoinExec,
ignore_threshold: bool,
threshold_byte_size: usize,
threshold_num_rows: usize,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let left = hash_join.left();
let right = hash_join.right();
let left_can_collect = ignore_threshold
|| supports_collect_by_thresholds(
&**left,
threshold_byte_size,
threshold_num_rows,
);
let right_can_collect = ignore_threshold
|| supports_collect_by_thresholds(
&**right,
threshold_byte_size,
threshold_num_rows,
);
match (left_can_collect, right_can_collect) {
(true, true) => {
if hash_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?))
} else {
Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.projection.clone(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?)))
}
}
(true, false) => Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.projection.clone(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?))),
(false, true) => {
if hash_join.join_type().supports_swap() {
hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some)
} else {
Ok(None)
}
}
(false, false) => Ok(None),
}
}
pub(crate) fn partitioned_hash_join(
hash_join: &HashJoinExec,
) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)?
{
hash_join.swap_inputs(PartitionMode::Partitioned)
} else {
Ok(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.projection.clone(),
PartitionMode::Partitioned,
hash_join.null_equals_null(),
)?))
}
}
fn statistical_join_selection_subrule(
plan: Arc<dyn ExecutionPlan>,
collect_threshold_byte_size: usize,
collect_threshold_num_rows: usize,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let transformed =
if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
match hash_join.partition_mode() {
PartitionMode::Auto => try_collect_left(
hash_join,
false,
collect_threshold_byte_size,
collect_threshold_num_rows,
)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if hash_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
hash_join
.swap_inputs(PartitionMode::Partitioned)
.map(Some)?
} else {
None
}
}
}
} else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() {
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right)? {
cross_join.swap_inputs().map(Some)?
} else {
None
}
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if nl_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
nl_join.swap_inputs().map(Some)?
} else {
None
}
} else {
None
};
Ok(if let Some(transformed) = transformed {
Transformed::yes(transformed)
} else {
Transformed::no(plan)
})
}
pub type PipelineFixerSubrule =
dyn Fn(Arc<dyn ExecutionPlan>, &ConfigOptions) -> Result<Arc<dyn ExecutionPlan>>;
fn hash_join_convert_symmetric_subrule(
input: Arc<dyn ExecutionPlan>,
config_options: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(hash_join) = input.as_any().downcast_ref::<HashJoinExec>() {
let left_unbounded = hash_join.left.boundedness().is_unbounded();
let left_incremental = matches!(
hash_join.left.pipeline_behavior(),
EmissionType::Incremental | EmissionType::Both
);
let right_unbounded = hash_join.right.boundedness().is_unbounded();
let right_incremental = matches!(
hash_join.right.pipeline_behavior(),
EmissionType::Incremental | EmissionType::Both
);
if left_unbounded && right_unbounded & left_incremental & right_incremental {
let mode = if config_options.optimizer.repartition_joins {
StreamJoinPartitionMode::Partitioned
} else {
StreamJoinPartitionMode::SinglePartition
};
let determine_order = |side: JoinSide| -> Option<LexOrdering> {
hash_join
.filter()
.map(|filter| {
filter.column_indices().iter().any(
|ColumnIndex {
index,
side: column_side,
}| {
if *column_side != side {
return false;
}
let (equivalence, schema) = match side {
JoinSide::Left => (
hash_join.left().equivalence_properties(),
hash_join.left().schema(),
),
JoinSide::Right => (
hash_join.right().equivalence_properties(),
hash_join.right().schema(),
),
JoinSide::None => return false,
};
let name = schema.field(*index).name();
let col = Arc::new(Column::new(name, *index)) as _;
equivalence.get_expr_properties(col).sort_properties
!= SortProperties::Unordered
},
)
})
.unwrap_or(false)
.then(|| {
match side {
JoinSide::Left => hash_join.left().output_ordering(),
JoinSide::Right => hash_join.right().output_ordering(),
JoinSide::None => unreachable!(),
}
.map(|p| LexOrdering::new(p.to_vec()))
})
.flatten()
};
let left_order = determine_order(JoinSide::Left);
let right_order = determine_order(JoinSide::Right);
return SymmetricHashJoinExec::try_new(
Arc::clone(hash_join.left()),
Arc::clone(hash_join.right()),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.null_equals_null(),
left_order,
right_order,
mode,
)
.map(|exec| Arc::new(exec) as _);
}
}
Ok(input)
}
pub fn hash_join_swap_subrule(
mut input: Arc<dyn ExecutionPlan>,
_config_options: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(hash_join) = input.as_any().downcast_ref::<HashJoinExec>() {
if hash_join.left.boundedness().is_unbounded()
&& !hash_join.right.boundedness().is_unbounded()
&& matches!(
*hash_join.join_type(),
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
)
{
input = swap_join_according_to_unboundedness(hash_join)?;
}
}
Ok(input)
}
pub(crate) fn swap_join_according_to_unboundedness(
hash_join: &HashJoinExec,
) -> Result<Arc<dyn ExecutionPlan>> {
let partition_mode = hash_join.partition_mode();
let join_type = hash_join.join_type();
match (*partition_mode, *join_type) {
(
_,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
(PartitionMode::Partitioned, _) => {
hash_join.swap_inputs(PartitionMode::Partitioned)
}
(PartitionMode::CollectLeft, _) => {
hash_join.swap_inputs(PartitionMode::CollectLeft)
}
(PartitionMode::Auto, _) => {
internal_err!("Auto is not acceptable for unbounded input here.")
}
}
}
fn apply_subrules(
mut input: Arc<dyn ExecutionPlan>,
subrules: &Vec<Box<PipelineFixerSubrule>>,
config_options: &ConfigOptions,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
for subrule in subrules {
input = subrule(input, config_options)?;
}
Ok(Transformed::yes(input))
}