use std::sync::Arc;
use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use crate::physical_plan::joins::{
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use arrow_schema::Schema;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_err, JoinSide, JoinType};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::sort_properties::SortProperties;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
#[derive(Default)]
pub struct JoinSelection {}
impl JoinSelection {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn should_swap_join_order(
left: &dyn ExecutionPlan,
right: &dyn ExecutionPlan,
) -> Result<bool> {
let left_stats = left.statistics()?;
let right_stats = right.statistics()?;
match (
left_stats.total_byte_size.get_value(),
right_stats.total_byte_size.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => match (
left_stats.num_rows.get_value(),
right_stats.num_rows.get_value(),
) {
(Some(l), Some(r)) => Ok(l > r),
_ => Ok(false),
},
}
}
fn supports_collect_by_thresholds(
plan: &dyn ExecutionPlan,
threshold_byte_size: usize,
threshold_num_rows: usize,
) -> bool {
let Ok(stats) = plan.statistics() else {
return false;
};
if let Some(byte_size) = stats.total_byte_size.get_value() {
*byte_size != 0 && *byte_size < threshold_byte_size
} else if let Some(num_rows) = stats.num_rows.get_value() {
*num_rows != 0 && *num_rows < threshold_num_rows
} else {
false
}
}
fn supports_swap(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
)
}
fn swap_join_type(join_type: JoinType) -> JoinType {
match join_type {
JoinType::Inner => JoinType::Inner,
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
JoinType::LeftSemi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::LeftSemi,
JoinType::LeftAnti => JoinType::RightAnti,
JoinType::RightAnti => JoinType::LeftAnti,
}
}
fn swap_join_projection(
left_schema_len: usize,
right_schema_len: usize,
projection: Option<&Vec<usize>>,
) -> Option<Vec<usize>> {
projection.map(|p| {
p.iter()
.map(|i| {
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
})
}
fn swap_hash_join(
hash_join: &HashJoinExec,
partition_mode: PartitionMode,
) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
let new_join = HashJoinExec::try_new(
Arc::clone(right),
Arc::clone(left),
hash_join
.on()
.iter()
.map(|(l, r)| (r.clone(), l.clone()))
.collect(),
swap_join_filter(hash_join.filter()),
&swap_join_type(*hash_join.join_type()),
swap_join_projection(
left.schema().fields().len(),
right.schema().fields().len(),
hash_join.projection.as_ref(),
),
partition_mode,
hash_join.null_equals_null(),
)?;
if matches!(
hash_join.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Ok(Arc::new(new_join))
} else {
let proj = ProjectionExec::try_new(
swap_reverting_projection(&left.schema(), &right.schema()),
Arc::new(new_join),
)?;
Ok(Arc::new(proj))
}
}
fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let new_filter = swap_join_filter(join.filter());
let new_join_type = &swap_join_type(*join.join_type());
let new_join = NestedLoopJoinExec::try_new(
Arc::clone(join.right()),
Arc::clone(join.left()),
new_filter,
new_join_type,
)?;
let plan: Arc<dyn ExecutionPlan> = if matches!(
join.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Arc::new(new_join)
} else {
let projection =
swap_reverting_projection(&join.left().schema(), &join.right().schema());
Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?)
};
Ok(plan)
}
fn swap_reverting_projection(
left_schema: &Schema,
right_schema: &Schema,
) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| {
(
Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
f.name().to_owned(),
)
});
let right_len = right_cols.len();
let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| {
(
Arc::new(Column::new(f.name(), right_len + i)) as Arc<dyn PhysicalExpr>,
f.name().to_owned(),
)
});
left_cols.chain(right_cols).collect()
}
fn swap_filter(filter: &JoinFilter) -> JoinFilter {
let column_indices = filter
.column_indices()
.iter()
.map(|idx| ColumnIndex {
index: idx.index,
side: idx.side.negate(),
})
.collect();
JoinFilter::new(
filter.expression().clone(),
column_indices,
filter.schema().clone(),
)
}
fn swap_join_filter(filter: Option<&JoinFilter>) -> Option<JoinFilter> {
filter.map(swap_filter)
}
impl PhysicalOptimizerRule for JoinSelection {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let subrules: Vec<Box<PipelineFixerSubrule>> = vec![
Box::new(hash_join_convert_symmetric_subrule),
Box::new(hash_join_swap_subrule),
];
let new_plan = plan
.transform_up(|p| apply_subrules(p, &subrules, config))
.data()?;
let config = &config.optimizer;
let collect_threshold_byte_size = config.hash_join_single_partition_threshold;
let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows;
new_plan
.transform_up(|plan| {
statistical_join_selection_subrule(
plan,
collect_threshold_byte_size,
collect_threshold_num_rows,
)
})
.data()
}
fn name(&self) -> &str {
"join_selection"
}
fn schema_check(&self) -> bool {
true
}
}
fn try_collect_left(
hash_join: &HashJoinExec,
ignore_threshold: bool,
threshold_byte_size: usize,
threshold_num_rows: usize,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let left = hash_join.left();
let right = hash_join.right();
let left_can_collect = ignore_threshold
|| supports_collect_by_thresholds(
&**left,
threshold_byte_size,
threshold_num_rows,
);
let right_can_collect = ignore_threshold
|| supports_collect_by_thresholds(
&**right,
threshold_byte_size,
threshold_num_rows,
);
match (left_can_collect, right_can_collect) {
(true, true) => {
if should_swap_join_order(&**left, &**right)?
&& supports_swap(*hash_join.join_type())
{
Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?))
} else {
Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.projection.clone(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?)))
}
}
(true, false) => Ok(Some(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.projection.clone(),
PartitionMode::CollectLeft,
hash_join.null_equals_null(),
)?))),
(false, true) => {
if supports_swap(*hash_join.join_type()) {
swap_hash_join(hash_join, PartitionMode::CollectLeft).map(Some)
} else {
Ok(None)
}
}
(false, false) => Ok(None),
}
}
fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let left = hash_join.left();
let right = hash_join.right();
if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type())
{
swap_hash_join(hash_join, PartitionMode::Partitioned)
} else {
Ok(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.projection.clone(),
PartitionMode::Partitioned,
hash_join.null_equals_null(),
)?))
}
}
fn statistical_join_selection_subrule(
plan: Arc<dyn ExecutionPlan>,
collect_threshold_byte_size: usize,
collect_threshold_num_rows: usize,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let transformed =
if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
match hash_join.partition_mode() {
PartitionMode::Auto => try_collect_left(
hash_join,
false,
collect_threshold_byte_size,
collect_threshold_num_rows,
)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if should_swap_join_order(&**left, &**right)?
&& supports_swap(*hash_join.join_type())
{
swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)?
} else {
None
}
}
}
} else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() {
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right)? {
let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left));
let proj: Arc<dyn ExecutionPlan> = Arc::new(ProjectionExec::try_new(
swap_reverting_projection(&left.schema(), &right.schema()),
Arc::new(new_join),
)?);
Some(proj)
} else {
None
}
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if should_swap_join_order(&**left, &**right)? {
swap_nl_join(nl_join).map(Some)?
} else {
None
}
} else {
None
};
Ok(if let Some(transformed) = transformed {
Transformed::yes(transformed)
} else {
Transformed::no(plan)
})
}
pub type PipelineFixerSubrule =
dyn Fn(Arc<dyn ExecutionPlan>, &ConfigOptions) -> Result<Arc<dyn ExecutionPlan>>;
fn hash_join_convert_symmetric_subrule(
input: Arc<dyn ExecutionPlan>,
config_options: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(hash_join) = input.as_any().downcast_ref::<HashJoinExec>() {
let left_unbounded = hash_join.left.execution_mode().is_unbounded();
let right_unbounded = hash_join.right.execution_mode().is_unbounded();
if left_unbounded && right_unbounded {
let mode = if config_options.optimizer.repartition_joins {
StreamJoinPartitionMode::Partitioned
} else {
StreamJoinPartitionMode::SinglePartition
};
let determine_order = |side: JoinSide| -> Option<Vec<PhysicalSortExpr>> {
hash_join
.filter()
.map(|filter| {
filter.column_indices().iter().any(
|ColumnIndex {
index,
side: column_side,
}| {
if *column_side != side {
return false;
}
let (equivalence, schema) = match side {
JoinSide::Left => (
hash_join.left().equivalence_properties(),
hash_join.left().schema(),
),
JoinSide::Right => (
hash_join.right().equivalence_properties(),
hash_join.right().schema(),
),
};
let name = schema.field(*index).name();
let col = Arc::new(Column::new(name, *index)) as _;
equivalence.get_expr_ordering(col).data
!= SortProperties::Unordered
},
)
})
.unwrap_or(false)
.then(|| {
match side {
JoinSide::Left => hash_join.left().output_ordering(),
JoinSide::Right => hash_join.right().output_ordering(),
}
.map(|p| p.to_vec())
})
.flatten()
};
let left_order = determine_order(JoinSide::Left);
let right_order = determine_order(JoinSide::Right);
return SymmetricHashJoinExec::try_new(
hash_join.left().clone(),
hash_join.right().clone(),
hash_join.on().to_vec(),
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.null_equals_null(),
left_order,
right_order,
mode,
)
.map(|exec| Arc::new(exec) as _);
}
}
Ok(input)
}
fn hash_join_swap_subrule(
mut input: Arc<dyn ExecutionPlan>,
_config_options: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(hash_join) = input.as_any().downcast_ref::<HashJoinExec>() {
if hash_join.left.execution_mode().is_unbounded()
&& !hash_join.right.execution_mode().is_unbounded()
&& matches!(
*hash_join.join_type(),
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
)
{
input = swap_join_according_to_unboundedness(hash_join)?;
}
}
Ok(input)
}
fn swap_join_according_to_unboundedness(
hash_join: &HashJoinExec,
) -> Result<Arc<dyn ExecutionPlan>> {
let partition_mode = hash_join.partition_mode();
let join_type = hash_join.join_type();
match (*partition_mode, *join_type) {
(
_,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
(PartitionMode::Partitioned, _) => {
swap_hash_join(hash_join, PartitionMode::Partitioned)
}
(PartitionMode::CollectLeft, _) => {
swap_hash_join(hash_join, PartitionMode::CollectLeft)
}
(PartitionMode::Auto, _) => {
internal_err!("Auto is not acceptable for unbounded input here.")
}
}
}
fn apply_subrules(
mut input: Arc<dyn ExecutionPlan>,
subrules: &Vec<Box<PipelineFixerSubrule>>,
config_options: &ConfigOptions,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
for subrule in subrules {
input = subrule(input, config_options)?;
}
Ok(Transformed::yes(input))
}
#[cfg(test)]
mod tests_statistical {
use super::*;
use crate::{
physical_plan::{displayable, ColumnStatistics, Statistics},
test::StatisticsExec,
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;
use datafusion_physical_expr::PhysicalExprRef;
use rstest::rstest;
fn empty_statistics() -> Statistics {
Statistics {
num_rows: Precision::Absent,
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics::new_unknown()],
}
}
fn get_thresholds() -> (usize, usize) {
let optimizer_options = ConfigOptions::new().optimizer;
(
optimizer_options.hash_join_single_partition_threshold_rows,
optimizer_options.hash_join_single_partition_threshold,
)
}
fn small_statistics() -> Statistics {
let (threshold_num_rows, threshold_byte_size) = get_thresholds();
Statistics {
num_rows: Precision::Inexact(threshold_num_rows / 128),
total_byte_size: Precision::Inexact(threshold_byte_size / 128),
column_statistics: vec![ColumnStatistics::new_unknown()],
}
}
fn big_statistics() -> Statistics {
let (threshold_num_rows, threshold_byte_size) = get_thresholds();
Statistics {
num_rows: Precision::Inexact(threshold_num_rows * 2),
total_byte_size: Precision::Inexact(threshold_byte_size * 2),
column_statistics: vec![ColumnStatistics::new_unknown()],
}
}
fn bigger_statistics() -> Statistics {
let (threshold_num_rows, threshold_byte_size) = get_thresholds();
Statistics {
num_rows: Precision::Inexact(threshold_num_rows * 4),
total_byte_size: Precision::Inexact(threshold_byte_size * 4),
column_statistics: vec![ColumnStatistics::new_unknown()],
}
}
fn create_big_and_small() -> (Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>) {
let big = Arc::new(StatisticsExec::new(
big_statistics(),
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
small_statistics(),
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
(big, small)
}
fn create_column_stats(
min: Option<u64>,
max: Option<u64>,
distinct_count: Option<usize>,
) -> Vec<ColumnStatistics> {
vec![ColumnStatistics {
distinct_count: distinct_count
.map(Precision::Inexact)
.unwrap_or(Precision::Absent),
min_value: min
.map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size))))
.unwrap_or(Precision::Absent),
max_value: max
.map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size))))
.unwrap_or(Precision::Absent),
..Default::default()
}]
}
fn nl_join_filter() -> Option<JoinFilter> {
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("big_col", DataType::Int32, false),
Field::new("small_col", DataType::Int32, false),
]);
let expression = Arc::new(BinaryExpr::new(
Arc::new(Column::new_with_schema("big_col", &intermediate_schema).unwrap()),
Operator::Gt,
Arc::new(Column::new_with_schema("small_col", &intermediate_schema).unwrap()),
)) as _;
Some(JoinFilter::new(
expression,
column_indices,
intermediate_schema,
))
}
fn create_nested_with_min_max() -> (
Arc<dyn ExecutionPlan>,
Arc<dyn ExecutionPlan>,
Arc<dyn ExecutionPlan>,
) {
let big = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(100_000),
column_statistics: create_column_stats(
Some(0),
Some(50_000),
Some(50_000),
),
total_byte_size: Precision::Absent,
},
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let medium = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(10_000),
column_statistics: create_column_stats(
Some(1000),
Some(5000),
Some(1000),
),
total_byte_size: Precision::Absent,
},
Schema::new(vec![Field::new("medium_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
Statistics {
num_rows: Precision::Inexact(1000),
column_statistics: create_column_stats(
Some(0),
Some(100_000),
Some(1000),
),
total_byte_size: Precision::Absent,
},
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
(big, medium, small)
}
#[tokio::test]
async fn test_join_with_swap() {
let (big, small) = create_big_and_small();
let join = Arc::new(
HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
)],
None,
&JoinType::Left,
None,
PartitionMode::CollectLeft,
false,
)
.unwrap(),
);
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");
assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 0);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
}
#[tokio::test]
async fn test_left_join_no_swap() {
let (big, small) = create_big_and_small();
let join = Arc::new(
HashJoinExec::try_new(
Arc::clone(&small),
Arc::clone(&big),
vec![(
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
)],
None,
&JoinType::Left,
None,
PartitionMode::CollectLeft,
false,
)
.unwrap(),
);
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
}
#[tokio::test]
async fn test_join_with_swap_semi() {
let join_types = [JoinType::LeftSemi, JoinType::LeftAnti];
for join_type in join_types {
let (big, small) = create_big_and_small();
let join = Arc::new(
HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Arc::new(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
),
Arc::new(
Column::new_with_schema("small_col", &small.schema())
.unwrap(),
),
)],
None,
&join_type,
None,
PartitionMode::Partitioned,
false,
)
.unwrap(),
);
let original_schema = join.schema();
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);
assert_eq!(swapped_join.schema().fields().len(), 1);
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
assert_eq!(original_schema, swapped_join.schema());
}
}
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
let expected_lines =
$EXPECTED_LINES.iter().map(|s| *s).collect::<Vec<&str>>();
let plan = Arc::new($PLAN);
let optimized = JoinSelection::new()
.optimize(plan.clone(), &ConfigOptions::new())
.unwrap();
let plan_string = displayable(optimized.as_ref()).indent(true).to_string();
let actual_lines = plan_string.split("\n").collect::<Vec<&str>>();
assert_eq!(
&expected_lines, &actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
#[tokio::test]
async fn test_nested_join_swap() {
let (big, medium, small) = create_nested_with_min_max();
let child_join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
)],
None,
&JoinType::Inner,
None,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let child_schema = child_join.schema();
let join = HashJoinExec::try_new(
Arc::clone(&medium),
Arc::new(child_join),
vec![(
Arc::new(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()),
)],
None,
&JoinType::Left,
None,
PartitionMode::CollectLeft,
false,
)
.unwrap();
let expected = [
"ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]",
" HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]",
" ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]",
" HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]",
" StatisticsExec: col_count=1, row_count=Inexact(1000)",
" StatisticsExec: col_count=1, row_count=Inexact(100000)",
" StatisticsExec: col_count=1, row_count=Inexact(10000)",
"",
];
assert_optimized!(expected, join);
}
#[tokio::test]
async fn test_join_no_swap() {
let (big, small) = create_big_and_small();
let join = Arc::new(
HashJoinExec::try_new(
Arc::clone(&small),
Arc::clone(&big),
vec![(
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
)],
None,
&JoinType::Inner,
None,
PartitionMode::CollectLeft,
false,
)
.unwrap(),
);
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
}
#[rstest(
join_type,
case::inner(JoinType::Inner),
case::left(JoinType::Left),
case::right(JoinType::Right),
case::full(JoinType::Full)
)]
#[tokio::test]
async fn test_nl_join_with_swap(join_type: JoinType) {
let (big, small) = create_big_and_small();
let join = Arc::new(
NestedLoopJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
nl_join_filter(),
&join_type,
)
.unwrap(),
);
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");
assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 0);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<NestedLoopJoinExec>()
.expect("The type of the plan should not be changed");
let swapped_filter = swapped_join.filter().unwrap();
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
let swapped_big_col_side = swapped_filter
.column_indices()
.get(swapped_big_col_idx)
.unwrap()
.side;
assert_eq!(
swapped_big_col_side,
JoinSide::Right,
"Filter column side should be swapped"
);
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
}
#[rstest(
join_type,
case::left_semi(JoinType::LeftSemi),
case::left_anti(JoinType::LeftAnti),
case::right_semi(JoinType::RightSemi),
case::right_anti(JoinType::RightAnti)
)]
#[tokio::test]
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {
let (big, small) = create_big_and_small();
let join = Arc::new(
NestedLoopJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
nl_join_filter(),
&join_type,
)
.unwrap(),
);
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
let swapped_join = optimized_join
.as_any()
.downcast_ref::<NestedLoopJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(
join.schema(),
swapped_join.schema(),
"Join schema should not be modified while optimization"
);
let swapped_filter = swapped_join.filter().unwrap();
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
let swapped_big_col_side = swapped_filter
.column_indices()
.get(swapped_big_col_idx)
.unwrap()
.side;
assert_eq!(
swapped_big_col_side,
JoinSide::Right,
"Filter column side should be swapped"
);
assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
}
#[tokio::test]
async fn test_swap_reverting_projection() {
let left_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
let proj = swap_reverting_projection(&left_schema, &right_schema);
assert_eq!(proj.len(), 3);
let (col, name) = &proj[0];
assert_eq!(name, "a");
assert_col_expr(col, "a", 1);
let (col, name) = &proj[1];
assert_eq!(name, "b");
assert_col_expr(col, "b", 2);
let (col, name) = &proj[2];
assert_eq!(name, "c");
assert_col_expr(col, "c", 0);
}
fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
let col = expr
.as_any()
.downcast_ref::<Column>()
.expect("Projection items should be Column expression");
assert_eq!(col.name(), name);
assert_eq!(col.index(), index);
}
#[tokio::test]
async fn test_join_selection_collect_left() {
let big = Arc::new(StatisticsExec::new(
big_statistics(),
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let small = Arc::new(StatisticsExec::new(
small_statistics(),
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
));
let empty = Arc::new(StatisticsExec::new(
empty_statistics(),
Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]),
));
let join_on = vec![(
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
big.clone(),
join_on,
false,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
big.clone(),
small.clone(),
join_on,
true,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
empty.clone(),
join_on,
false,
PartitionMode::CollectLeft,
);
let join_on = vec![(
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
small.clone(),
join_on,
true,
PartitionMode::CollectLeft,
);
}
#[tokio::test]
async fn test_join_selection_partitioned() {
let bigger = Arc::new(StatisticsExec::new(
bigger_statistics(),
Schema::new(vec![Field::new("bigger_col", DataType::Int32, false)]),
));
let big = Arc::new(StatisticsExec::new(
big_statistics(),
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
));
let empty = Arc::new(StatisticsExec::new(
empty_statistics(),
Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]),
));
let join_on = vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
as _,
)];
check_join_partition_mode(
big.clone(),
bigger.clone(),
join_on,
false,
PartitionMode::Partitioned,
);
let join_on = vec![(
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
bigger.clone(),
big.clone(),
join_on,
true,
PartitionMode::Partitioned,
);
let join_on = vec![(
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
big.clone(),
join_on,
false,
PartitionMode::Partitioned,
);
let join_on = vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(big, empty, join_on, false, PartitionMode::Partitioned);
}
fn check_join_partition_mode(
left: Arc<StatisticsExec>,
right: Arc<StatisticsExec>,
on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
is_swapped: bool,
expected_mode: PartitionMode,
) {
let join = Arc::new(
HashJoinExec::try_new(
left,
right,
on,
None,
&JoinType::Inner,
None,
PartitionMode::Auto,
false,
)
.unwrap(),
);
let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();
if !is_swapped {
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(*swapped_join.partition_mode(), expected_mode);
} else {
let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect(
"A proj is required to swap columns back to their original order",
);
let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("The type of the plan should not be changed");
assert_eq!(*swapped_join.partition_mode(), expected_mode);
}
}
}
#[cfg(test)]
mod util_tests {
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr};
use datafusion_physical_expr::intervals::utils::check_support;
use datafusion_physical_expr::PhysicalExpr;
#[test]
fn check_expr_supported() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let supported_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("a", 0)),
)) as Arc<dyn PhysicalExpr>;
assert!(check_support(&supported_expr, &schema));
let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
assert!(check_support(&supported_expr_2, &schema));
let unsupported_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(Column::new("a", 0)),
)) as Arc<dyn PhysicalExpr>;
assert!(!check_support(&unsupported_expr, &schema));
let unsupported_expr_2 = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))),
)) as Arc<dyn PhysicalExpr>;
assert!(!check_support(&unsupported_expr_2, &schema));
}
}
#[cfg(test)]
mod hash_join_tests {
use super::*;
use crate::physical_optimizer::test_utils::SourceType;
use crate::test_util::UnboundedExec;
use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
struct TestCase {
case: String,
initial_sources_unbounded: (SourceType, SourceType),
initial_join_type: JoinType,
initial_mode: PartitionMode,
expected_sources_unbounded: (SourceType, SourceType),
expected_join_type: JoinType,
expected_mode: PartitionMode,
expecting_swap: bool,
}
#[tokio::test]
async fn test_join_with_swap_full() -> Result<()> {
let cases = vec![
TestCase {
case: "Bounded - Unbounded 1".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Unbounded - Bounded 2".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Bounded - Bounded 3".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Unbounded - Unbounded 4".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
];
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_cases_without_collect_left_check() -> Result<()> {
let mut cases = vec![];
let join_types = vec![JoinType::LeftSemi, JoinType::Inner];
for join_type in join_types {
cases.push(TestCase {
case: "Unbounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::CollectLeft,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::Partitioned,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_not_support_collect_left() -> Result<()> {
let mut cases = vec![];
let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti];
for join_type in the_ones_not_support_collect_left {
cases.push(TestCase {
case: "Unbounded - Bounded".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::Partitioned,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> {
let mut cases = vec![];
let the_ones_not_support_collect_left =
vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi];
for join_type in the_ones_not_support_collect_left {
cases.push(TestCase {
case: "Unbounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> {
let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded;
let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded;
let left_exec = Arc::new(UnboundedExec::new(
(!left_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Int32,
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
let right_exec = Arc::new(UnboundedExec::new(
(!right_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"b",
DataType::Int32,
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
let join = Arc::new(HashJoinExec::try_new(
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
)],
None,
&t.initial_join_type,
None,
t.initial_mode,
false,
)?);
let optimized_join_plan =
hash_join_swap_subrule(join.clone(), &ConfigOptions::new())?;
let projection_added = optimized_join_plan.as_any().is::<ProjectionExec>();
let plan = if projection_added {
let proj = optimized_join_plan
.as_any()
.downcast_ref::<ProjectionExec>()
.expect(
"A proj is required to swap columns back to their original order",
);
proj.input().clone()
} else {
optimized_join_plan
};
if let Some(HashJoinExec {
left,
right,
join_type,
mode,
..
}) = plan.as_any().downcast_ref::<HashJoinExec>()
{
let left_changed = Arc::ptr_eq(left, &right_exec);
let right_changed = Arc::ptr_eq(right, &left_exec);
assert_eq!(left_changed, right_changed);
assert_eq!(
(
t.case.as_str(),
if left.execution_mode().is_unbounded() {
SourceType::Unbounded
} else {
SourceType::Bounded
},
if right.execution_mode().is_unbounded() {
SourceType::Unbounded
} else {
SourceType::Bounded
},
join_type,
mode,
left_changed && right_changed
),
(
t.case.as_str(),
t.expected_sources_unbounded.0,
t.expected_sources_unbounded.1,
&t.expected_join_type,
&t.expected_mode,
t.expecting_swap
)
);
};
Ok(())
}
}