use std::collections::BTreeSet;
use crate::decorrelate::PullUpCorrelatedExpr;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_expr::{Join, lit};
use datafusion_common::Result;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::utils::conjunction;
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
#[derive(Default, Debug)]
pub struct DecorrelateLateralJoin {}
impl DecorrelateLateralJoin {
#[expect(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}
impl OptimizerRule for DecorrelateLateralJoin {
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let LogicalPlan::Join(join) = plan else {
return Ok(Transformed::no(plan));
};
rewrite_internal(join)
}
fn name(&self) -> &str {
"decorrelate_lateral_join"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
}
fn rewrite_internal(join: Join) -> Result<Transformed<LogicalPlan>> {
if join.join_type != JoinType::Inner {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
match join.right.apply_with_subqueries(|p| {
if p.contains_outer_reference() {
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})? {
TreeNodeRecursion::Stop => {}
TreeNodeRecursion::Continue => {
return Ok(Transformed::new(
LogicalPlan::Join(join),
false,
TreeNodeRecursion::Jump,
));
}
TreeNodeRecursion::Jump => {
unreachable!("")
}
}
let LogicalPlan::Subquery(subquery) = join.right.as_ref() else {
return Ok(Transformed::no(LogicalPlan::Join(join)));
};
if join.join_type != JoinType::Inner {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
let subquery_plan = subquery.subquery.as_ref();
let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
if !pull_up.can_pull_up {
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
let mut all_correlated_cols = BTreeSet::new();
pull_up
.correlated_subquery_cols_map
.values()
.for_each(|cols| all_correlated_cols.extend(cols.clone()));
let join_filter_opt = conjunction(pull_up.join_filters);
let join_filter = match join_filter_opt {
Some(join_filter) => join_filter,
None => lit(true),
};
let new_plan = LogicalPlanBuilder::from(join.left)
.join_on(
rewritten_subquery,
if pull_up.pulled_up_scalar_agg {
JoinType::Left
} else {
JoinType::Inner
},
Some(join_filter),
)?
.build()?;
Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump))
}