Skip to main content

nodedb_sql/planner/
cte.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! CTE (WITH clause) and WITH RECURSIVE planning.
4
5use sqlparser::ast::{self, Query, SetExpr};
6
7use crate::error::{Result, SqlError};
8use crate::functions::registry::FunctionRegistry;
9use crate::parser::normalize::{normalize_ident, normalize_object_name_checked};
10use crate::types::*;
11
12/// Default maximum recursion depth for WITH RECURSIVE queries.
13pub const DEFAULT_MAX_RECURSION_DEPTH: usize = 1000;
14
15/// Plan a WITH RECURSIVE query.
16///
17/// Dispatches to either `plan_recursive_scan` (collection-backed) or
18/// `plan_recursive_value` (pure expression / value-generating) based on
19/// whether the anchor arm references a real collection.
20pub fn plan_recursive_cte(
21    query: &Query,
22    catalog: &dyn SqlCatalog,
23    functions: &FunctionRegistry,
24    temporal: crate::TemporalScope,
25) -> Result<SqlPlan> {
26    let with = query.with.as_ref().ok_or_else(|| SqlError::Parse {
27        detail: "expected WITH clause".into(),
28    })?;
29
30    let cte = with.cte_tables.first().ok_or_else(|| SqlError::Parse {
31        detail: "empty WITH clause".into(),
32    })?;
33
34    let cte_name = normalize_ident(&cte.alias.name);
35    let declared_columns: Vec<String> = cte
36        .alias
37        .columns
38        .iter()
39        .map(|c| normalize_ident(&c.name))
40        .collect();
41
42    let cte_query = &cte.query;
43
44    // Validate set operator: only UNION / UNION ALL permitted.
45    let (left, right, set_quantifier) = match &*cte_query.body {
46        SetExpr::SetOperation {
47            op: ast::SetOperator::Union,
48            left,
49            right,
50            set_quantifier,
51        } => (left, right, set_quantifier),
52        SetExpr::SetOperation { op, .. } => {
53            return Err(SqlError::InvalidRecursiveSetOp {
54                op: format!("{op}"),
55            });
56        }
57        _ => {
58            return Err(SqlError::InvalidRecursiveSetOp {
59                op: "non-set-operation".into(),
60            });
61        }
62    };
63
64    // Validate self-reference count in the recursive arm.
65    validate_self_ref_count(right, &cte_name)?;
66
67    let distinct = !matches!(set_quantifier, ast::SetQuantifier::All);
68
69    // Try to detect whether this is a collection-backed or value-generating CTE
70    // by attempting to plan the anchor arm against the catalog.
71    match plan_cte_branch(left, catalog, functions, temporal) {
72        Ok(base) => {
73            let collection = extract_collection(&base);
74            if collection.is_empty() {
75                // Anchor planned but produced no collection → treat as value-gen.
76                plan_recursive_value(left, right, &cte_name, &declared_columns, distinct)
77            } else {
78                plan_recursive_scan_from_parts(
79                    &cte_name,
80                    &base,
81                    &RecursiveParts {
82                        left,
83                        right,
84                        declared_columns: &declared_columns,
85                        distinct,
86                    },
87                    catalog,
88                    functions,
89                    temporal,
90                )
91            }
92        }
93        Err(_) => {
94            // Anchor references CTE name or uses value expressions → value-gen.
95            plan_recursive_value(left, right, &cte_name, &declared_columns, distinct)
96        }
97    }
98}
99
100// ── Collection-backed recursive scan ─────────────────────────────────────────
101
102struct RecursiveParts<'a> {
103    left: &'a SetExpr,
104    right: &'a SetExpr,
105    declared_columns: &'a [String],
106    distinct: bool,
107}
108
109fn plan_recursive_scan_from_parts(
110    cte_name: &str,
111    base: &SqlPlan,
112    parts: &RecursiveParts<'_>,
113    catalog: &dyn SqlCatalog,
114    functions: &FunctionRegistry,
115    temporal: crate::TemporalScope,
116) -> Result<SqlPlan> {
117    let RecursiveParts {
118        left,
119        right,
120        declared_columns,
121        distinct,
122    } = parts;
123    let collection = extract_collection(base);
124
125    // Validate column count if columns were declared.
126    if !declared_columns.is_empty() {
127        let anchor_cols = count_select_cols(left);
128        if anchor_cols != 0 && anchor_cols != declared_columns.len() {
129            return Err(SqlError::RecursiveColumnMismatch {
130                cte_name: cte_name.to_owned(),
131                anchor_cols,
132                declared_cols: declared_columns.len(),
133            });
134        }
135    }
136
137    let (recursive_filters, join_link) = match plan_cte_branch(right, catalog, functions, temporal)
138    {
139        Ok(plan) => (extract_filters(&plan), None),
140        Err(_) => extract_recursive_info(right, cte_name)?,
141    };
142
143    Ok(SqlPlan::RecursiveScan {
144        collection,
145        base_filters: extract_filters(base),
146        recursive_filters,
147        join_link,
148        max_iterations: DEFAULT_MAX_RECURSION_DEPTH,
149        distinct: *distinct,
150        limit: 10000,
151    })
152}
153
154// ── Value-generating recursive CTE ───────────────────────────────────────────
155
156/// Plan a value-generating WITH RECURSIVE CTE (no collection reference).
157///
158/// Produces a `SqlPlan::RecursiveValue` that carries the anchor and step
159/// expressions as raw SQL text for evaluation in the Data Plane.
160fn plan_recursive_value(
161    left: &SetExpr,
162    right: &SetExpr,
163    cte_name: &str,
164    declared_columns: &[String],
165    distinct: bool,
166) -> Result<SqlPlan> {
167    let init_exprs = extract_select_exprs_as_text(left).ok_or_else(|| SqlError::Parse {
168        detail: "WITH RECURSIVE anchor must be a SELECT".into(),
169    })?;
170
171    // Validate column count against declared columns list.
172    if !declared_columns.is_empty() && init_exprs.len() != declared_columns.len() {
173        return Err(SqlError::RecursiveColumnMismatch {
174            cte_name: cte_name.to_owned(),
175            anchor_cols: init_exprs.len(),
176            declared_cols: declared_columns.len(),
177        });
178    }
179
180    let (step_exprs, condition) =
181        extract_step_exprs_and_condition(right).ok_or_else(|| SqlError::Parse {
182            detail: "WITH RECURSIVE step must be a SELECT".into(),
183        })?;
184
185    // Infer column names from anchor if not declared.
186    let columns = if declared_columns.is_empty() {
187        // Default column names: col0, col1, ...
188        (0..init_exprs.len()).map(|i| format!("col{i}")).collect()
189    } else {
190        declared_columns.to_vec()
191    };
192
193    Ok(SqlPlan::RecursiveValue {
194        cte_name: cte_name.to_owned(),
195        columns,
196        init_exprs,
197        step_exprs,
198        condition,
199        max_depth: DEFAULT_MAX_RECURSION_DEPTH,
200        distinct,
201    })
202}
203
204/// Extract SELECT projection items as raw SQL text strings.
205fn extract_select_exprs_as_text(expr: &SetExpr) -> Option<Vec<String>> {
206    let select = match expr {
207        SetExpr::Select(s) => s,
208        _ => return None,
209    };
210    Some(
211        select
212            .projection
213            .iter()
214            .map(|item| match item {
215                ast::SelectItem::UnnamedExpr(e) => format!("{e}"),
216                ast::SelectItem::ExprWithAlias { expr: e, .. } => format!("{e}"),
217                ast::SelectItem::Wildcard(_) => "*".into(),
218                ast::SelectItem::QualifiedWildcard(name, _) => format!("{name}.*"),
219            })
220            .collect(),
221    )
222}
223
224/// Extract step SELECT expressions and optional WHERE condition as SQL text.
225///
226/// Returns `(step_exprs, condition)`.
227fn extract_step_exprs_and_condition(expr: &SetExpr) -> Option<(Vec<String>, Option<String>)> {
228    let select = match expr {
229        SetExpr::Select(s) => s,
230        _ => return None,
231    };
232    let step_exprs = select
233        .projection
234        .iter()
235        .map(|item| match item {
236            ast::SelectItem::UnnamedExpr(e) => format!("{e}"),
237            ast::SelectItem::ExprWithAlias { expr: e, .. } => format!("{e}"),
238            ast::SelectItem::Wildcard(_) => "*".into(),
239            ast::SelectItem::QualifiedWildcard(name, _) => format!("{name}.*"),
240        })
241        .collect();
242    let condition = select.selection.as_ref().map(|e| format!("{e}"));
243    Some((step_exprs, condition))
244}
245
246// ── Validation ────────────────────────────────────────────────────────────────
247
248/// Count SELECT projection columns; returns 0 if the expression is not a SELECT.
249fn count_select_cols(expr: &SetExpr) -> usize {
250    match expr {
251        SetExpr::Select(s) => s.projection.len(),
252        _ => 0,
253    }
254}
255
256/// Validate that the CTE name appears exactly once in the recursive arm and
257/// not inside a subquery, aggregate function, or the nullable side of an outer join.
258///
259/// Returns `Ok(())` if the reference is valid, or a typed error otherwise.
260fn validate_self_ref_count(expr: &SetExpr, cte_name: &str) -> Result<()> {
261    let select = match expr {
262        SetExpr::Select(s) => s,
263        // Non-SELECT arm: no self-ref needed.
264        _ => return Ok(()),
265    };
266
267    let mut count = 0usize;
268
269    for from in &select.from {
270        if table_ref_matches(&from.relation, cte_name) {
271            count += 1;
272        }
273        for join in &from.joins {
274            if table_ref_matches(&join.relation, cte_name) {
275                // Reject self-ref on the nullable side of an outer join.
276                if is_nullable_join_side(&join.join_operator) {
277                    return Err(SqlError::InvalidRecursiveSelfRef {
278                        cte_name: cte_name.to_owned(),
279                        reason: "self-reference on the nullable side of an outer join is not \
280                                 permitted; use INNER JOIN or move the CTE reference to the \
281                                 driving table position"
282                            .into(),
283                    });
284                }
285                count += 1;
286            }
287        }
288    }
289
290    // Subquery self-references are not permitted.
291    if where_contains_subquery_ref(&select.selection, cte_name) {
292        return Err(SqlError::InvalidRecursiveSelfRef {
293            cte_name: cte_name.to_owned(),
294            reason: "self-reference inside a subquery is not permitted".into(),
295        });
296    }
297
298    if count > 1 {
299        return Err(SqlError::InvalidRecursiveSelfRef {
300            cte_name: cte_name.to_owned(),
301            reason: format!("self-reference appears {count} times; exactly one is required"),
302        });
303    }
304
305    // count == 0 is fine for the value-generating case (no table ref at all).
306    Ok(())
307}
308
309fn table_ref_matches(factor: &ast::TableFactor, cte_name: &str) -> bool {
310    match factor {
311        ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name)
312            .map(|n| n.eq_ignore_ascii_case(cte_name))
313            .unwrap_or(false),
314        _ => false,
315    }
316}
317
318fn is_nullable_join_side(op: &ast::JoinOperator) -> bool {
319    use ast::JoinOperator::*;
320    matches!(op, LeftOuter(_) | RightOuter(_) | FullOuter(_))
321}
322
323fn where_contains_subquery_ref(selection: &Option<ast::Expr>, cte_name: &str) -> bool {
324    match selection {
325        None => false,
326        Some(e) => expr_contains_subquery_ref(e, cte_name),
327    }
328}
329
330fn expr_contains_subquery_ref(expr: &ast::Expr, cte_name: &str) -> bool {
331    match expr {
332        ast::Expr::InSubquery { subquery, .. } | ast::Expr::Exists { subquery, .. } => {
333            query_references_cte(subquery, cte_name)
334        }
335        ast::Expr::Subquery(q) => query_references_cte(q, cte_name),
336        ast::Expr::BinaryOp { left, right, .. } => {
337            expr_contains_subquery_ref(left, cte_name)
338                || expr_contains_subquery_ref(right, cte_name)
339        }
340        ast::Expr::Nested(inner) => expr_contains_subquery_ref(inner, cte_name),
341        _ => false,
342    }
343}
344
345fn query_references_cte(query: &Query, cte_name: &str) -> bool {
346    match &*query.body {
347        SetExpr::Select(s) => s.from.iter().any(|f| {
348            table_ref_matches(&f.relation, cte_name)
349                || f.joins
350                    .iter()
351                    .any(|j| table_ref_matches(&j.relation, cte_name))
352        }),
353        _ => false,
354    }
355}
356
357// ── Helpers shared with collection-backed path ────────────────────────────────
358
359/// Extract recursive info from the AST when normal planning fails
360/// because the FROM clause references the CTE name.
361///
362/// Returns `(filters, join_link)` where `join_link` is the
363/// `(collection_field, working_table_field)` pair for the working-table
364/// hash-join.
365type RecursiveInfo = (Vec<Filter>, Option<(String, String)>);
366
367fn extract_recursive_info(expr: &SetExpr, cte_name: &str) -> Result<RecursiveInfo> {
368    let select = match expr {
369        SetExpr::Select(s) => s,
370        _ => {
371            return Err(SqlError::Unsupported {
372                detail: "recursive CTE branch must be SELECT".into(),
373            });
374        }
375    };
376
377    let mut real_table_alias = None;
378    let mut cte_alias = None;
379    let mut join_on_expr = None;
380
381    for from in &select.from {
382        let table_name = extract_table_name(&from.relation);
383        let table_alias = extract_table_alias(&from.relation);
384
385        if let Some(name) = &table_name {
386            if name.eq_ignore_ascii_case(cte_name) {
387                cte_alias = table_alias.or_else(|| Some(name.clone()));
388            } else {
389                real_table_alias = table_alias.or_else(|| Some(name.clone()));
390            }
391        }
392
393        for join in &from.joins {
394            let join_table = extract_table_name(&join.relation);
395            let join_alias = extract_table_alias(&join.relation);
396            if let Some(jt) = &join_table {
397                if jt.eq_ignore_ascii_case(cte_name) {
398                    cte_alias = join_alias.or_else(|| Some(jt.clone()));
399                    if let Some(cond) = extract_join_on_condition(&join.join_operator) {
400                        join_on_expr = Some(cond.clone());
401                    }
402                } else {
403                    real_table_alias = join_alias.or_else(|| Some(jt.clone()));
404                    if join_on_expr.is_none()
405                        && let Some(cond) = extract_join_on_condition(&join.join_operator)
406                    {
407                        join_on_expr = Some(cond.clone());
408                    }
409                }
410            }
411        }
412    }
413
414    // Extract the join link from the ON condition.
415    let join_link = if let (Some(real_alias), Some(cte_al), Some(on_expr)) =
416        (&real_table_alias, &cte_alias, &join_on_expr)
417    {
418        extract_equi_link(on_expr, real_alias, cte_al)
419    } else {
420        None
421    };
422
423    let mut filters = Vec::new();
424    if let Some(where_expr) = &select.selection {
425        let converted = crate::resolver::expr::convert_expr(where_expr)?;
426        filters.push(Filter {
427            expr: FilterExpr::Expr(converted),
428        });
429    }
430
431    Ok((filters, join_link))
432}
433
434/// Extract `(collection_field, cte_field)` from an equi-join ON clause.
435fn extract_equi_link(
436    expr: &ast::Expr,
437    real_alias: &str,
438    cte_alias: &str,
439) -> Option<(String, String)> {
440    match expr {
441        ast::Expr::BinaryOp {
442            left,
443            op: ast::BinaryOperator::Eq,
444            right,
445        } => {
446            let left_parts = extract_qualified_column(left)?;
447            let right_parts = extract_qualified_column(right)?;
448
449            if left_parts.0.eq_ignore_ascii_case(real_alias)
450                && right_parts.0.eq_ignore_ascii_case(cte_alias)
451            {
452                Some((left_parts.1, right_parts.1))
453            } else if right_parts.0.eq_ignore_ascii_case(real_alias)
454                && left_parts.0.eq_ignore_ascii_case(cte_alias)
455            {
456                Some((right_parts.1, left_parts.1))
457            } else {
458                None
459            }
460        }
461        ast::Expr::BinaryOp {
462            left,
463            op: ast::BinaryOperator::And,
464            right,
465        } => extract_equi_link(left, real_alias, cte_alias)
466            .or_else(|| extract_equi_link(right, real_alias, cte_alias)),
467        _ => None,
468    }
469}
470
471fn extract_qualified_column(expr: &ast::Expr) -> Option<(String, String)> {
472    match expr {
473        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
474            Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
475        }
476        _ => None,
477    }
478}
479
480fn extract_table_name(relation: &ast::TableFactor) -> Option<String> {
481    match relation {
482        ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name).ok(),
483        _ => None,
484    }
485}
486
487fn extract_table_alias(relation: &ast::TableFactor) -> Option<String> {
488    match relation {
489        ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
490        _ => None,
491    }
492}
493
494fn extract_join_on_condition(op: &ast::JoinOperator) -> Option<&ast::Expr> {
495    use ast::JoinOperator::*;
496    let constraint = match op {
497        Inner(c) | LeftOuter(c) | RightOuter(c) | FullOuter(c) => c,
498        _ => return None,
499    };
500    match constraint {
501        ast::JoinConstraint::On(expr) => Some(expr),
502        _ => None,
503    }
504}
505
506fn plan_cte_branch(
507    expr: &SetExpr,
508    catalog: &dyn SqlCatalog,
509    functions: &FunctionRegistry,
510    temporal: crate::TemporalScope,
511) -> Result<SqlPlan> {
512    match expr {
513        SetExpr::Select(select) => {
514            let query = Query {
515                with: None,
516                body: Box::new(SetExpr::Select(select.clone())),
517                order_by: None,
518                limit_clause: None,
519                fetch: None,
520                locks: Vec::new(),
521                for_clause: None,
522                settings: None,
523                format_clause: None,
524                pipe_operators: Vec::new(),
525            };
526            super::select::plan_query(&query, catalog, functions, temporal)
527        }
528        _ => Err(SqlError::Unsupported {
529            detail: "CTE branch must be SELECT".into(),
530        }),
531    }
532}
533
534fn extract_collection(plan: &SqlPlan) -> String {
535    match plan {
536        SqlPlan::Scan { collection, .. } => collection.clone(),
537        _ => String::new(),
538    }
539}
540
541fn extract_filters(plan: &SqlPlan) -> Vec<Filter> {
542    match plan {
543        SqlPlan::Scan { filters, .. } => filters.clone(),
544        _ => Vec::new(),
545    }
546}