use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::join_selection::swap_hash_join;
use crate::physical_optimizer::pipeline_checker::{
check_finiteness_requirements, PipelineStatePropagator,
};
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::joins::utils::JoinSide;
use crate::physical_plan::joins::{
convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode,
SymmetricHashJoinExec,
};
use crate::physical_plan::rewrite::TreeNodeRewritable;
use crate::physical_plan::ExecutionPlan;
use datafusion_common::DataFusionError;
use datafusion_expr::logical_plan::JoinType;
use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal};
use datafusion_physical_expr::intervals::{is_datatype_supported, is_operator_supported};
use datafusion_physical_expr::PhysicalExpr;
use std::sync::Arc;
#[derive(Default)]
pub struct PipelineFixer {}
impl PipelineFixer {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
type PipelineFixerSubrule =
dyn Fn(PipelineStatePropagator) -> Option<Result<PipelineStatePropagator>>;
impl PhysicalOptimizerRule for PipelineFixer {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let pipeline = PipelineStatePropagator::new(plan);
let physical_optimizer_subrules: Vec<Box<PipelineFixerSubrule>> = vec![
Box::new(hash_join_convert_symmetric_subrule),
Box::new(hash_join_swap_subrule),
];
let state = pipeline.transform_up(&|p| {
apply_subrules_and_check_finiteness_requirements(
p,
&physical_optimizer_subrules,
)
})?;
Ok(state.plan)
}
fn name(&self) -> &str {
"PipelineFixer"
}
fn schema_check(&self) -> bool {
true
}
}
fn check_support(expr: &Arc<dyn PhysicalExpr>) -> bool {
let expr_any = expr.as_any();
let expr_supported = if let Some(binary_expr) = expr_any.downcast_ref::<BinaryExpr>()
{
is_operator_supported(binary_expr.op())
} else {
expr_any.is::<Column>() || expr_any.is::<Literal>() || expr_any.is::<CastExpr>()
};
expr_supported && expr.children().iter().all(check_support)
}
fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result<bool> {
if let Some(filter) = hash_join.filter() {
let left = hash_join.left();
if let Some(left_ordering) = left.output_ordering() {
let right = hash_join.right();
if let Some(right_ordering) = right.output_ordering() {
let expr_supported = check_support(filter.expression());
let left_convertible = convert_sort_expr_with_filter_schema(
&JoinSide::Left,
filter,
&left.schema(),
&left_ordering[0],
)?
.is_some();
let right_convertible = convert_sort_expr_with_filter_schema(
&JoinSide::Right,
filter,
&right.schema(),
&right_ordering[0],
)?
.is_some();
let fields_supported = filter
.schema()
.fields()
.iter()
.all(|f| is_datatype_supported(f.data_type()));
return Ok(expr_supported
&& fields_supported
&& left_convertible
&& right_convertible);
}
}
}
Ok(false)
}
fn hash_join_convert_symmetric_subrule(
input: PipelineStatePropagator,
) -> Option<Result<PipelineStatePropagator>> {
let plan = input.plan;
if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
let ub_flags = input.children_unbounded;
let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]);
let new_plan = if left_unbounded && right_unbounded {
match is_suitable_for_symmetric_hash_join(hash_join) {
Ok(true) => SymmetricHashJoinExec::try_new(
hash_join.left().clone(),
hash_join.right().clone(),
hash_join
.on()
.iter()
.map(|(l, r)| (l.clone(), r.clone()))
.collect(),
hash_join.filter().unwrap().clone(),
hash_join.join_type(),
hash_join.null_equals_null(),
)
.map(|e| Arc::new(e) as _),
Ok(false) => Ok(plan),
Err(e) => return Some(Err(e)),
}
} else {
Ok(plan)
};
Some(new_plan.map(|plan| PipelineStatePropagator {
plan,
unbounded: left_unbounded || right_unbounded,
children_unbounded: ub_flags,
}))
} else {
None
}
}
fn hash_join_swap_subrule(
input: PipelineStatePropagator,
) -> Option<Result<PipelineStatePropagator>> {
let plan = input.plan;
if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
let ub_flags = input.children_unbounded;
let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]);
let new_plan = if left_unbounded && !right_unbounded {
if matches!(
*hash_join.join_type(),
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
) {
swap(hash_join)
} else {
Ok(plan)
}
} else {
Ok(plan)
};
Some(new_plan.map(|plan| PipelineStatePropagator {
plan,
unbounded: left_unbounded || right_unbounded,
children_unbounded: ub_flags,
}))
} else {
None
}
}
fn swap(hash_join: &HashJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let partition_mode = hash_join.partition_mode();
let join_type = hash_join.join_type();
match (*partition_mode, *join_type) {
(
_,
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
) => Err(DataFusionError::Internal(format!(
"{join_type} join cannot be swapped for unbounded input."
))),
(PartitionMode::Partitioned, _) => {
swap_hash_join(hash_join, PartitionMode::Partitioned)
}
(PartitionMode::CollectLeft, _) => {
swap_hash_join(hash_join, PartitionMode::CollectLeft)
}
(PartitionMode::Auto, _) => Err(DataFusionError::Internal(
"Auto is not acceptable for unbounded input here.".to_string(),
)),
}
}
fn apply_subrules_and_check_finiteness_requirements(
mut input: PipelineStatePropagator,
physical_optimizer_subrules: &Vec<Box<PipelineFixerSubrule>>,
) -> Result<Option<PipelineStatePropagator>> {
for sub_rule in physical_optimizer_subrules {
if let Some(value) = sub_rule(input.clone()).transpose()? {
input = value;
}
}
check_finiteness_requirements(input)
}
#[cfg(test)]
mod util_tests {
use crate::physical_optimizer::pipeline_fixer::check_support;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr};
use datafusion_physical_expr::PhysicalExpr;
use std::sync::Arc;
#[test]
fn check_expr_supported() {
let supported_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("a", 0)),
)) as Arc<dyn PhysicalExpr>;
assert!(check_support(&supported_expr));
let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
assert!(check_support(&supported_expr_2));
let unsupported_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(Column::new("a", 0)),
)) as Arc<dyn PhysicalExpr>;
assert!(!check_support(&unsupported_expr));
let unsupported_expr_2 = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))),
)) as Arc<dyn PhysicalExpr>;
assert!(!check_support(&unsupported_expr_2));
}
}
#[cfg(test)]
mod hash_join_tests {
use super::*;
use crate::physical_optimizer::join_selection::swap_join_type;
use crate::physical_optimizer::test_utils::SourceType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::joins::PartitionMode;
use crate::physical_plan::projection::ProjectionExec;
use crate::test_util::UnboundedExec;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use std::sync::Arc;
struct TestCase {
case: String,
initial_sources_unbounded: (SourceType, SourceType),
initial_join_type: JoinType,
initial_mode: PartitionMode,
expected_sources_unbounded: (SourceType, SourceType),
expected_join_type: JoinType,
expected_mode: PartitionMode,
expecting_swap: bool,
}
#[tokio::test]
async fn test_join_with_swap_full() -> Result<()> {
let cases = vec![
TestCase {
case: "Bounded - Unbounded 1".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Unbounded - Bounded 2".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Bounded - Bounded 3".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
TestCase {
case: "Unbounded - Unbounded 4".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: JoinType::Full,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: JoinType::Full,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
},
];
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_cases_without_collect_left_check() -> Result<()> {
let mut cases = vec![];
let join_types = vec![JoinType::LeftSemi, JoinType::Inner];
for join_type in join_types {
cases.push(TestCase {
case: "Unbounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::CollectLeft,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::Partitioned,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_not_support_collect_left() -> Result<()> {
let mut cases = vec![];
let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti];
for join_type in the_ones_not_support_collect_left {
cases.push(TestCase {
case: "Unbounded - Bounded".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: swap_join_type(join_type),
expected_mode: PartitionMode::Partitioned,
expecting_swap: true,
});
cases.push(TestCase {
case: "Bounded - Unbounded".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[tokio::test]
async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> {
let mut cases = vec![];
let the_ones_not_support_collect_left =
vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi];
for join_type in the_ones_not_support_collect_left {
cases.push(TestCase {
case: "Unbounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / CollectLeft".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::CollectLeft,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::CollectLeft,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Bounded - Bounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
cases.push(TestCase {
case: "Unbounded - Unbounded / Partitioned".to_string(),
initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded),
initial_join_type: join_type,
initial_mode: PartitionMode::Partitioned,
expected_sources_unbounded: (
SourceType::Unbounded,
SourceType::Unbounded,
),
expected_join_type: join_type,
expected_mode: PartitionMode::Partitioned,
expecting_swap: false,
});
}
for case in cases.into_iter() {
test_join_with_maybe_swap_unbounded_case(case).await?
}
Ok(())
}
#[allow(clippy::vtable_address_comparisons)]
async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> {
let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded;
let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded;
let left_exec = Arc::new(UnboundedExec::new(
(!left_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Int32,
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
let right_exec = Arc::new(UnboundedExec::new(
(!right_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"b",
DataType::Int32,
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
let join = HashJoinExec::try_new(
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
Column::new_with_schema("a", &left_exec.schema())?,
Column::new_with_schema("b", &right_exec.schema())?,
)],
None,
&t.initial_join_type,
t.initial_mode,
false,
)?;
let initial_hash_join_state = PipelineStatePropagator {
plan: Arc::new(join),
unbounded: false,
children_unbounded: vec![left_unbounded, right_unbounded],
};
let optimized_hash_join =
hash_join_swap_subrule(initial_hash_join_state).unwrap()?;
let optimized_join_plan = optimized_hash_join.plan;
let projection_added = optimized_join_plan.as_any().is::<ProjectionExec>();
let plan = if projection_added {
let proj = optimized_join_plan
.as_any()
.downcast_ref::<ProjectionExec>()
.expect(
"A proj is required to swap columns back to their original order",
);
proj.input().clone()
} else {
optimized_join_plan
};
if let Some(HashJoinExec {
left,
right,
join_type,
mode,
..
}) = plan.as_any().downcast_ref::<HashJoinExec>()
{
let left_changed = Arc::ptr_eq(left, &right_exec);
let right_changed = Arc::ptr_eq(right, &left_exec);
assert_eq!(left_changed, right_changed);
assert_eq!(
(
t.case.as_str(),
if left.unbounded_output(&[])? {
SourceType::Unbounded
} else {
SourceType::Bounded
},
if right.unbounded_output(&[])? {
SourceType::Unbounded
} else {
SourceType::Bounded
},
join_type,
mode,
left_changed && right_changed
),
(
t.case.as_str(),
t.expected_sources_unbounded.0,
t.expected_sources_unbounded.1,
&t.expected_join_type,
&t.expected_mode,
t.expecting_swap
)
);
};
Ok(())
}
}