use std::sync::Arc;
use arrow::datatypes::Schema;
use crate::config::ConfigOptions;
use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::joins::{
utils::{ColumnIndex, JoinFilter, JoinSide},
CrossJoinExec, HashJoinExec, PartitionMode,
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{ExecutionPlan, PhysicalExpr};
use super::optimizer::PhysicalOptimizerRule;
use crate::error::Result;
use datafusion_common::tree_node::{Transformed, TreeNode};
#[derive(Default)]
pub struct JoinSelection {}
impl JoinSelection {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -> bool {
let (left_size, right_size) = match (
left.statistics().total_byte_size,
right.statistics().total_byte_size,
) {
(Some(l), Some(r)) => (Some(l), Some(r)),
_ => (left.statistics().num_rows, right.statistics().num_rows),
};
match (left_size, right_size) {
(Some(l), Some(r)) => l > r,
_ => false,
}
}
fn supports_collect_by_size(
plan: &dyn ExecutionPlan,
collection_size_threshold: usize,
) -> bool {
if let Some(size) = plan.statistics().total_byte_size {
size != 0 && size < collection_size_threshold
} else if let Some(row_count) = plan.statistics().num_rows {
row_count != 0 && row_count < collection_size_threshold
} else {
false
}
}
pub fn supports_swap(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
)
}
pub fn swap_join_type(join_type: JoinType) -> JoinType {
match join_type {
JoinType::Inner => JoinType::Inner,
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
JoinType::LeftSemi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::LeftSemi,
JoinType::LeftAnti => JoinType::RightAnti,
JoinType::RightAnti => JoinType::LeftAnti,
}
}
pub fn swap_hash_join(
hash_join: &HashJoinExec,
partition_mode: PartitionMode,
) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
let new_join = HashJoinExec::try_new(
Arc::clone(right),
Arc::clone(left),
hash_join
.on()
.iter()
.map(|(l, r)| (r.clone(), l.clone()))
.collect(),
swap_join_filter(hash_join.filter()),
&swap_join_type(*hash_join.join_type()),
partition_mode,
hash_join.null_equals_null(),
)?;
if matches!(
hash_join.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Ok(Arc::new(new_join))
} else {
let proj = ProjectionExec::try_new(
swap_reverting_projection(&left.schema(), &right.schema()),
Arc::new(new_join),
)?;
Ok(Arc::new(proj))
}
}
pub fn swap_reverting_projection(
left_schema: &Schema,
right_schema: &Schema,
) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| {
(
Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
f.name().to_owned(),
)
});
let right_len = right_cols.len();
let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| {
(
Arc::new(Column::new(f.name(), right_len + i)) as Arc<dyn PhysicalExpr>,
f.name().to_owned(),
)
});
left_cols.chain(right_cols).collect()
}
fn swap_join_filter(filter: Option<&JoinFilter>) -> Option<JoinFilter> {
filter.map(|filter| {
let column_indices = filter
.column_indices()
.iter()
.map(|idx| {
let side = if matches!(idx.side, JoinSide::Left) {
JoinSide::Right
} else {
JoinSide::Left
};
ColumnIndex {
index: idx.index,
side,
}
})
.collect();
JoinFilter::new(
filter.expression().clone(),
column_indices,
filter.schema().clone(),
)
})
}
impl PhysicalOptimizerRule for JoinSelection {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let config = &config.optimizer;
let collect_left_threshold = config.hash_join_single_partition_threshold;
plan.transform_up(&|plan| {
let transformed = if let Some(hash_join) =
plan.as_any().downcast_ref::<HashJoinExec>()
{
match hash_join.partition_mode() {
PartitionMode::Auto => {
try_collect_left(hash_join, Some(collect_left_threshold))?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?
}
PartitionMode::CollectLeft => try_collect_left(hash_join, None)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if should_swap_join_order(&**left, &**right)
&& supports_swap(*hash_join.join_type())
{
swap_hash_join(hash_join, PartitionMode::Partitioned)
.map(Some)?
} else {
None
}
}
}
} else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>()
{
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right) {
let new_join =
CrossJoinExec::new(Arc::clone(right), Arc::clone(left));
let proj: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
swap_reverting_projection(&left.schema(), &right.schema()),
Arc::new(new_join),
)?);
Some(proj)
} else {
None
}
} else {
None
};
Ok(if let Some(transformed) = transformed {
Transformed::Yes(transformed)
} else {
Transformed::No(plan)
})
})
}
fn name(&self) -> &str {
"join_selection"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_collect_left(
hash_join: &HashJoinExec,
collect_threshold: Option<usize>,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let left = hash_join.left();
let right = hash_join.right();
let join_type = hash_join.join_type();
let left_can_collect = match join_type {
JoinType::Left | JoinType::Full | JoinType::LeftAnti => false,
JoinType::Inner
| JoinType::LeftSemi
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => collect_threshold.map_or(true, |threshold| {
supports_collect_by_size(&**left, threshold)
}),
};
let right_can_collect = match join_type {
JoinType::Right | JoinType::Full | JoinType::RightAnti => false,
JoinType::Inner
| JoinType::RightSemi
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti => collect_threshold.map_or(true, |threshold| {
supports_collect_by_size(&**right, threshold)
}),
};
match (left_can_collect, right_can_collect) {
(true, true) => {
if should_swap_join_order(&**left, &**right)
&& supports_swap(*hash_join.join_type())
{
Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?))
} else {
Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?)))
}
}
(true, false) => Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?))),
(false, true) => {
if supports_swap(*hash_join.join_type()) {
Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?))
} else {
Ok(None)
}
}
(false, false) => Ok(None),
}
}
fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
if should_swap_join_order(&**left, &**right) && supports_swap(*hash_join.join_type())
{
swap_hash_join(hash_join, PartitionMode::Partitioned)
} else {
Ok(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
PartitionMode::Partitioned,
hash_join.null_equals_null(),
)?))
}
}
#[cfg(test)]
mod tests {
use crate::{
physical_plan::{
displayable, joins::PartitionMode, ColumnStatistics, Statistics,
},
test::exec::StatisticsExec,
};
use super::*;
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
fn create_big_and_small() -> (Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>) {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(10),
total_byte_size: Some(100000),
..Default::default()
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(100000),
total_byte_size: Some(10),
..Default::default()
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
(big, small)
}
fn create_column_stats(
min: Option<u64>,
max: Option<u64>,
distinct_count: Option<usize>,
) -> Option<Vec<ColumnStatistics>> {
Some(vec![ColumnStatistics {
distinct_count,
min_value: min.map(|size| ScalarValue::UInt64(Some(size))),
max_value: max.map(|size| ScalarValue::UInt64(Some(size))),
..Default::default()
}])
}
fn create_nested_with_min_max() -> (
Arc<dyn ExecutionPlan>,
Arc<dyn ExecutionPlan>,
Arc<dyn ExecutionPlan>,
) {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(100_000),
column_statistics: create_column_stats(
Some(0),
Some(50_000),
Some(50_000),
),
..Default::default()
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let medium = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(10_000),
column_statistics: create_column_stats(
Some(1000),
Some(5000),
Some(1000),
),
..Default::default()
},
Schema::new(vec![Field::new("medium_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(1000),
column_statistics: create_column_stats(
Some(0),
Some(100_000),
Some(1000),
),
..Default::default()
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
(big, medium, small)
}
#[tokio::test]
async fn test_join_with_swap() {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&JoinType::Left,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");
assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 0);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
assert_eq!(
swapped_join.right().statistics().total_byte_size,
Some(100000)
);
}
#[tokio::test]
async fn test_left_join_with_swap() {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
)],
None,
&JoinType::Left,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");
assert_eq!(swapping_projection.expr().len(), 2);
println!("swapping_projection {swapping_projection:?}");
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 0);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().total_byte_size,
Some(100000)
);
assert_eq!(swapped_join.right().statistics().total_byte_size, Some(10));
}
#[tokio::test]
async fn test_join_with_swap_semi() {
let join_types = [JoinType::LeftSemi, JoinType::LeftAnti];
for join_type in join_types {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&join_type,
PartitionMode::Partitioned,
false,
)
.unwrap();
let original_schema = join.schema();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);
assert_eq!(swapped_join.schema().fields().len(), 1);
assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
assert_eq!(
swapped_join.right().statistics().total_byte_size,
Some(100000)
);
assert_eq!(original_schema, swapped_join.schema());
}
}
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
let expected_lines =
$EXPECTED_LINES.iter().map(|s| *s).collect::<Vec<&str>>();
let optimized = JoinSelection::new()
.optimize(Arc::new($PLAN), &ConfigOptions::new())
.unwrap();
let plan = displayable(optimized.as_ref()).indent(true).to_string();
let actual_lines = plan.split("\n").collect::<Vec<&str>>();
assert_eq!(
&expected_lines, &actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
#[tokio::test]
async fn test_nested_join_swap() {
let (big, medium, small) = create_nested_with_min_max();
let child_join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&JoinType::Inner,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let child_schema = child_join.schema();
let join = HashJoinExec::try_new(
Arc::clone(&medium),
Arc::new(child_join),
vec![(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
Column::new_with_schema("small_col", &child_schema).unwrap(),
)],
None,
&JoinType::Left,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let expected = [
"ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]",
" HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]",
" ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]",
" HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]",
" StatisticsExec: col_count=1, row_count=Some(1000)",
" StatisticsExec: col_count=1, row_count=Some(100000)",
" StatisticsExec: col_count=1, row_count=Some(10000)",
""
];
assert_optimized!(expected, join);
}
#[tokio::test]
async fn test_join_no_swap() {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
)],
None,
&JoinType::Inner,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
assert_eq!(
swapped_join.right().statistics().total_byte_size,
Some(100000)
);
}
#[tokio::test]
async fn test_swap_reverting_projection() {
let left_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
let proj = swap_reverting_projection(&left_schema, &right_schema);
assert_eq!(proj.len(), 3);
let (col, name) = &proj[0];
assert_eq!(name, "a");
assert_col_expr(col, "a", 1);
let (col, name) = &proj[1];
assert_eq!(name, "b");
assert_col_expr(col, "b", 2);
let (col, name) = &proj[2];
assert_eq!(name, "c");
assert_col_expr(col, "c", 0);
}
fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
let col = expr
.as_any()
.downcast_ref::<Column>()
.expect("Projection items should be Column expression");
assert_eq!(col.name(), name);
assert_eq!(col.index(), index);
}
#[tokio::test]
async fn test_join_selection_collect_left() {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(10000000),
total_byte_size: Some(10000000),
..Default::default()
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(10),
total_byte_size: Some(10),
..Default::default()
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
let empty = Arc::new(StatisticsExec::new(
Statistics {
num_rows: None,
total_byte_size: None,
..Default::default()
},
Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]),
));
let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
)];
check_join_partition_mode(
small.clone(),
big.clone(),
join_on,
false,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)];
check_join_partition_mode(
big,
small.clone(),
join_on,
true,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
)];
check_join_partition_mode(
small.clone(),
empty.clone(),
join_on,
false,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)];
check_join_partition_mode(
empty,
small,
join_on,
true,
PartitionMode::CollectLeft,
);
}
#[tokio::test]
async fn test_join_selection_partitioned() {
let big1 = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(10000000),
total_byte_size: Some(10000000),
..Default::default()
},
Schema::new(vec![Field::new("big_col1", DataType::Int32, false)]),
));
let big2 = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Some(20000000),
total_byte_size: Some(20000000),
..Default::default()
},
Schema::new(vec![Field::new("big_col2", DataType::Int32, false)]),
));
let empty = Arc::new(StatisticsExec::new(
Statistics {
num_rows: None,
total_byte_size: None,
..Default::default()
},
Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]),
));
let join_on = vec![(
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
Column::new_with_schema("big_col2", &big2.schema()).unwrap(),
)];
check_join_partition_mode(
big1.clone(),
big2.clone(),
join_on,
false,
PartitionMode::Partitioned,
);
let join_on = vec![(
Column::new_with_schema("big_col2", &big2.schema()).unwrap(),
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
)];
check_join_partition_mode(
big2,
big1.clone(),
join_on,
true,
PartitionMode::Partitioned,
);
let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
)];
check_join_partition_mode(
empty.clone(),
big1.clone(),
join_on,
false,
PartitionMode::Partitioned,
);
let join_on = vec![(
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
)];
check_join_partition_mode(
big1,
empty,
join_on,
false,
PartitionMode::Partitioned,
);
}
fn check_join_partition_mode(
left: Arc<StatisticsExec>,
right: Arc<StatisticsExec>,
on: Vec<(Column, Column)>,
is_swapped: bool,
expected_mode: PartitionMode,
) {
let join = HashJoinExec::try_new(
left,
right,
on,
None,
&JoinType::Inner,
PartitionMode::Auto,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
if !is_swapped {
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(*swapped_join.partition_mode(), expected_mode);
} else {
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect(
"A proj is required to swap columns back to their original order",
);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(*swapped_join.partition_mode(), expected_mode);
}
}
}