Skip to main content

nodedb_sql/planner/
subquery.rs

1//! Subquery planning: IN (SELECT ...), NOT IN (SELECT ...), scalar subqueries.
2//!
3//! Rewrites WHERE-clause subqueries into semi/anti joins so the existing
4//! hash-join executor handles them without a dedicated subquery engine.
5//!
6//! Supported patterns:
7//!   - `WHERE col IN (SELECT col2 FROM tbl ...)`  → semi-join
8//!   - `WHERE col NOT IN (SELECT col2 FROM tbl ...)` → anti-join
9//!   - `WHERE col > (SELECT AGG(...) FROM tbl ...)` → scalar subquery (materialized)
10
11use sqlparser::ast::{self, Expr, SetExpr};
12
13use crate::error::{Result, SqlError};
14use crate::functions::registry::FunctionRegistry;
15use crate::parser::normalize::normalize_ident;
16use crate::types::*;
17
18/// Result of extracting subqueries from a WHERE clause.
19pub struct SubqueryExtraction {
20    /// Semi/anti joins to wrap around the base scan.
21    pub joins: Vec<SubqueryJoin>,
22    /// Remaining WHERE expression with subqueries removed (None if nothing remains).
23    pub remaining_where: Option<Expr>,
24}
25
26/// A subquery that was rewritten as a join.
27pub struct SubqueryJoin {
28    /// The column on the outer table to join on.
29    pub outer_column: String,
30    /// The planned inner SELECT.
31    pub inner_plan: SqlPlan,
32    /// The column from the inner SELECT to join on.
33    pub inner_column: String,
34    /// Semi (IN) or Anti (NOT IN).
35    pub join_type: JoinType,
36}
37
38fn canonical_aggregate_key(function: &str, field: &str) -> String {
39    format!("{function}({field})")
40}
41
42/// Extract `IN (SELECT ...)` and `NOT IN (SELECT ...)` patterns from a WHERE clause.
43///
44/// Returns the extracted subquery joins and the remaining WHERE expression
45/// (with subquery predicates removed). If the entire WHERE is a single
46/// subquery predicate, `remaining_where` is `None`.
47pub fn extract_subqueries(
48    expr: &Expr,
49    catalog: &dyn SqlCatalog,
50    functions: &FunctionRegistry,
51) -> Result<SubqueryExtraction> {
52    let mut joins = Vec::new();
53    let remaining = extract_recursive(expr, &mut joins, catalog, functions)?;
54    Ok(SubqueryExtraction {
55        joins,
56        remaining_where: remaining,
57    })
58}
59
60/// Recursively walk the WHERE expression, extracting subquery predicates.
61///
62/// Returns `None` if the entire expression was consumed (subquery-only),
63/// or `Some(expr)` with the remaining non-subquery predicates.
64fn extract_recursive(
65    expr: &Expr,
66    joins: &mut Vec<SubqueryJoin>,
67    catalog: &dyn SqlCatalog,
68    functions: &FunctionRegistry,
69) -> Result<Option<Expr>> {
70    match expr {
71        // AND: recurse both sides, reconstruct with remaining parts.
72        Expr::BinaryOp {
73            left,
74            op: ast::BinaryOperator::And,
75            right,
76        } => {
77            let left_remaining = extract_recursive(left, joins, catalog, functions)?;
78            let right_remaining = extract_recursive(right, joins, catalog, functions)?;
79            match (left_remaining, right_remaining) {
80                (None, None) => Ok(None),
81                (Some(l), None) => Ok(Some(l)),
82                (None, Some(r)) => Ok(Some(r)),
83                (Some(l), Some(r)) => Ok(Some(Expr::BinaryOp {
84                    left: Box::new(l),
85                    op: ast::BinaryOperator::And,
86                    right: Box::new(r),
87                })),
88            }
89        }
90
91        // IN (SELECT ...): rewrite as semi-join.
92        Expr::InSubquery {
93            expr: outer_expr,
94            subquery,
95            negated,
96        } => {
97            if let Some(join) =
98                try_plan_in_subquery(outer_expr, subquery, *negated, catalog, functions)?
99            {
100                joins.push(join);
101                Ok(None) // This predicate is consumed.
102            } else {
103                // Cannot plan as join — return original expression.
104                Ok(Some(expr.clone()))
105            }
106        }
107
108        // Scalar subquery comparison: `col > (SELECT AGG(...) FROM ...)`
109        Expr::BinaryOp { left, op, right } if is_comparison_op(op) => {
110            if let Expr::Subquery(subquery) = right.as_ref() {
111                if let Some(scalar) = try_plan_scalar_subquery(subquery, catalog, functions)? {
112                    joins.push(scalar.join);
113                    Ok(Some(Expr::BinaryOp {
114                        left: left.clone(),
115                        op: op.clone(),
116                        right: Box::new(scalar.replacement_expr),
117                    }))
118                } else {
119                    Ok(Some(expr.clone()))
120                }
121            } else {
122                Ok(Some(expr.clone()))
123            }
124        }
125
126        // EXISTS (SELECT ...): rewrite as semi-join.
127        // NOT EXISTS (SELECT ...): rewrite as anti-join.
128        Expr::Exists { subquery, negated } => {
129            if let Some(join) = try_plan_exists_subquery(subquery, *negated, catalog, functions)? {
130                joins.push(join);
131                Ok(None)
132            } else {
133                Ok(Some(expr.clone()))
134            }
135        }
136
137        // Nested parentheses.
138        Expr::Nested(inner) => extract_recursive(inner, joins, catalog, functions),
139
140        // Not a subquery pattern — return as-is.
141        _ => Ok(Some(expr.clone())),
142    }
143}
144
145/// Try to plan `col IN (SELECT col2 FROM tbl ...)` as a semi/anti join.
146fn try_plan_in_subquery(
147    outer_expr: &Expr,
148    subquery: &ast::Query,
149    negated: bool,
150    catalog: &dyn SqlCatalog,
151    functions: &FunctionRegistry,
152) -> Result<Option<SubqueryJoin>> {
153    // Extract outer column name.
154    let outer_col = match outer_expr {
155        Expr::Identifier(ident) => normalize_ident(ident),
156        Expr::CompoundIdentifier(parts) if parts.len() == 2 => normalize_ident(&parts[1]),
157        _ => return Ok(None), // Complex expression, can't rewrite.
158    };
159
160    // Plan the inner SELECT.
161    let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
162
163    // Extract the projected column from the inner plan.
164    let inner_col = extract_single_projected_column(subquery)?;
165
166    Ok(Some(SubqueryJoin {
167        outer_column: outer_col,
168        inner_plan,
169        inner_column: inner_col,
170        join_type: if negated {
171            JoinType::Anti
172        } else {
173            JoinType::Semi
174        },
175    }))
176}
177
178/// Extract the single column name from a subquery's SELECT list.
179///
180/// For `SELECT user_id FROM orders`, returns `"user_id"`.
181fn extract_single_projected_column(query: &ast::Query) -> Result<String> {
182    let select = match &*query.body {
183        SetExpr::Select(s) => s,
184        _ => {
185            return Err(SqlError::Unsupported {
186                detail: "subquery must be a simple SELECT".into(),
187            });
188        }
189    };
190
191    if select.projection.len() != 1 {
192        return Err(SqlError::Unsupported {
193            detail: format!(
194                "subquery must select exactly 1 column, got {}",
195                select.projection.len()
196            ),
197        });
198    }
199
200    match &select.projection[0] {
201        ast::SelectItem::UnnamedExpr(expr) => match expr {
202            Expr::Identifier(ident) => Ok(normalize_ident(ident)),
203            Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(normalize_ident(&parts[1])),
204            _ => Err(SqlError::Unsupported {
205                detail: "subquery projection must be a column reference".into(),
206            }),
207        },
208        ast::SelectItem::ExprWithAlias { alias, .. } => Ok(normalize_ident(alias)),
209        _ => Err(SqlError::Unsupported {
210            detail: "subquery projection must be a column reference".into(),
211        }),
212    }
213}
214
215/// Plan `EXISTS (SELECT 1 FROM tbl WHERE tbl.col = outer.col)` as a semi/anti join.
216///
217/// Extracts the correlated column from the subquery's WHERE clause.
218fn try_plan_exists_subquery(
219    subquery: &ast::Query,
220    negated: bool,
221    catalog: &dyn SqlCatalog,
222    functions: &FunctionRegistry,
223) -> Result<Option<SubqueryJoin>> {
224    let select = match &*subquery.body {
225        SetExpr::Select(s) => s,
226        _ => return Ok(None),
227    };
228
229    // Look for a correlated predicate in the WHERE: inner.col = outer.col
230    let (outer_col, inner_col) = match &select.selection {
231        Some(expr) => match extract_correlated_eq(expr) {
232            Some(pair) => pair,
233            None => return Ok(None),
234        },
235        None => return Ok(None),
236    };
237
238    // Build a simplified subquery without the correlated predicate for planning.
239    let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
240
241    Ok(Some(SubqueryJoin {
242        outer_column: outer_col,
243        inner_plan,
244        inner_column: inner_col,
245        join_type: if negated {
246            JoinType::Anti
247        } else {
248            JoinType::Semi
249        },
250    }))
251}
252
253/// Extract a correlated equality predicate from a WHERE clause.
254///
255/// Looks for patterns like `o.user_id = u.id` and returns (outer_col, inner_col).
256/// The "inner" column is the one qualified with the subquery's table alias;
257/// the "outer" column is the one referencing the outer query's table.
258fn extract_correlated_eq(expr: &Expr) -> Option<(String, String)> {
259    match expr {
260        Expr::BinaryOp {
261            left,
262            op: ast::BinaryOperator::Eq,
263            right,
264        } => {
265            let left_parts = extract_qualified_column(left);
266            let right_parts = extract_qualified_column(right);
267            match (left_parts, right_parts) {
268                (Some((_lt, lc)), Some((_rt, rc))) => {
269                    // Convention: left is inner (subquery table), right is outer.
270                    // But we can't distinguish without schema, so just return both.
271                    Some((rc, lc))
272                }
273                _ => None,
274            }
275        }
276        // For AND, try to find a correlated eq in either side.
277        Expr::BinaryOp {
278            left,
279            op: ast::BinaryOperator::And,
280            right,
281        } => extract_correlated_eq(left).or_else(|| extract_correlated_eq(right)),
282        Expr::Nested(inner) => extract_correlated_eq(inner),
283        _ => None,
284    }
285}
286
287/// Extract table.column from a qualified identifier.
288fn extract_qualified_column(expr: &Expr) -> Option<(String, String)> {
289    match expr {
290        Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
291            Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
292        }
293        Expr::Identifier(ident) => Some((String::new(), normalize_ident(ident))),
294        _ => None,
295    }
296}
297
298fn is_comparison_op(op: &ast::BinaryOperator) -> bool {
299    matches!(
300        op,
301        ast::BinaryOperator::Gt
302            | ast::BinaryOperator::GtEq
303            | ast::BinaryOperator::Lt
304            | ast::BinaryOperator::LtEq
305            | ast::BinaryOperator::Eq
306            | ast::BinaryOperator::NotEq
307    )
308}
309
310/// Result of planning a scalar subquery.
311struct ScalarSubqueryResult {
312    join: SubqueryJoin,
313    replacement_expr: Expr,
314}
315
316/// Plan a scalar subquery (e.g., `(SELECT AVG(amount) FROM orders)`).
317///
318/// Rewrites `col > (SELECT AVG(amount) FROM orders)` as:
319///   cross-join with the aggregate result (1 row), then filter `col > result_col`.
320///
321/// The cross-join produces a cartesian product, but since the aggregate returns
322/// exactly 1 row, every outer row gets paired with that single result row.
323fn try_plan_scalar_subquery(
324    subquery: &ast::Query,
325    catalog: &dyn SqlCatalog,
326    functions: &FunctionRegistry,
327) -> Result<Option<ScalarSubqueryResult>> {
328    let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
329
330    // Extract the result column name from the subquery's SELECT list.
331    let result_col = match extract_scalar_column(subquery) {
332        Some(col) => col,
333        None => return Ok(None),
334    };
335
336    let replacement = Expr::Identifier(ast::Ident::new(&result_col));
337
338    Ok(Some(ScalarSubqueryResult {
339        join: SubqueryJoin {
340            outer_column: String::new(),
341            inner_plan,
342            inner_column: String::new(),
343            join_type: JoinType::Cross,
344        },
345        replacement_expr: replacement,
346    }))
347}
348
349/// Extract the projected column name from a scalar subquery.
350///
351/// Handles aliased aggregates like `SELECT AVG(amount) AS avg_amount`.
352/// For unaliased aggregates, returns the canonical aggregate key emitted by
353/// the aggregate executor (e.g. `avg(amount)`, `count(*)`).
354fn extract_scalar_column(query: &ast::Query) -> Option<String> {
355    let select = match &*query.body {
356        SetExpr::Select(s) => s,
357        _ => return None,
358    };
359    if select.projection.len() != 1 {
360        return None;
361    }
362    match &select.projection[0] {
363        ast::SelectItem::ExprWithAlias { alias, .. } => Some(normalize_ident(alias)),
364        ast::SelectItem::UnnamedExpr(expr) => match expr {
365            Expr::Identifier(ident) => Some(normalize_ident(ident)),
366            Expr::CompoundIdentifier(parts) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
367            Expr::Function(func) => {
368                let func_name = func
369                    .name
370                    .0
371                    .iter()
372                    .map(|p| match p {
373                        ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
374                        _ => String::new(),
375                    })
376                    .collect::<Vec<_>>()
377                    .join(".")
378                    .to_lowercase();
379                let arg = match &func.args {
380                    ast::FunctionArguments::List(arg_list) => arg_list
381                        .args
382                        .first()
383                        .and_then(|a| match a {
384                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
385                                Expr::Identifier(ident),
386                            )) => Some(normalize_ident(ident)),
387                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
388                                Expr::CompoundIdentifier(parts),
389                            )) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
390                            ast::FunctionArg::Unnamed(
391                                ast::FunctionArgExpr::Wildcard
392                                | ast::FunctionArgExpr::QualifiedWildcard(_),
393                            ) => Some("all".to_string()),
394                            _ => None,
395                        })
396                        .unwrap_or_else(|| "*".to_string()),
397                    _ => "*".to_string(),
398                };
399                Some(canonical_aggregate_key(&func_name, &arg))
400            }
401            _ => None,
402        },
403        _ => None,
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::extract_scalar_column;
410    use crate::parser::statement::parse_sql;
411    use sqlparser::ast::Statement;
412
413    #[test]
414    fn unaliased_scalar_aggregate_uses_canonical_aggregate_key() {
415        let statements = parse_sql("SELECT AVG(amount) FROM orders").unwrap();
416        let Statement::Query(query) = &statements[0] else {
417            panic!("expected query");
418        };
419        assert_eq!(extract_scalar_column(query), Some("avg(amount)".into()));
420    }
421}