use crate::logical_plan::consumer::SubstraitConsumer;
use datafusion::common::{Column, JoinType, NullEquality, not_impl_err, plan_err};
use datafusion::logical_expr::requalify_sides_if_needed;
use datafusion::logical_expr::utils::split_conjunction_owned;
use datafusion::logical_expr::{
BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator,
};
use substrait::proto::{JoinRel, join_rel};
pub async fn from_join_rel(
consumer: &impl SubstraitConsumer,
join: &JoinRel,
) -> datafusion::common::Result<LogicalPlan> {
if join.post_join_filter.is_some() {
return not_impl_err!("JoinRel with post_join_filter is not yet supported");
}
let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
consumer.consume_rel(join.left.as_ref().unwrap()).await?,
);
let right = LogicalPlanBuilder::from(
consumer.consume_rel(join.right.as_ref().unwrap()).await?,
);
let (left, right, _requalified) = requalify_sides_if_needed(left, right)?;
let join_type = from_substrait_jointype(join.r#type)?;
let in_join_schema = left.schema().join(right.schema())?;
match &join.expression.as_ref() {
Some(expr) => {
let on = consumer.consume_expression(expr, &in_join_schema).await?;
let (join_ons, null_equality, join_filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(on);
let (left_cols, right_cols): (Vec<_>, Vec<_>) =
itertools::multiunzip(join_ons);
left.join_detailed(
right.build()?,
join_type,
(left_cols, right_cols),
join_filter,
null_equality,
)?
.build()
}
None => {
let on: Vec<String> = vec![];
left.join_detailed(
right.build()?,
join_type,
(on.clone(), on),
None,
NullEquality::NullEqualsNothing,
)?
.build()
}
}
}
fn split_eq_and_noneq_join_predicate_with_nulls_equality(
filter: Expr,
) -> (Vec<(Column, Column)>, NullEquality, Option<Expr>) {
let exprs = split_conjunction_owned(filter);
let mut eq_keys: Vec<(Column, Column)> = vec![];
let mut indistinct_keys: Vec<(Column, Column)> = vec![];
let mut accum_filters: Vec<Expr> = vec![];
for expr in exprs {
match expr {
Expr::BinaryExpr(BinaryExpr {
left,
op: op @ (Operator::Eq | Operator::IsNotDistinctFrom),
right,
}) => match (*left, *right) {
(Expr::Column(l), Expr::Column(r)) => match op {
Operator::Eq => eq_keys.push((l, r)),
Operator::IsNotDistinctFrom => indistinct_keys.push((l, r)),
_ => unreachable!(),
},
(left, right) => {
accum_filters.push(Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
}));
}
},
_ => accum_filters.push(expr),
}
}
let (join_keys, null_equality) =
match (eq_keys.is_empty(), indistinct_keys.is_empty()) {
(false, false) => {
for (l, r) in indistinct_keys {
accum_filters.push(Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column(l)),
op: Operator::IsNotDistinctFrom,
right: Box::new(Expr::Column(r)),
}));
}
(eq_keys, NullEquality::NullEqualsNothing)
}
(false, true) => (eq_keys, NullEquality::NullEqualsNothing),
(true, false) => (indistinct_keys, NullEquality::NullEqualsNull),
(true, true) => (vec![], NullEquality::NullEqualsNothing),
};
let join_filter = accum_filters.into_iter().reduce(Expr::and);
(join_keys, null_equality, join_filter)
}
fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result<JoinType> {
if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) {
match substrait_join_type {
join_rel::JoinType::Inner => Ok(JoinType::Inner),
join_rel::JoinType::Left => Ok(JoinType::Left),
join_rel::JoinType::Right => Ok(JoinType::Right),
join_rel::JoinType::Outer => Ok(JoinType::Full),
join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti),
join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi),
join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark),
join_rel::JoinType::RightMark => Ok(JoinType::RightMark),
join_rel::JoinType::RightAnti => Ok(JoinType::RightAnti),
join_rel::JoinType::RightSemi => Ok(JoinType::RightSemi),
_ => plan_err!("unsupported join type {substrait_join_type:?}"),
}
} else {
plan_err!("invalid join type variant {join_type}")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn col(name: &str) -> Expr {
Expr::Column(Column::from_name(name))
}
fn indistinct(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op: Operator::IsNotDistinctFrom,
right: Box::new(right),
})
}
fn fmt_keys(keys: &[(Column, Column)]) -> String {
keys.iter()
.map(|(l, r)| format!("{l} = {r}"))
.collect::<Vec<_>>()
.join(", ")
}
#[test]
fn split_only_eq_keys() {
let expr = col("a").eq(col("b"));
let (keys, null_eq, filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(expr);
assert_eq!(fmt_keys(&keys), "a = b");
assert_eq!(null_eq, NullEquality::NullEqualsNothing);
assert!(filter.is_none());
}
#[test]
fn split_only_indistinct_keys() {
let expr = indistinct(col("a"), col("b"));
let (keys, null_eq, filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(expr);
assert_eq!(fmt_keys(&keys), "a = b");
assert_eq!(null_eq, NullEquality::NullEqualsNull);
assert!(filter.is_none());
}
#[test]
fn split_mixed_eq_and_indistinct_demotes_indistinct_to_filter() {
let expr =
indistinct(col("val_l"), col("val_r")).and(col("id_l").eq(col("id_r")));
let (keys, null_eq, filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(expr);
assert_eq!(fmt_keys(&keys), "id_l = id_r");
assert_eq!(null_eq, NullEquality::NullEqualsNothing);
assert_eq!(
filter.unwrap().to_string(),
"val_l IS NOT DISTINCT FROM val_r"
);
}
#[test]
fn split_mixed_multiple_indistinct_demoted() {
let expr = indistinct(col("a_l"), col("a_r"))
.and(indistinct(col("b_l"), col("b_r")))
.and(col("id_l").eq(col("id_r")));
let (keys, null_eq, filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(expr);
assert_eq!(fmt_keys(&keys), "id_l = id_r");
assert_eq!(null_eq, NullEquality::NullEqualsNothing);
assert_eq!(
filter.unwrap().to_string(),
"a_l IS NOT DISTINCT FROM a_r AND b_l IS NOT DISTINCT FROM b_r"
);
}
#[test]
fn split_non_column_eq_goes_to_filter() {
let expr = Expr::Literal(
datafusion::common::ScalarValue::Utf8(Some("x".into())),
None,
)
.eq(col("b"));
let (keys, _, filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(expr);
assert!(keys.is_empty());
assert_eq!(filter.unwrap().to_string(), "Utf8(\"x\") = b");
}
}