Skip to main content

nodedb_sql/planner/lateral/
correlation.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Extract correlation predicates from a LATERAL subquery's WHERE clause.
4//!
5//! A correlated predicate is one where the left side references the outer
6//! table and the right side references the inner table (or vice versa).
7//! The outer table is identified by its alias or name.
8
9use sqlparser::ast::{self, Expr, SetExpr};
10
11use crate::parser::normalize::normalize_ident;
12
13/// A single equi-correlation pair extracted from the subquery WHERE.
14///
15/// `outer_col` is the column name on the outer (driving) side;
16/// `inner_col` is the column name on the inner (lateral) side.
17#[derive(Debug, Clone)]
18pub struct CorrelationEq {
19    pub outer_col: String,
20    pub inner_col: String,
21}
22
23/// Result of analysing a lateral subquery's WHERE clause.
24#[derive(Debug, Default)]
25pub struct CorrelationAnalysis {
26    /// Equi-join pairs `(outer_col, inner_col)` extracted from `inner.col = outer.col`.
27    pub equi_keys: Vec<CorrelationEq>,
28    /// Non-equi correlated predicates as `(inner_col, outer_col)`.
29    pub non_equi: Vec<(String, String)>,
30    /// Remaining WHERE expression with correlated predicates stripped.
31    /// `None` when the entire WHERE was consumed.
32    pub remaining: Option<Expr>,
33}
34
35/// Analyse the WHERE clause of a LATERAL subquery.
36///
37/// `outer_alias` is the alias or name of the driving table (e.g. `"u"` for
38/// `FROM users u`). Any compound identifier `outer_alias.col` or `col = outer_alias.col`
39/// is treated as a correlation reference to the outer side.
40pub fn analyse_lateral_where(subquery: &ast::Query, outer_alias: &str) -> CorrelationAnalysis {
41    let select = match subquery.body.as_ref() {
42        SetExpr::Select(s) => s,
43        _ => return CorrelationAnalysis::default(),
44    };
45    let Some(where_expr) = &select.selection else {
46        return CorrelationAnalysis::default();
47    };
48
49    let mut analysis = CorrelationAnalysis::default();
50    analysis.remaining = extract_correlation_recursive(
51        where_expr,
52        outer_alias,
53        &mut analysis.equi_keys,
54        &mut analysis.non_equi,
55    );
56    analysis
57}
58
59/// Walk the WHERE expression, extracting correlated predicates.
60///
61/// Returns `None` when the expression was fully consumed; `Some(expr)` when
62/// a non-correlated residual remains.
63fn extract_correlation_recursive(
64    expr: &Expr,
65    outer_alias: &str,
66    equi_keys: &mut Vec<CorrelationEq>,
67    non_equi: &mut Vec<(String, String)>,
68) -> Option<Expr> {
69    match expr {
70        // AND: recurse both sides.
71        Expr::BinaryOp {
72            left,
73            op: ast::BinaryOperator::And,
74            right,
75        } => {
76            let l = extract_correlation_recursive(left, outer_alias, equi_keys, non_equi);
77            let r = extract_correlation_recursive(right, outer_alias, equi_keys, non_equi);
78            match (l, r) {
79                (None, None) => None,
80                (Some(e), None) | (None, Some(e)) => Some(e),
81                (Some(l), Some(r)) => Some(Expr::BinaryOp {
82                    left: Box::new(l),
83                    op: ast::BinaryOperator::And,
84                    right: Box::new(r),
85                }),
86            }
87        }
88
89        // Equi-predicate: check if one side is an outer reference.
90        Expr::BinaryOp {
91            left,
92            op: ast::BinaryOperator::Eq,
93            right,
94        } => {
95            let lp = compound_parts(left);
96            let rp = compound_parts(right);
97            match (lp, rp) {
98                (Some((lt, lc)), Some((rt, rc))) => {
99                    let lc_str = lc.as_str();
100                    let rc_str = rc.as_str();
101                    let lt_lower = lt.to_lowercase();
102                    let rt_lower = rt.to_lowercase();
103                    if lt_lower == outer_alias {
104                        // left = outer, right = inner
105                        equi_keys.push(CorrelationEq {
106                            outer_col: lc_str.to_string(),
107                            inner_col: rc_str.to_string(),
108                        });
109                        None
110                    } else if rt_lower == outer_alias {
111                        // right = outer, left = inner
112                        equi_keys.push(CorrelationEq {
113                            outer_col: rc_str.to_string(),
114                            inner_col: lc_str.to_string(),
115                        });
116                        None
117                    } else {
118                        // No outer reference — leave as-is.
119                        Some(expr.clone())
120                    }
121                }
122                _ => {
123                    // Try non-equi correlation detection.
124                    if is_correlated_expr(expr, outer_alias) {
125                        extract_non_equi_correlation(expr, outer_alias, non_equi);
126                        None
127                    } else {
128                        Some(expr.clone())
129                    }
130                }
131            }
132        }
133
134        // Non-equi predicates referencing the outer table.
135        Expr::BinaryOp { .. } => {
136            if is_correlated_expr(expr, outer_alias) {
137                extract_non_equi_correlation(expr, outer_alias, non_equi);
138                None
139            } else {
140                Some(expr.clone())
141            }
142        }
143
144        Expr::Nested(inner) => {
145            extract_correlation_recursive(inner, outer_alias, equi_keys, non_equi)
146        }
147
148        _ => Some(expr.clone()),
149    }
150}
151
152/// True when the expression references the outer table by alias.
153fn is_correlated_expr(expr: &Expr, outer_alias: &str) -> bool {
154    match expr {
155        Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
156            normalize_ident(&parts[0]).eq_ignore_ascii_case(outer_alias)
157        }
158        Expr::BinaryOp { left, right, .. } => {
159            is_correlated_expr(left, outer_alias) || is_correlated_expr(right, outer_alias)
160        }
161        _ => false,
162    }
163}
164
165/// Extract a non-equi correlation from a binary predicate into `non_equi`.
166///
167/// Records `(inner_side, outer_side)` for the predicate.
168fn extract_non_equi_correlation(
169    expr: &Expr,
170    outer_alias: &str,
171    non_equi: &mut Vec<(String, String)>,
172) {
173    let Expr::BinaryOp { left, right, .. } = expr else {
174        return;
175    };
176    let lp = compound_parts(left);
177    let rp = compound_parts(right);
178    if let (Some((lt, lc)), Some((rt, rc))) = (lp, rp) {
179        if rt.eq_ignore_ascii_case(outer_alias) {
180            non_equi.push((lc, rc));
181        } else if lt.eq_ignore_ascii_case(outer_alias) {
182            non_equi.push((rc, lc));
183        }
184    }
185}
186
187/// Extract `(table_alias, column_name)` from a compound identifier.
188fn compound_parts(expr: &Expr) -> Option<(String, String)> {
189    match expr {
190        Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
191            Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
192        }
193        _ => None,
194    }
195}