use std::sync::Arc;
use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use crate::physical_plan::joins::{
CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode,
SymmetricHashJoinExec,
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::ExecutionPlan;
use arrow_schema::Schema;
use datafusion_common::internal_err;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DataFusionError, JoinType};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
#[derive(Default)]
pub struct JoinSelection {}
impl JoinSelection {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn should_swap_join_order(
left: &dyn ExecutionPlan,
right: &dyn ExecutionPlan,
) -> Result<bool> {
let left_stats = left.statistics()?;
let right_stats = right.statistics()?;
match (
left_stats.total_byte_size.get_value(),
right_stats.total_byte_size.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => match (
left_stats.num_rows.get_value(),
right_stats.num_rows.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => Ok(false),
},
}
}
fn supports_collect_by_size(
plan: &dyn ExecutionPlan,
collection_size_threshold: usize,
) -> bool {
let Ok(stats) = plan.statistics() else {
return false;
};
if let Some(size) = stats.total_byte_size.get_value() {
*size != 0 && *size < collection_size_threshold
} else if let Some(row_count) = stats.num_rows.get_value() {
*row_count != 0 && *row_count < collection_size_threshold
} else {
false
}
}
fn supports_swap(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
)
}
fn swap_join_type(join_type: JoinType) -> JoinType {
match join_type {
JoinType::Inner => JoinType::Inner,
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
JoinType::LeftSemi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::LeftSemi,
JoinType::LeftAnti => JoinType::RightAnti,
JoinType::RightAnti => JoinType::LeftAnti,
}
}
fn swap_hash_join(
hash_join: &HashJoinExec,
partition_mode: PartitionMode,
) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
let new_join = HashJoinExec::try_new(
Arc::clone(right),
Arc::clone(left),
hash_join
.on()
.iter()
.map(|(l, r)| (r.clone(), l.clone()))
.collect(),
swap_join_filter(hash_join.filter()),
&swap_join_type(*hash_join.join_type()),
partition_mode,
hash_join.null_equals_null(),
)?;
if matches!(
hash_join.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Ok(Arc::new(new_join))
} else {
let proj = ProjectionExec::try_new(
swap_reverting_projection(&left.schema(), &right.schema()),
Arc::new(new_join),
)?;
Ok(Arc::new(proj))
}
}
fn swap_reverting_projection(
left_schema: &Schema,
right_schema: &Schema,
) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| {
(
Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
f.name().to_owned(),
)
});
let right_len = right_cols.len();
let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| {
(
Arc::new(Column::new(f.name(), right_len + i)) as Arc<dyn PhysicalExpr>,
f.name().to_owned(),
)
});
left_cols.chain(right_cols).collect()
}
fn swap_filter(filter: &JoinFilter) -> JoinFilter {
let column_indices = filter
.column_indices()
.iter()
.map(|idx| ColumnIndex {
index: idx.index,
side: idx.side.negate(),
})
.collect();
JoinFilter::new(
filter.expression().clone(),
column_indices,
filter.schema().clone(),
)
}
fn swap_join_filter(filter: Option<&JoinFilter>) -> Option<JoinFilter> {
filter.map(swap_filter)
}
impl PhysicalOptimizerRule for JoinSelection {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let pipeline = PipelineStatePropagator::new(plan);
let subrules: Vec<Box<PipelineFixerSubrule>> = vec![
Box::new(hash_join_convert_symmetric_subrule),
Box::new(hash_join_swap_subrule),
];
let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?;
let config = &config.optimizer;
let collect_left_threshold = config.hash_join_single_partition_threshold;
state.plan.transform_up(&|plan| {
statistical_join_selection_subrule(plan, collect_left_threshold)
})
}
fn name(&self) -> &str {
"join_selection"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_collect_left(
hash_join: &HashJoinExec,
collect_threshold: Option<usize>,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let left = hash_join.left();
let right = hash_join.right();
let join_type = hash_join.join_type();
let left_can_collect = match join_type {
JoinType::Left | JoinType::Full | JoinType::LeftAnti => false,
JoinType::Inner
| JoinType::LeftSemi
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => collect_threshold.map_or(true, |threshold| {
supports_collect_by_size(&**left, threshold)
}),
};
let right_can_collect = match join_type {
JoinType::Right | JoinType::Full | JoinType::RightAnti => false,
JoinType::Inner
| JoinType::RightSemi
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti => collect_threshold.map_or(true, |threshold| {
supports_collect_by_size(&**right, threshold)
}),
};
match (left_can_collect, right_can_collect) {
(true, true) => {
if should_swap_join_order(&**left, &**right)?
&& supports_swap(*hash_join.join_type())
{
Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?))
} else {
Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?)))
}
}
(true, false) => Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?))),
(false, true) => {
if supports_swap(*hash_join.join_type()) {
Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?))
} else {
Ok(None)
}
}
(false, false) => Ok(None),
}
}
fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type())
{
swap_hash_join(hash_join, PartitionMode::Partitioned)
} else {
Ok(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
PartitionMode::Partitioned,
hash_join.null_equals_null(),
)?))
}
}
fn statistical_join_selection_subrule(
plan: Arc<dyn ExecutionPlan>,
collect_left_threshold: usize,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let transformed = if let Some(hash_join) =
plan.as_any().downcast_ref::<HashJoinExec>()
{
match hash_join.partition_mode() {
PartitionMode::Auto => {
try_collect_left(hash_join, Some(collect_left_threshold))?.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?
}
PartitionMode::CollectLeft => try_collect_left(hash_join, None)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if should_swap_join_order(&**left, &**right)?
&& supports_swap(*hash_join.join_type())
{
swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)?
} else {
None
}
}
}
} else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() {
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right)? {
let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left));
let proj: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
swap_reverting_projection(&left.schema(), &right.schema()),
Arc::new(new_join),
)?);
Some(proj)
} else {
None
}
} else {
None
};
Ok(if let Some(transformed) = transformed {
Transformed::Yes(transformed)
} else {
Transformed::No(plan)
})
}
pub type PipelineFixerSubrule = dyn Fn(
PipelineStatePropagator,
&ConfigOptions,
) -> Option<Result<PipelineStatePropagator>>;
fn hash_join_convert_symmetric_subrule(
mut input: PipelineStatePropagator,
config_options: &ConfigOptions,
) -> Option<Result<PipelineStatePropagator>> {
if let Some(hash_join) = input.plan.as_any().downcast_ref::<HashJoinExec>() {
let ub_flags = &input.children_unbounded;
let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]);
input.unbounded = left_unbounded || right_unbounded;
let result = if left_unbounded && right_unbounded {
let mode = if config_options.optimizer.repartition_joins {
StreamJoinPartitionMode::Partitioned
} else {
StreamJoinPartitionMode::SinglePartition
};
SymmetricHashJoinExec::try_new(
hash_join.left().clone(),
hash_join.right().clone(),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.null_equals_null(),
mode,
)
.map(|exec| {
input.plan = Arc::new(exec) as _;
input
})
} else {
Ok(input)
};
Some(result)
} else {
None
}
}
fn hash_join_swap_subrule(
mut input: PipelineStatePropagator,
_config_options: &ConfigOptions,
) -> Option<Result<PipelineStatePropagator>> {
if let Some(hash_join) = input.plan.as_any().downcast_ref::<HashJoinExec>() {
let ub_flags = &input.children_unbounded;
let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]);
input.unbounded = left_unbounded || right_unbounded;
let result = if left_unbounded
&& !right_unbounded
&& matches!(
*hash_join.join_type(),
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
) {
swap_join_according_to_unboundedness(hash_join).map(|plan| {
input.plan = plan;
input
})
} else {
Ok(input)
};
Some(result)
} else {
None
}
}
fn swap_join_according_to_unboundedness(
hash_join: &HashJoinExec,
) -> Result<Arc<dyn ExecutionPlan>> {
let partition_mode = hash_join.partition_mode();
let join_type = hash_join.join_type();
match (*partition_mode, *join_type) {
(
_,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
(PartitionMode::Partitioned, _) => {
swap_hash_join(hash_join, PartitionMode::Partitioned)
}
(PartitionMode::CollectLeft, _) => {
swap_hash_join(hash_join, PartitionMode::CollectLeft)
}
(PartitionMode::Auto, _) => {
internal_err!("Auto is not acceptable for unbounded input here.")
}
}
}
fn apply_subrules(
mut input: PipelineStatePropagator,
subrules: &Vec<Box<PipelineFixerSubrule>>,
config_options: &ConfigOptions,
) -> Result<Transformed<PipelineStatePropagator>> {
for subrule in subrules {
if let Some(value) = subrule(input.clone(), config_options).transpose()? {
input = value;
}
}
let is_unbounded = input
.plan
.unbounded_output(&input.children_unbounded)
.unwrap_or(true);
input.unbounded = is_unbounded;
Ok(Transformed::Yes(input))
}
#[cfg(test)]
mod tests_statistical {
use std::sync::Arc;
use super::*;
use crate::{
physical_plan::{
displayable, joins::PartitionMode, ColumnStatistics, Statistics,
},
test::StatisticsExec,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
fn create_big_and_small() -> (Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>) {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(10),
total_byte_size: Precision::Inexact(100000),
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(100000),
total_byte_size: Precision::Inexact(10),
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
(big, small)
}
fn create_column_stats(
min: Option<u64>,
max: Option<u64>,
distinct_count: Option<usize>,
) -> Vec<ColumnStatistics> {
vec![ColumnStatistics {
distinct_count: distinct_count
.map(Precision::Inexact)
.unwrap_or(Precision::Absent),
min_value: min
.map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size))))
.unwrap_or(Precision::Absent),
max_value: max
.map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size))))
.unwrap_or(Precision::Absent),
..Default::default()
}]
}
fn create_nested_with_min_max() -> (
Arc<dyn ExecutionPlan>,
Arc<dyn ExecutionPlan>,
Arc<dyn ExecutionPlan>,
) {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(100_000),
column_statistics: create_column_stats(
Some(0),
Some(50_000),
Some(50_000),
),
total_byte_size: Precision::Absent,
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let medium = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(10_000),
column_statistics: create_column_stats(
Some(1000),
Some(5000),
Some(1000),
),
total_byte_size: Precision::Absent,
},
Schema::new(vec![Field::new("medium_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(1000),
column_statistics: create_column_stats(
Some(0),
Some(100_000),
Some(1000),
),
total_byte_size: Precision::Absent,
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
(big, medium, small)
}
#[tokio::test]
async fn test_join_with_swap() {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&JoinType::Left,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");
assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 0);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(10)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(100000)
);
}
#[tokio::test]
async fn test_left_join_with_swap() {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
)],
None,
&JoinType::Left,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");
assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 0);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(100000)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(10)
);
}
#[tokio::test]
async fn test_join_with_swap_semi() {
let join_types = [JoinType::LeftSemi, JoinType::LeftAnti];
for join_type in join_types {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&join_type,
PartitionMode::Partitioned,
false,
)
.unwrap();
let original_schema = join.schema();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);
assert_eq!(swapped_join.schema().fields().len(), 1);
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(10)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(100000)
);
assert_eq!(original_schema, swapped_join.schema());
}
}
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
let expected_lines =
$EXPECTED_LINES.iter().map(|s| *s).collect::<Vec<&str>>();
let optimized = JoinSelection::new()
.optimize(Arc::new($PLAN), &ConfigOptions::new())
.unwrap();
let plan = displayable(optimized.as_ref()).indent(true).to_string();
let actual_lines = plan.split("\n").collect::<Vec<&str>>();
assert_eq!(
&expected_lines, &actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
#[tokio::test]
async fn test_nested_join_swap() {
let (big, medium, small) = create_nested_with_min_max();
let child_join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&JoinType::Inner,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let child_schema = child_join.schema();
let join = HashJoinExec::try_new(
Arc::clone(&medium),
Arc::new(child_join),
vec![(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
Column::new_with_schema("small_col", &child_schema).unwrap(),
)],
None,
&JoinType::Left,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let expected = [
"ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]",
" HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]",
" ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]",
" HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]",
" StatisticsExec: col_count=1, row_count=Inexact(1000)",
" StatisticsExec: col_count=1, row_count=Inexact(100000)",
" StatisticsExec: col_count=1, row_count=Inexact(10000)",
"",
];
assert_optimized!(expected, join);
}
#[tokio::test]
async fn test_join_no_swap() {
let (big, small) = create_big_and_small();
let join = HashJoinExec::try_new(
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
)],
None,
&JoinType::Inner,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(10)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(100000)
);
}
#[tokio::test]
async fn test_swap_reverting_projection() {
let left_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
let proj = swap_reverting_projection(&left_schema, &right_schema);
assert_eq!(proj.len(), 3);
let (col, name) = &proj[0];
assert_eq!(name, "a");
assert_col_expr(col, "a", 1);
let (col, name) = &proj[1];
assert_eq!(name, "b");
assert_col_expr(col, "b", 2);
let (col, name) = &proj[2];
assert_eq!(name, "c");
assert_col_expr(col, "c", 0);
}
fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
let col = expr
.as_any()
.downcast_ref::<Column>()
.expect("Projection items should be Column expression");
assert_eq!(col.name(), name);
assert_eq!(col.index(), index);
}
#[tokio::test]
async fn test_join_selection_collect_left() {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(10000000),
total_byte_size: Precision::Inexact(10000000),
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(10),
total_byte_size: Precision::Inexact(10),
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
let empty = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Absent,
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]),
));
let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
)];
check_join_partition_mode(
small.clone(),
big.clone(),
join_on,
false,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)];
check_join_partition_mode(
big,
small.clone(),
join_on,
true,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
)];
check_join_partition_mode(
small.clone(),
empty.clone(),
join_on,
false,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)];
check_join_partition_mode(
empty,
small,
join_on,
true,
PartitionMode::CollectLeft,
);
}
#[tokio::test]
async fn test_join_selection_partitioned() {
let big1 = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(10000000),
total_byte_size: Precision::Inexact(10000000),
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("big_col1", DataType::Int32, false)]),
));
let big2 = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(20000000),
total_byte_size: Precision::Inexact(20000000),
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("big_col2", DataType::Int32, false)]),
));
let empty = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Absent,
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics::new_unknown()],
},
Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]),
));
let join_on = vec![(
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
Column::new_with_schema("big_col2", &big2.schema()).unwrap(),
)];
check_join_partition_mode(
big1.clone(),
big2.clone(),
join_on,
false,
PartitionMode::Partitioned,
);
let join_on = vec![(
Column::new_with_schema("big_col2", &big2.schema()).unwrap(),
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
)];
check_join_partition_mode(
big2,
big1.clone(),
join_on,
true,
PartitionMode::Partitioned,
);
let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
)];
check_join_partition_mode(
empty.clone(),
big1.clone(),
join_on,
false,
PartitionMode::Partitioned,
);
let join_on = vec![(
Column::new_with_schema("big_col1", &big1.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
)];
check_join_partition_mode(
big1,
empty,
join_on,
false,
PartitionMode::Partitioned,
);
}
fn check_join_partition_mode(
left: Arc<StatisticsExec>,
right: Arc<StatisticsExec>,
on: Vec<(Column, Column)>,
is_swapped: bool,
expected_mode: PartitionMode,
) {
let join = HashJoinExec::try_new(
left,
right,
on,
None,
&JoinType::Inner,
PartitionMode::Auto,
false,
)
.unwrap();
let optimized_join = JoinSelection::new()
.optimize(Arc::new(join), &ConfigOptions::new())
.unwrap();
if !is_swapped {
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(*swapped_join.partition_mode(), expected_mode);
} else {
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect(
"A proj is required to swap columns back to their original order",
);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(*swapped_join.partition_mode(), expected_mode);
}
}
}
#[cfg(test)]
mod util_tests {
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr};
use datafusion_physical_expr::intervals::utils::check_support;
use datafusion_physical_expr::PhysicalExpr;
#[test]
fn check_expr_supported() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let supported_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("a", 0)),
)) as Arc<dyn PhysicalExpr>;
assert!(check_support(&supported_expr, &schema));
let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
assert!(check_support(&supported_expr_2, &schema));
let unsupported_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(Column::new("a", 0)),
)) as Arc<dyn PhysicalExpr>;
assert!(!check_support(&unsupported_expr, &schema));
let unsupported_expr_2 = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))),
)) as Arc<dyn PhysicalExpr>;
assert!(!check_support(&unsupported_expr_2, &schema));
}
}
#[cfg(test)]
mod hash_join_tests {
use super::*;
use crate::physical_optimizer::join_selection::swap_join_type;
use crate::physical_optimizer::test_utils::SourceType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::joins::PartitionMode;
use crate::physical_plan::projection::ProjectionExec;
use crate::test_util::UnboundedExec;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::DataPtr;
use datafusion_common::JoinType;
use std::sync::Arc;
struct TestCase {
case: String,
initial_sources_unbounded: (SourceType, SourceType),
initial_join_type: JoinType,
initial_mode: PartitionMode,
expected_sources_unbounded: (SourceType, SourceType),
expected_join_type: JoinType,
expected_mode: PartitionMode,
expecting_swap: bool,
}
#[tokio::test]
async fn test_join_with_swap_full() -> Result<()> {
let cases = vec![
TestCase {
case: "Bounded - Unbounded 1".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Unbounded - Bounded 2".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Bounded - Bounded 3".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Unbounded - Unbounded 4".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
];
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_cases_without_collect_left_check() -> Result<()> {
let mut cases = vec![];
let join_types = vec![JoinType::LeftSemi, JoinType::Inner];
for join_type in join_types {
cases.push(TestCase {
case: "Unbounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::CollectLeft,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::Partitioned,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_not_support_collect_left() -> Result<()> {
let mut cases = vec![];
let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti];
for join_type in the_ones_not_support_collect_left {
cases.push(TestCase {
case: "Unbounded - Bounded".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::Partitioned,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> {
let mut cases = vec![];
let the_ones_not_support_collect_left =
vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi];
for join_type in the_ones_not_support_collect_left {
cases.push(TestCase {
case: "Unbounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> {
let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded;
let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded;
let left_exec = Arc::new(UnboundedExec::new(
(!left_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Int32,
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
let right_exec = Arc::new(UnboundedExec::new(
(!right_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"b",
DataType::Int32,
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
let join = HashJoinExec::try_new(
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
Column::new_with_schema("a", &left_exec.schema())?,
Column::new_with_schema("b", &right_exec.schema())?,
)],
None,
&t.initial_join_type,
t.initial_mode,
false,
)?;
let initial_hash_join_state = PipelineStatePropagator {
plan: Arc::new(join),
unbounded: false,
children_unbounded: vec![left_unbounded, right_unbounded],
};
let optimized_hash_join =
hash_join_swap_subrule(initial_hash_join_state, &ConfigOptions::new())
.unwrap()?;
let optimized_join_plan = optimized_hash_join.plan;
let projection_added = optimized_join_plan.as_any().is::<ProjectionExec>();
let plan = if projection_added {
let proj = optimized_join_plan
.as_any()
.downcast_ref::<ProjectionExec>()
.expect(
"A proj is required to swap columns back to their original order",
);
proj.input().clone()
} else {
optimized_join_plan
};
if let Some(HashJoinExec {
left,
right,
join_type,
mode,
..
}) = plan.as_any().downcast_ref::<HashJoinExec>()
{
let left_changed = Arc::data_ptr_eq(left, &right_exec);
let right_changed = Arc::data_ptr_eq(right, &left_exec);
assert_eq!(left_changed, right_changed);
assert_eq!(
(
t.case.as_str(),
if left.unbounded_output(&[])? {
SourceType::Unbounded
} else {
SourceType::Bounded
},
if right.unbounded_output(&[])? {
SourceType::Unbounded
} else {
SourceType::Bounded
},
join_type,
mode,
left_changed && right_changed
),
(
t.case.as_str(),
t.expected_sources_unbounded.0,
t.expected_sources_unbounded.1,
&t.expected_join_type,
&t.expected_mode,
t.expecting_swap
)
);
};
Ok(())
}
}