use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{DFSchema, assert_or_internal_err};
use datafusion_common::{NullEquality, Result};
use datafusion_expr::utils::split_conjunction_owned;
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
type EquijoinPredicate = (Expr, Expr);
#[derive(Default, Debug)]
pub struct ExtractEquijoinPredicate;
impl ExtractEquijoinPredicate {
#[expect(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for ExtractEquijoinPredicate {
fn supports_rewrite(&self) -> bool {
true
}
fn name(&self) -> &str {
"extract_equijoin_predicate"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Join(Join {
left,
right,
mut on,
filter: Some(expr),
join_type,
join_constraint,
schema,
null_equality,
null_aware,
}) => {
let left_schema = left.schema();
let right_schema = right.schema();
let (equijoin_predicates, non_equijoin_expr) =
split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
if on.is_empty()
&& equijoin_predicates.is_empty()
&& non_equijoin_expr.is_some()
{
let expr = non_equijoin_expr.clone().unwrap();
let (equijoin_predicates, non_equijoin_expr) =
split_is_not_distinct_from_and_other_join_predicate(
expr,
left_schema,
right_schema,
)?;
if !equijoin_predicates.is_empty() {
on.extend(equijoin_predicates);
return Ok(Transformed::yes(LogicalPlan::Join(Join {
left,
right,
on,
filter: non_equijoin_expr,
join_type,
join_constraint,
schema,
null_equality: NullEquality::NullEqualsNull,
null_aware,
})));
}
}
if !equijoin_predicates.is_empty() {
on.extend(equijoin_predicates);
Ok(Transformed::yes(LogicalPlan::Join(Join {
left,
right,
on,
filter: non_equijoin_expr,
join_type,
join_constraint,
schema,
null_equality,
null_aware,
})))
} else {
Ok(Transformed::no(LogicalPlan::Join(Join {
left,
right,
on,
filter: non_equijoin_expr,
join_type,
join_constraint,
schema,
null_equality,
null_aware,
})))
}
}
_ => Ok(Transformed::no(plan)),
}
}
}
fn split_eq_and_noneq_join_predicate(
filter: Expr,
left_schema: &DFSchema,
right_schema: &DFSchema,
) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
}
fn split_is_not_distinct_from_and_other_join_predicate(
filter: Expr,
left_schema: &DFSchema,
right_schema: &DFSchema,
) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
split_op_and_other_join_predicates(
filter,
left_schema,
right_schema,
Operator::IsNotDistinctFrom,
)
}
fn split_op_and_other_join_predicates(
filter: Expr,
left_schema: &DFSchema,
right_schema: &DFSchema,
operator: Operator,
) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
assert_or_internal_err!(
matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom),
"split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
but received: {:?}",
operator
);
let exprs = split_conjunction_owned(filter);
let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
let mut accum_filters: Vec<Expr> = vec![];
for expr in exprs {
match expr {
Expr::BinaryExpr(BinaryExpr {
ref left,
ref op,
ref right,
}) if *op == operator => {
let join_key_pair =
find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
if let Some((left_expr, right_expr)) = join_key_pair {
let left_expr_type = left_expr.get_type(left_schema)?;
let right_expr_type = right_expr.get_type(right_schema)?;
if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
accum_join_keys.push((left_expr, right_expr));
} else {
accum_filters.push(expr);
}
} else {
accum_filters.push(expr);
}
}
_ => accum_filters.push(expr),
}
}
let result_filter = accum_filters.into_iter().reduce(Expr::and);
Ok((accum_join_keys, result_filter))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_optimized_plan_eq_display_indent_snapshot;
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_expr::{
JoinType, col, lit, logical_plan::builder::LogicalPlanBuilder,
};
use std::sync::Arc;
macro_rules! assert_optimized_plan_equal {
(
$plan:expr,
@ $expected:literal $(,)?
) => {{
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
assert_optimized_plan_eq_display_indent_snapshot!(
rule,
$plan,
@ $expected,
)
}};
}
#[test]
fn join_with_only_column_equi_predicate() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_only_equi_expr_predicate() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.join_on(
t2,
JoinType::Left,
Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
)?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_only_none_equi_predicate() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.join_on(
t2,
JoinType::Left,
Some(
(col("t1.a") + lit(10i64))
.gt_eq(col("t2.a") * lit(2u32))
.and(col("t1.b").lt(lit(100i32))),
),
)?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.join_with_expr_keys(
t2,
JoinType::Left,
(
vec![col("t1.a") + lit(11u32)],
vec![col("t2.a") * lit(2u32)],
),
Some(
(col("t1.a") + lit(10i64))
.eq(col("t2.a") * lit(2u32))
.and(col("t1.b").lt(lit(100i32))),
),
)?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_and_or_filter() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.join_on(
t2,
JoinType::Left,
Some(
col("t1.c")
.eq(col("t2.c"))
.or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
.and(
col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
),
),
)?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_multiple_table() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let input = LogicalPlanBuilder::from(t2)
.join_on(
t3,
JoinType::Left,
Some(
col("t2.a")
.eq(col("t3.a"))
.and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
),
)?
.build()?;
let plan = LogicalPlanBuilder::from(t1)
.join_on(
input,
JoinType::Left,
Some(
col("t1.a")
.eq(col("t2.a"))
.and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
),
)?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_multiple_table_and_eq_filter() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let input = LogicalPlanBuilder::from(t2)
.join_on(
t3,
JoinType::Left,
Some(
col("t2.a")
.eq(col("t3.a"))
.and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
),
)?
.build()?;
let plan = LogicalPlanBuilder::from(t1)
.join_on(
input,
JoinType::Left,
Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
)?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
#[test]
fn join_with_alias_filter() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t1_schema = Arc::clone(t1.schema());
let t2_schema = Arc::clone(t2.schema());
let filter = Expr::eq(
col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
)
.alias("t1.a + 1 = t2.a + 2");
let plan = LogicalPlanBuilder::from(t1)
.join_on(t2, JoinType::Left, Some(filter))?
.build()?;
assert_optimized_plan_equal!(
plan,
@r"
Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
"
)
}
}