use sqlparser::ast::{self, Expr, SetExpr};
use crate::parser::normalize::normalize_ident;
#[derive(Debug, Clone)]
pub struct CorrelationEq {
pub outer_col: String,
pub inner_col: String,
}
#[derive(Debug, Default)]
pub struct CorrelationAnalysis {
pub equi_keys: Vec<CorrelationEq>,
pub non_equi: Vec<(String, String)>,
pub remaining: Option<Expr>,
}
pub fn analyse_lateral_where(subquery: &ast::Query, outer_alias: &str) -> CorrelationAnalysis {
let select = match subquery.body.as_ref() {
SetExpr::Select(s) => s,
_ => return CorrelationAnalysis::default(),
};
let Some(where_expr) = &select.selection else {
return CorrelationAnalysis::default();
};
let mut analysis = CorrelationAnalysis::default();
analysis.remaining = extract_correlation_recursive(
where_expr,
outer_alias,
&mut analysis.equi_keys,
&mut analysis.non_equi,
);
analysis
}
fn extract_correlation_recursive(
expr: &Expr,
outer_alias: &str,
equi_keys: &mut Vec<CorrelationEq>,
non_equi: &mut Vec<(String, String)>,
) -> Option<Expr> {
match expr {
Expr::BinaryOp {
left,
op: ast::BinaryOperator::And,
right,
} => {
let l = extract_correlation_recursive(left, outer_alias, equi_keys, non_equi);
let r = extract_correlation_recursive(right, outer_alias, equi_keys, non_equi);
match (l, r) {
(None, None) => None,
(Some(e), None) | (None, Some(e)) => Some(e),
(Some(l), Some(r)) => Some(Expr::BinaryOp {
left: Box::new(l),
op: ast::BinaryOperator::And,
right: Box::new(r),
}),
}
}
Expr::BinaryOp {
left,
op: ast::BinaryOperator::Eq,
right,
} => {
let lp = compound_parts(left);
let rp = compound_parts(right);
match (lp, rp) {
(Some((lt, lc)), Some((rt, rc))) => {
let lc_str = lc.as_str();
let rc_str = rc.as_str();
let lt_lower = lt.to_lowercase();
let rt_lower = rt.to_lowercase();
if lt_lower == outer_alias {
equi_keys.push(CorrelationEq {
outer_col: lc_str.to_string(),
inner_col: rc_str.to_string(),
});
None
} else if rt_lower == outer_alias {
equi_keys.push(CorrelationEq {
outer_col: rc_str.to_string(),
inner_col: lc_str.to_string(),
});
None
} else {
Some(expr.clone())
}
}
_ => {
if is_correlated_expr(expr, outer_alias) {
extract_non_equi_correlation(expr, outer_alias, non_equi);
None
} else {
Some(expr.clone())
}
}
}
}
Expr::BinaryOp { .. } => {
if is_correlated_expr(expr, outer_alias) {
extract_non_equi_correlation(expr, outer_alias, non_equi);
None
} else {
Some(expr.clone())
}
}
Expr::Nested(inner) => {
extract_correlation_recursive(inner, outer_alias, equi_keys, non_equi)
}
_ => Some(expr.clone()),
}
}
fn is_correlated_expr(expr: &Expr, outer_alias: &str) -> bool {
match expr {
Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
normalize_ident(&parts[0]).eq_ignore_ascii_case(outer_alias)
}
Expr::BinaryOp { left, right, .. } => {
is_correlated_expr(left, outer_alias) || is_correlated_expr(right, outer_alias)
}
_ => false,
}
}
fn extract_non_equi_correlation(
expr: &Expr,
outer_alias: &str,
non_equi: &mut Vec<(String, String)>,
) {
let Expr::BinaryOp { left, right, .. } = expr else {
return;
};
let lp = compound_parts(left);
let rp = compound_parts(right);
if let (Some((lt, lc)), Some((rt, rc))) = (lp, rp) {
if rt.eq_ignore_ascii_case(outer_alias) {
non_equi.push((lc, rc));
} else if lt.eq_ignore_ascii_case(outer_alias) {
non_equi.push((rc, lc));
}
}
}
fn compound_parts(expr: &Expr) -> Option<(String, String)> {
match expr {
Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
}
_ => None,
}
}