use std::sync::Arc;
use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
use crate::optimizer::ApplyOrder;
use crate::utils::evaluates_to_null;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_expr::{Expr, Join, expr};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{Column, DFSchema, Result, ScalarValue, TableReference};
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::conjunction;
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, SubqueryAlias};
#[derive(Default, Debug)]
pub struct DecorrelateLateralJoin {}
impl DecorrelateLateralJoin {
#[expect(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}
impl OptimizerRule for DecorrelateLateralJoin {
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let LogicalPlan::Join(join) = plan else {
return Ok(Transformed::no(plan));
};
rewrite_internal(join)
}
fn name(&self) -> &str {
"decorrelate_lateral_join"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
}
fn rewrite_internal(join: Join) -> Result<Transformed<LogicalPlan>> {
if !matches!(join.join_type, JoinType::Inner | JoinType::Left) {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
let original_join_type = join.join_type;
let Some((subquery, alias)) = extract_lateral_subquery(join.right.as_ref()) else {
return Ok(Transformed::no(LogicalPlan::Join(join)));
};
let has_outer_refs = matches!(
subquery.subquery.apply_with_subqueries(|p| {
if p.contains_outer_reference() {
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})?,
TreeNodeRecursion::Stop
);
if !has_outer_refs {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
let subquery_plan = subquery.subquery.as_ref();
let original_join_filter = join.filter.clone();
let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
if !pull_up.can_pull_up {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
if pull_up.pull_up_having_expr.is_some() {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
let correlation_filter = conjunction(pull_up.join_filters);
let collected_count_expr_map = pull_up
.collected_count_expr_map
.get(&rewritten_subquery)
.cloned();
let (right_plan, correlation_filter, original_join_filter) =
if let Some(ref alias) = alias {
let inner_schema = Arc::clone(rewritten_subquery.schema());
let right = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
Arc::new(rewritten_subquery),
alias.clone(),
)?);
let corr = correlation_filter
.map(|f| requalify_filter(f, &inner_schema, alias))
.transpose()?;
let on = original_join_filter
.map(|f| requalify_filter(f, &inner_schema, alias))
.transpose()?;
(right, corr, on)
} else {
(rewritten_subquery, correlation_filter, original_join_filter)
};
if original_join_type == JoinType::Left
&& let Some(ref filter) = correlation_filter
{
let left_schema = join.left.schema();
let right_schema = right_plan.schema();
let has_outer_scope_refs = filter
.column_refs()
.iter()
.any(|col| !left_schema.has_column(col) && !right_schema.has_column(col));
if has_outer_scope_refs {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
}
let join_type =
if original_join_type == JoinType::Left || pull_up.pulled_up_scalar_agg {
JoinType::Left
} else {
JoinType::Inner
};
let (join_filter, post_join_filter, on_condition_for_projection) =
if original_join_type == JoinType::Left {
if pull_up.pulled_up_scalar_agg {
(correlation_filter, None, original_join_filter)
} else {
let combined = conjunction(
correlation_filter.into_iter().chain(original_join_filter),
);
(combined, None, None)
}
} else {
(correlation_filter, original_join_filter, None)
};
let left_field_count = join.left.schema().fields().len();
let new_plan = LogicalPlanBuilder::from(join.left)
.join_on(right_plan, join_type, join_filter)?
.build()?;
let new_plan = if let Some(expr_map) = collected_count_expr_map {
let join_schema = new_plan.schema();
let alias_qualifier = alias.as_ref();
let mut proj_exprs: Vec<Expr> = vec![];
for (i, (qualifier, field)) in join_schema.iter().enumerate() {
let col = Expr::Column(Column::new(qualifier.cloned(), field.name()));
let name = field.name();
if i >= left_field_count
&& let Some(default_value) = expr_map.get(name.as_str())
&& !evaluates_to_null(default_value.clone(), default_value.column_refs())?
{
let indicator_col =
Column::new(alias_qualifier.cloned(), UN_MATCHED_ROW_INDICATOR);
let case_expr = Expr::Case(expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(Expr::IsNull(Box::new(Expr::Column(indicator_col)))),
Box::new(default_value.clone()),
)],
else_expr: Some(Box::new(col)),
});
proj_exprs.push(Expr::Alias(expr::Alias {
expr: Box::new(case_expr),
relation: qualifier.cloned(),
name: name.to_string(),
metadata: None,
}));
continue;
}
proj_exprs.push(col);
}
LogicalPlanBuilder::from(new_plan)
.project(proj_exprs)?
.build()?
} else {
new_plan
};
let new_plan = if let Some(on_cond) = on_condition_for_projection {
let schema = Arc::clone(new_plan.schema());
let mut proj_exprs: Vec<Expr> = vec![];
for (i, (qualifier, field)) in schema.iter().enumerate() {
let col = Expr::Column(Column::new(qualifier.cloned(), field.name()));
if i < left_field_count {
proj_exprs.push(col);
continue;
}
let typed_null =
Expr::Literal(ScalarValue::try_from(field.data_type())?, None);
let case_expr = Expr::Case(expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(Expr::IsNotTrue(Box::new(on_cond.clone()))),
Box::new(typed_null),
)],
else_expr: Some(Box::new(col)),
});
proj_exprs.push(case_expr.alias_qualified(qualifier.cloned(), field.name()));
}
LogicalPlanBuilder::from(new_plan)
.project(proj_exprs)?
.build()?
} else {
new_plan
};
let new_plan = if let Some(on_filter) = post_join_filter {
LogicalPlanBuilder::from(new_plan)
.filter(on_filter)?
.build()?
} else {
new_plan
};
Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump))
}
fn extract_lateral_subquery(
plan: &LogicalPlan,
) -> Option<(Subquery, Option<TableReference>)> {
match plan {
LogicalPlan::Subquery(sq) => Some((sq.clone(), None)),
LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
if let LogicalPlan::Subquery(sq) = input.as_ref() {
Some((sq.clone(), Some(alias.clone())))
} else {
None
}
}
_ => None,
}
}
fn requalify_filter(
filter: Expr,
inner_schema: &DFSchema,
alias: &TableReference,
) -> Result<Expr> {
filter
.transform(|expr| {
if let Expr::Column(col) = &expr
&& inner_schema.has_column(col)
{
let new_col = Column::new(Some(alias.clone()), col.name.clone());
return Ok(Transformed::yes(Expr::Column(new_col)));
}
Ok(Transformed::no(expr))
})
.data()
}