Skip to main content

nodedb_sql/planner/
subquery.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Subquery planning: IN (SELECT ...), NOT IN (SELECT ...), scalar subqueries.
4//!
5//! Rewrites WHERE-clause subqueries into semi/anti joins so the existing
6//! hash-join executor handles them without a dedicated subquery engine.
7//!
8//! Supported patterns:
9//!   - `WHERE col IN (SELECT col2 FROM tbl ...)`  → semi-join
10//!   - `WHERE col NOT IN (SELECT col2 FROM tbl ...)` → anti-join
11//!   - `WHERE col > (SELECT AGG(...) FROM tbl ...)` → scalar subquery (materialized)
12
13use sqlparser::ast::{self, Expr, SetExpr};
14
15use crate::error::{Result, SqlError};
16use crate::functions::registry::FunctionRegistry;
17use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
18use crate::types::*;
19
20/// Result of extracting subqueries from a WHERE clause.
21pub struct SubqueryExtraction {
22    /// Semi/anti joins to wrap around the base scan.
23    pub joins: Vec<SubqueryJoin>,
24    /// Remaining WHERE expression with subqueries removed (None if nothing remains).
25    pub remaining_where: Option<Expr>,
26}
27
28/// A subquery that was rewritten as a join.
29pub struct SubqueryJoin {
30    /// The column on the outer table to join on.
31    pub outer_column: String,
32    /// The planned inner SELECT.
33    pub inner_plan: SqlPlan,
34    /// The column from the inner SELECT to join on.
35    pub inner_column: String,
36    /// Semi (IN) or Anti (NOT IN).
37    pub join_type: JoinType,
38}
39
40fn canonical_aggregate_key(function: &str, field: &str) -> String {
41    format!("{function}({field})")
42}
43
44/// Extract `IN (SELECT ...)` and `NOT IN (SELECT ...)` patterns from a WHERE clause.
45///
46/// Returns the extracted subquery joins and the remaining WHERE expression
47/// (with subquery predicates removed). If the entire WHERE is a single
48/// subquery predicate, `remaining_where` is `None`.
49pub fn extract_subqueries(
50    expr: &Expr,
51    catalog: &dyn SqlCatalog,
52    functions: &FunctionRegistry,
53    temporal: crate::TemporalScope,
54) -> Result<SubqueryExtraction> {
55    let mut joins = Vec::new();
56    let remaining = extract_recursive(expr, &mut joins, catalog, functions, temporal)?;
57    Ok(SubqueryExtraction {
58        joins,
59        remaining_where: remaining,
60    })
61}
62
63/// Recursively walk the WHERE expression, extracting subquery predicates.
64///
65/// Returns `None` if the entire expression was consumed (subquery-only),
66/// or `Some(expr)` with the remaining non-subquery predicates.
67fn extract_recursive(
68    expr: &Expr,
69    joins: &mut Vec<SubqueryJoin>,
70    catalog: &dyn SqlCatalog,
71    functions: &FunctionRegistry,
72    temporal: crate::TemporalScope,
73) -> Result<Option<Expr>> {
74    match expr {
75        // AND: recurse both sides, reconstruct with remaining parts.
76        Expr::BinaryOp {
77            left,
78            op: ast::BinaryOperator::And,
79            right,
80        } => {
81            let left_remaining = extract_recursive(left, joins, catalog, functions, temporal)?;
82            let right_remaining = extract_recursive(right, joins, catalog, functions, temporal)?;
83            match (left_remaining, right_remaining) {
84                (None, None) => Ok(None),
85                (Some(l), None) => Ok(Some(l)),
86                (None, Some(r)) => Ok(Some(r)),
87                (Some(l), Some(r)) => Ok(Some(Expr::BinaryOp {
88                    left: Box::new(l),
89                    op: ast::BinaryOperator::And,
90                    right: Box::new(r),
91                })),
92            }
93        }
94
95        // IN (SELECT ...): rewrite as semi-join.
96        Expr::InSubquery {
97            expr: outer_expr,
98            subquery,
99            negated,
100        } => {
101            if let Some(join) =
102                try_plan_in_subquery(outer_expr, subquery, *negated, catalog, functions, temporal)?
103            {
104                joins.push(join);
105                Ok(None) // This predicate is consumed.
106            } else {
107                // Cannot plan as join — return original expression.
108                Ok(Some(expr.clone()))
109            }
110        }
111
112        // Scalar subquery comparison: `col > (SELECT AGG(...) FROM ...)`
113        Expr::BinaryOp { left, op, right } if is_comparison_op(op) => {
114            if let Expr::Subquery(subquery) = right.as_ref() {
115                if let Some(scalar) =
116                    try_plan_scalar_subquery(subquery, catalog, functions, temporal)?
117                {
118                    joins.push(scalar.join);
119                    Ok(Some(Expr::BinaryOp {
120                        left: left.clone(),
121                        op: op.clone(),
122                        right: Box::new(scalar.replacement_expr),
123                    }))
124                } else {
125                    Ok(Some(expr.clone()))
126                }
127            } else {
128                Ok(Some(expr.clone()))
129            }
130        }
131
132        // EXISTS (SELECT ...): rewrite as semi-join.
133        // NOT EXISTS (SELECT ...): rewrite as anti-join.
134        Expr::Exists { subquery, negated } => {
135            if let Some(join) =
136                try_plan_exists_subquery(subquery, *negated, catalog, functions, temporal)?
137            {
138                joins.push(join);
139                Ok(None)
140            } else {
141                Ok(Some(expr.clone()))
142            }
143        }
144
145        // Nested parentheses.
146        Expr::Nested(inner) => extract_recursive(inner, joins, catalog, functions, temporal),
147
148        // Not a subquery pattern — return as-is.
149        _ => Ok(Some(expr.clone())),
150    }
151}
152
153/// Try to plan `col IN (SELECT col2 FROM tbl ...)` as a semi/anti join.
154fn try_plan_in_subquery(
155    outer_expr: &Expr,
156    subquery: &ast::Query,
157    negated: bool,
158    catalog: &dyn SqlCatalog,
159    functions: &FunctionRegistry,
160    temporal: crate::TemporalScope,
161) -> Result<Option<SubqueryJoin>> {
162    // Extract outer column name.
163    let outer_col = match outer_expr {
164        Expr::Identifier(ident) => normalize_ident(ident),
165        Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
166            let qualified: String = parts
167                .iter()
168                .map(normalize_ident)
169                .collect::<Vec<_>>()
170                .join(".");
171            return Err(SqlError::Unsupported {
172                detail: format!(
173                    "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
174                ),
175            });
176        }
177        Expr::CompoundIdentifier(parts) if parts.len() == 2 => normalize_ident(&parts[1]),
178        _ => return Ok(None), // Complex expression, can't rewrite.
179    };
180
181    // Plan the inner SELECT.
182    let inner_plan = super::select::plan_query(subquery, catalog, functions, temporal)?;
183
184    // Extract the projected column from the inner plan.
185    let inner_col = extract_single_projected_column(subquery)?;
186
187    Ok(Some(SubqueryJoin {
188        outer_column: outer_col,
189        inner_plan,
190        inner_column: inner_col,
191        join_type: if negated {
192            JoinType::Anti
193        } else {
194            JoinType::Semi
195        },
196    }))
197}
198
199/// Extract the single column name from a subquery's SELECT list.
200///
201/// For `SELECT user_id FROM orders`, returns `"user_id"`.
202fn extract_single_projected_column(query: &ast::Query) -> Result<String> {
203    let select = match &*query.body {
204        SetExpr::Select(s) => s,
205        _ => {
206            return Err(SqlError::Unsupported {
207                detail: "subquery must be a simple SELECT".into(),
208            });
209        }
210    };
211
212    if select.projection.len() != 1 {
213        return Err(SqlError::Unsupported {
214            detail: format!(
215                "subquery must select exactly 1 column, got {}",
216                select.projection.len()
217            ),
218        });
219    }
220
221    match &select.projection[0] {
222        ast::SelectItem::UnnamedExpr(expr) => match expr {
223            Expr::Identifier(ident) => Ok(normalize_ident(ident)),
224            Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
225                let qualified: String = parts
226                    .iter()
227                    .map(normalize_ident)
228                    .collect::<Vec<_>>()
229                    .join(".");
230                Err(SqlError::Unsupported {
231                    detail: format!(
232                        "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
233                    ),
234                })
235            }
236            Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(normalize_ident(&parts[1])),
237            _ => Err(SqlError::Unsupported {
238                detail: "subquery projection must be a column reference".into(),
239            }),
240        },
241        ast::SelectItem::ExprWithAlias { alias, .. } => Ok(normalize_ident(alias)),
242        _ => Err(SqlError::Unsupported {
243            detail: "subquery projection must be a column reference".into(),
244        }),
245    }
246}
247
248/// Plan `EXISTS (SELECT 1 FROM tbl WHERE tbl.col = outer.col)` as a semi/anti join.
249///
250/// Extracts the correlated column from the subquery's WHERE clause.
251fn try_plan_exists_subquery(
252    subquery: &ast::Query,
253    negated: bool,
254    catalog: &dyn SqlCatalog,
255    functions: &FunctionRegistry,
256    temporal: crate::TemporalScope,
257) -> Result<Option<SubqueryJoin>> {
258    let select = match &*subquery.body {
259        SetExpr::Select(s) => s,
260        _ => return Ok(None),
261    };
262
263    // Look for a correlated predicate in the WHERE: inner.col = outer.col
264    let (outer_col, inner_col) = match &select.selection {
265        Some(expr) => match extract_correlated_eq(expr) {
266            Some(pair) => pair,
267            None => return Ok(None),
268        },
269        None => return Ok(None),
270    };
271
272    // Build a simplified subquery without the correlated predicate for planning.
273    let inner_plan = super::select::plan_query(subquery, catalog, functions, temporal)?;
274
275    Ok(Some(SubqueryJoin {
276        outer_column: outer_col,
277        inner_plan,
278        inner_column: inner_col,
279        join_type: if negated {
280            JoinType::Anti
281        } else {
282            JoinType::Semi
283        },
284    }))
285}
286
287/// Extract a correlated equality predicate from a WHERE clause.
288///
289/// Looks for patterns like `o.user_id = u.id` and returns (outer_col, inner_col).
290/// The "inner" column is the one qualified with the subquery's table alias;
291/// the "outer" column is the one referencing the outer query's table.
292fn extract_correlated_eq(expr: &Expr) -> Option<(String, String)> {
293    match expr {
294        Expr::BinaryOp {
295            left,
296            op: ast::BinaryOperator::Eq,
297            right,
298        } => {
299            let left_parts = extract_qualified_column(left);
300            let right_parts = extract_qualified_column(right);
301            match (left_parts, right_parts) {
302                (Some((_lt, lc)), Some((_rt, rc))) => {
303                    // Convention: left is inner (subquery table), right is outer.
304                    // But we can't distinguish without schema, so just return both.
305                    Some((rc, lc))
306                }
307                _ => None,
308            }
309        }
310        // For AND, try to find a correlated eq in either side.
311        Expr::BinaryOp {
312            left,
313            op: ast::BinaryOperator::And,
314            right,
315        } => extract_correlated_eq(left).or_else(|| extract_correlated_eq(right)),
316        Expr::Nested(inner) => extract_correlated_eq(inner),
317        _ => None,
318    }
319}
320
321/// Extract table.column from a qualified identifier.
322///
323/// Returns `None` for schema-qualified references (`schema.table.col`) — those
324/// are rejected upstream by `convert_expr` when the expression is fully evaluated.
325fn extract_qualified_column(expr: &Expr) -> Option<(String, String)> {
326    match expr {
327        Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
328            // Schema-qualified: cannot extract table/column sensibly.
329            // convert_expr will reject this path with Unsupported.
330            None
331        }
332        Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
333            Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
334        }
335        Expr::Identifier(ident) => Some((String::new(), normalize_ident(ident))),
336        _ => None,
337    }
338}
339
340fn is_comparison_op(op: &ast::BinaryOperator) -> bool {
341    matches!(
342        op,
343        ast::BinaryOperator::Gt
344            | ast::BinaryOperator::GtEq
345            | ast::BinaryOperator::Lt
346            | ast::BinaryOperator::LtEq
347            | ast::BinaryOperator::Eq
348            | ast::BinaryOperator::NotEq
349    )
350}
351
352/// Result of planning a scalar subquery.
353struct ScalarSubqueryResult {
354    join: SubqueryJoin,
355    replacement_expr: Expr,
356}
357
358/// Plan a scalar subquery (e.g., `(SELECT AVG(amount) FROM orders)`).
359///
360/// Rewrites `col > (SELECT AVG(amount) FROM orders)` as:
361///   cross-join with the aggregate result (1 row), then filter `col > result_col`.
362///
363/// The cross-join produces a cartesian product, but since the aggregate returns
364/// exactly 1 row, every outer row gets paired with that single result row.
365fn try_plan_scalar_subquery(
366    subquery: &ast::Query,
367    catalog: &dyn SqlCatalog,
368    functions: &FunctionRegistry,
369    temporal: crate::TemporalScope,
370) -> Result<Option<ScalarSubqueryResult>> {
371    let inner_plan = super::select::plan_query(subquery, catalog, functions, temporal)?;
372
373    // Extract the result column name from the subquery's SELECT list.
374    let result_col = match extract_scalar_column(subquery) {
375        Some(col) => col,
376        None => return Ok(None),
377    };
378
379    let replacement = Expr::Identifier(ast::Ident::new(&result_col));
380
381    Ok(Some(ScalarSubqueryResult {
382        join: SubqueryJoin {
383            outer_column: String::new(),
384            inner_plan,
385            inner_column: String::new(),
386            join_type: JoinType::Cross,
387        },
388        replacement_expr: replacement,
389    }))
390}
391
392/// Extract the projected column name from a scalar subquery.
393///
394/// Handles aliased aggregates like `SELECT AVG(amount) AS avg_amount`.
395/// For unaliased aggregates, returns the canonical aggregate key emitted by
396/// the aggregate executor (e.g. `avg(amount)`, `count(*)`).
397fn extract_scalar_column(query: &ast::Query) -> Option<String> {
398    let select = match &*query.body {
399        SetExpr::Select(s) => s,
400        _ => return None,
401    };
402    if select.projection.len() != 1 {
403        return None;
404    }
405    match &select.projection[0] {
406        ast::SelectItem::ExprWithAlias { alias, .. } => Some(normalize_ident(alias)),
407        ast::SelectItem::UnnamedExpr(expr) => match expr {
408            Expr::Identifier(ident) => Some(normalize_ident(ident)),
409            Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
410                // Schema-qualified: return None to propagate "unsupported" through convert_expr.
411                None
412            }
413            Expr::CompoundIdentifier(parts) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
414            Expr::Function(func) => {
415                let func_name = func
416                    .name
417                    .0
418                    .iter()
419                    .map(|p| match p {
420                        ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
421                        _ => String::new(),
422                    })
423                    .collect::<Vec<_>>()
424                    .join(".")
425                    .to_lowercase();
426                let arg = match &func.args {
427                    ast::FunctionArguments::List(arg_list) => arg_list
428                        .args
429                        .first()
430                        .and_then(|a| match a {
431                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
432                                Expr::Identifier(ident),
433                            )) => Some(normalize_ident(ident)),
434                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
435                                Expr::CompoundIdentifier(parts),
436                            )) if parts.len() >= 3 => None,
437                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
438                                Expr::CompoundIdentifier(parts),
439                            )) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
440                            ast::FunctionArg::Unnamed(
441                                ast::FunctionArgExpr::Wildcard
442                                | ast::FunctionArgExpr::QualifiedWildcard(_),
443                            ) => Some("all".to_string()),
444                            _ => None,
445                        })
446                        .unwrap_or_else(|| "*".to_string()),
447                    _ => "*".to_string(),
448                };
449                Some(canonical_aggregate_key(&func_name, &arg))
450            }
451            _ => None,
452        },
453        _ => None,
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::extract_scalar_column;
460    use crate::parser::statement::parse_sql;
461    use sqlparser::ast::Statement;
462
463    #[test]
464    fn unaliased_scalar_aggregate_uses_canonical_aggregate_key() {
465        let statements = parse_sql("SELECT AVG(amount) FROM orders").unwrap();
466        let Statement::Query(query) = &statements[0] else {
467            panic!("expected query");
468        };
469        assert_eq!(extract_scalar_column(query), Some("avg(amount)".into()));
470    }
471}