Skip to main content

nodedb_sql/planner/
cte.rs

1//! CTE (WITH clause) and WITH RECURSIVE planning.
2
3use sqlparser::ast::{self, Query, SetExpr};
4
5use crate::error::{Result, SqlError};
6use crate::functions::registry::FunctionRegistry;
7use crate::parser::normalize::{normalize_ident, normalize_object_name};
8use crate::types::*;
9
10/// Plan a WITH RECURSIVE query.
11///
12/// Supports table-based recursive CTEs where the base case scans a real
13/// collection and the recursive step references both the collection and
14/// the CTE. Value-generating CTEs (no underlying collection) return an
15/// explicit unsupported error.
16pub fn plan_recursive_cte(
17    query: &Query,
18    catalog: &dyn SqlCatalog,
19    functions: &FunctionRegistry,
20) -> Result<SqlPlan> {
21    let with = query.with.as_ref().ok_or_else(|| SqlError::Parse {
22        detail: "expected WITH clause".into(),
23    })?;
24
25    let cte = with.cte_tables.first().ok_or_else(|| SqlError::Parse {
26        detail: "empty WITH clause".into(),
27    })?;
28
29    let cte_name = normalize_ident(&cte.alias.name);
30
31    let cte_query = &cte.query;
32
33    // The CTE body should be a UNION of base case and recursive case.
34    let (left, right, set_quantifier) = match &*cte_query.body {
35        SetExpr::SetOperation {
36            op: ast::SetOperator::Union,
37            left,
38            right,
39            set_quantifier,
40        } => (left, right, set_quantifier),
41        _ => {
42            return Err(SqlError::Unsupported {
43                detail: "WITH RECURSIVE requires UNION in CTE body".into(),
44            });
45        }
46    };
47
48    // UNION ALL → distinct = false; UNION → distinct = true.
49    let distinct = !matches!(set_quantifier, ast::SetQuantifier::All);
50
51    // Plan the base case (should not reference the CTE name).
52    let base = plan_cte_branch(left, catalog, functions)?;
53
54    // Extract the source collection from the base case.
55    let collection = extract_collection(&base).unwrap_or_default();
56
57    // Plan the recursive branch. The recursive branch references the CTE
58    // name in its FROM clause — either directly (value-gen) or via a JOIN
59    // with a real table. We attempt to plan it; if it fails because the
60    // CTE name isn't in the catalog, we try to extract the real table from
61    // a JOIN and use it with the CTE self-reference as the recursive filter.
62    let (recursive_filters, join_link) = match plan_cte_branch(right, catalog, functions) {
63        Ok(plan) => (extract_filters(&plan), None),
64        Err(_) => {
65            // The recursive branch references the CTE name. Try to extract
66            // the real collection, filters, and join link from the AST.
67            extract_recursive_info(right, &cte_name)?
68        }
69    };
70
71    if collection.is_empty() {
72        return Err(SqlError::Unsupported {
73            detail: "WITH RECURSIVE requires a base case that scans a collection; \
74                     value-generating recursive CTEs are not yet supported"
75                .into(),
76        });
77    }
78
79    Ok(SqlPlan::RecursiveScan {
80        collection,
81        base_filters: extract_filters(&base),
82        recursive_filters,
83        join_link,
84        max_iterations: 100,
85        distinct,
86        limit: 10000,
87    })
88}
89
90/// Extract recursive info from the AST when normal planning fails
91/// because the FROM clause references the CTE name.
92///
93/// Returns `(filters, join_link)` where `join_link` is the
94/// `(collection_field, working_table_field)` pair for the working-table
95/// hash-join.
96///
97/// Handles the common tree-traversal pattern:
98/// `SELECT t.id FROM tree t INNER JOIN cte_name d ON t.parent_id = d.id`
99/// → join_link = `("parent_id", "id")`
100/// `(filters, join_link)` where `join_link` is `(collection_field, working_table_field)`.
101type RecursiveInfo = (Vec<Filter>, Option<(String, String)>);
102
103fn extract_recursive_info(expr: &SetExpr, cte_name: &str) -> Result<RecursiveInfo> {
104    let select = match expr {
105        SetExpr::Select(s) => s,
106        _ => {
107            return Err(SqlError::Unsupported {
108                detail: "recursive CTE branch must be SELECT".into(),
109            });
110        }
111    };
112
113    let mut real_table_alias = None;
114    let mut cte_alias = None;
115    let mut join_on_expr = None;
116
117    for from in &select.from {
118        let table_name = extract_table_name(&from.relation);
119        let table_alias = extract_table_alias(&from.relation);
120
121        if let Some(name) = &table_name {
122            if name.eq_ignore_ascii_case(cte_name) {
123                cte_alias = table_alias.or_else(|| Some(name.clone()));
124            } else {
125                real_table_alias = table_alias.or_else(|| Some(name.clone()));
126            }
127        }
128
129        for join in &from.joins {
130            let join_table = extract_table_name(&join.relation);
131            let join_alias = extract_table_alias(&join.relation);
132            if let Some(jt) = &join_table {
133                if jt.eq_ignore_ascii_case(cte_name) {
134                    cte_alias = join_alias.or_else(|| Some(jt.clone()));
135                    if let Some(cond) = extract_join_on_condition(&join.join_operator) {
136                        join_on_expr = Some(cond.clone());
137                    }
138                } else {
139                    real_table_alias = join_alias.or_else(|| Some(jt.clone()));
140                    if join_on_expr.is_none()
141                        && let Some(cond) = extract_join_on_condition(&join.join_operator)
142                    {
143                        join_on_expr = Some(cond.clone());
144                    }
145                }
146            }
147        }
148    }
149
150    // Extract the join link from the ON condition.
151    let join_link = if let (Some(real_alias), Some(cte_al), Some(on_expr)) =
152        (&real_table_alias, &cte_alias, &join_on_expr)
153    {
154        extract_equi_link(on_expr, real_alias, cte_al)
155    } else {
156        None
157    };
158
159    // Convert the WHERE clause to filters if present.
160    let mut filters = Vec::new();
161    if let Some(where_expr) = &select.selection {
162        let converted = crate::resolver::expr::convert_expr(where_expr)?;
163        filters.push(Filter {
164            expr: FilterExpr::Expr(converted),
165        });
166    }
167
168    Ok((filters, join_link))
169}
170
171/// Extract `(collection_field, cte_field)` from an equi-join ON clause.
172///
173/// Given `t.parent_id = d.id` where `t` is the real table alias and `d`
174/// is the CTE alias, returns `("parent_id", "id")`.
175fn extract_equi_link(
176    expr: &ast::Expr,
177    real_alias: &str,
178    cte_alias: &str,
179) -> Option<(String, String)> {
180    match expr {
181        ast::Expr::BinaryOp {
182            left,
183            op: ast::BinaryOperator::Eq,
184            right,
185        } => {
186            let left_parts = extract_qualified_column(left)?;
187            let right_parts = extract_qualified_column(right)?;
188
189            // Determine which side is the real table and which is the CTE.
190            if left_parts.0.eq_ignore_ascii_case(real_alias)
191                && right_parts.0.eq_ignore_ascii_case(cte_alias)
192            {
193                Some((left_parts.1, right_parts.1))
194            } else if right_parts.0.eq_ignore_ascii_case(real_alias)
195                && left_parts.0.eq_ignore_ascii_case(cte_alias)
196            {
197                Some((right_parts.1, left_parts.1))
198            } else {
199                None
200            }
201        }
202        // For AND-combined conditions, take the first equi-link found.
203        ast::Expr::BinaryOp {
204            left,
205            op: ast::BinaryOperator::And,
206            right,
207        } => extract_equi_link(left, real_alias, cte_alias)
208            .or_else(|| extract_equi_link(right, real_alias, cte_alias)),
209        _ => None,
210    }
211}
212
213/// Extract `(table_or_alias, column)` from a qualified column reference.
214fn extract_qualified_column(expr: &ast::Expr) -> Option<(String, String)> {
215    match expr {
216        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
217            Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
218        }
219        _ => None,
220    }
221}
222
223fn extract_table_name(relation: &ast::TableFactor) -> Option<String> {
224    match relation {
225        ast::TableFactor::Table { name, .. } => Some(normalize_object_name(name)),
226        _ => None,
227    }
228}
229
230fn extract_table_alias(relation: &ast::TableFactor) -> Option<String> {
231    match relation {
232        ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
233        _ => None,
234    }
235}
236
237fn extract_join_on_condition(op: &ast::JoinOperator) -> Option<&ast::Expr> {
238    use ast::JoinOperator::*;
239    let constraint = match op {
240        Inner(c) | LeftOuter(c) | RightOuter(c) | FullOuter(c) => c,
241        _ => return None,
242    };
243    match constraint {
244        ast::JoinConstraint::On(expr) => Some(expr),
245        _ => None,
246    }
247}
248
249fn plan_cte_branch(
250    expr: &SetExpr,
251    catalog: &dyn SqlCatalog,
252    functions: &FunctionRegistry,
253) -> Result<SqlPlan> {
254    match expr {
255        SetExpr::Select(select) => {
256            let query = Query {
257                with: None,
258                body: Box::new(SetExpr::Select(select.clone())),
259                order_by: None,
260                limit_clause: None,
261                fetch: None,
262                locks: Vec::new(),
263                for_clause: None,
264                settings: None,
265                format_clause: None,
266                pipe_operators: Vec::new(),
267            };
268            super::select::plan_query(&query, catalog, functions)
269        }
270        _ => Err(SqlError::Unsupported {
271            detail: "CTE branch must be SELECT".into(),
272        }),
273    }
274}
275
276fn extract_collection(plan: &SqlPlan) -> Option<String> {
277    match plan {
278        SqlPlan::Scan { collection, .. } => Some(collection.clone()),
279        _ => None,
280    }
281}
282
283fn extract_filters(plan: &SqlPlan) -> Vec<Filter> {
284    match plan {
285        SqlPlan::Scan { filters, .. } => filters.clone(),
286        _ => Vec::new(),
287    }
288}