1use sqlparser::ast::{self, Query, SetExpr};
4
5use crate::error::{Result, SqlError};
6use crate::functions::registry::FunctionRegistry;
7use crate::parser::normalize::{normalize_ident, normalize_object_name};
8use crate::types::*;
9
10pub fn plan_recursive_cte(
17 query: &Query,
18 catalog: &dyn SqlCatalog,
19 functions: &FunctionRegistry,
20) -> Result<SqlPlan> {
21 let with = query.with.as_ref().ok_or_else(|| SqlError::Parse {
22 detail: "expected WITH clause".into(),
23 })?;
24
25 let cte = with.cte_tables.first().ok_or_else(|| SqlError::Parse {
26 detail: "empty WITH clause".into(),
27 })?;
28
29 let cte_name = normalize_ident(&cte.alias.name);
30
31 let cte_query = &cte.query;
32
33 let (left, right, set_quantifier) = match &*cte_query.body {
35 SetExpr::SetOperation {
36 op: ast::SetOperator::Union,
37 left,
38 right,
39 set_quantifier,
40 } => (left, right, set_quantifier),
41 _ => {
42 return Err(SqlError::Unsupported {
43 detail: "WITH RECURSIVE requires UNION in CTE body".into(),
44 });
45 }
46 };
47
48 let distinct = !matches!(set_quantifier, ast::SetQuantifier::All);
50
51 let base = plan_cte_branch(left, catalog, functions)?;
53
54 let collection = extract_collection(&base).unwrap_or_default();
56
57 let (recursive_filters, join_link) = match plan_cte_branch(right, catalog, functions) {
63 Ok(plan) => (extract_filters(&plan), None),
64 Err(_) => {
65 extract_recursive_info(right, &cte_name)?
68 }
69 };
70
71 if collection.is_empty() {
72 return Err(SqlError::Unsupported {
73 detail: "WITH RECURSIVE requires a base case that scans a collection; \
74 value-generating recursive CTEs are not yet supported"
75 .into(),
76 });
77 }
78
79 Ok(SqlPlan::RecursiveScan {
80 collection,
81 base_filters: extract_filters(&base),
82 recursive_filters,
83 join_link,
84 max_iterations: 100,
85 distinct,
86 limit: 10000,
87 })
88}
89
90type RecursiveInfo = (Vec<Filter>, Option<(String, String)>);
102
103fn extract_recursive_info(expr: &SetExpr, cte_name: &str) -> Result<RecursiveInfo> {
104 let select = match expr {
105 SetExpr::Select(s) => s,
106 _ => {
107 return Err(SqlError::Unsupported {
108 detail: "recursive CTE branch must be SELECT".into(),
109 });
110 }
111 };
112
113 let mut real_table_alias = None;
114 let mut cte_alias = None;
115 let mut join_on_expr = None;
116
117 for from in &select.from {
118 let table_name = extract_table_name(&from.relation);
119 let table_alias = extract_table_alias(&from.relation);
120
121 if let Some(name) = &table_name {
122 if name.eq_ignore_ascii_case(cte_name) {
123 cte_alias = table_alias.or_else(|| Some(name.clone()));
124 } else {
125 real_table_alias = table_alias.or_else(|| Some(name.clone()));
126 }
127 }
128
129 for join in &from.joins {
130 let join_table = extract_table_name(&join.relation);
131 let join_alias = extract_table_alias(&join.relation);
132 if let Some(jt) = &join_table {
133 if jt.eq_ignore_ascii_case(cte_name) {
134 cte_alias = join_alias.or_else(|| Some(jt.clone()));
135 if let Some(cond) = extract_join_on_condition(&join.join_operator) {
136 join_on_expr = Some(cond.clone());
137 }
138 } else {
139 real_table_alias = join_alias.or_else(|| Some(jt.clone()));
140 if join_on_expr.is_none()
141 && let Some(cond) = extract_join_on_condition(&join.join_operator)
142 {
143 join_on_expr = Some(cond.clone());
144 }
145 }
146 }
147 }
148 }
149
150 let join_link = if let (Some(real_alias), Some(cte_al), Some(on_expr)) =
152 (&real_table_alias, &cte_alias, &join_on_expr)
153 {
154 extract_equi_link(on_expr, real_alias, cte_al)
155 } else {
156 None
157 };
158
159 let mut filters = Vec::new();
161 if let Some(where_expr) = &select.selection {
162 let converted = crate::resolver::expr::convert_expr(where_expr)?;
163 filters.push(Filter {
164 expr: FilterExpr::Expr(converted),
165 });
166 }
167
168 Ok((filters, join_link))
169}
170
171fn extract_equi_link(
176 expr: &ast::Expr,
177 real_alias: &str,
178 cte_alias: &str,
179) -> Option<(String, String)> {
180 match expr {
181 ast::Expr::BinaryOp {
182 left,
183 op: ast::BinaryOperator::Eq,
184 right,
185 } => {
186 let left_parts = extract_qualified_column(left)?;
187 let right_parts = extract_qualified_column(right)?;
188
189 if left_parts.0.eq_ignore_ascii_case(real_alias)
191 && right_parts.0.eq_ignore_ascii_case(cte_alias)
192 {
193 Some((left_parts.1, right_parts.1))
194 } else if right_parts.0.eq_ignore_ascii_case(real_alias)
195 && left_parts.0.eq_ignore_ascii_case(cte_alias)
196 {
197 Some((right_parts.1, left_parts.1))
198 } else {
199 None
200 }
201 }
202 ast::Expr::BinaryOp {
204 left,
205 op: ast::BinaryOperator::And,
206 right,
207 } => extract_equi_link(left, real_alias, cte_alias)
208 .or_else(|| extract_equi_link(right, real_alias, cte_alias)),
209 _ => None,
210 }
211}
212
213fn extract_qualified_column(expr: &ast::Expr) -> Option<(String, String)> {
215 match expr {
216 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
217 Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
218 }
219 _ => None,
220 }
221}
222
223fn extract_table_name(relation: &ast::TableFactor) -> Option<String> {
224 match relation {
225 ast::TableFactor::Table { name, .. } => Some(normalize_object_name(name)),
226 _ => None,
227 }
228}
229
230fn extract_table_alias(relation: &ast::TableFactor) -> Option<String> {
231 match relation {
232 ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
233 _ => None,
234 }
235}
236
237fn extract_join_on_condition(op: &ast::JoinOperator) -> Option<&ast::Expr> {
238 use ast::JoinOperator::*;
239 let constraint = match op {
240 Inner(c) | LeftOuter(c) | RightOuter(c) | FullOuter(c) => c,
241 _ => return None,
242 };
243 match constraint {
244 ast::JoinConstraint::On(expr) => Some(expr),
245 _ => None,
246 }
247}
248
249fn plan_cte_branch(
250 expr: &SetExpr,
251 catalog: &dyn SqlCatalog,
252 functions: &FunctionRegistry,
253) -> Result<SqlPlan> {
254 match expr {
255 SetExpr::Select(select) => {
256 let query = Query {
257 with: None,
258 body: Box::new(SetExpr::Select(select.clone())),
259 order_by: None,
260 limit_clause: None,
261 fetch: None,
262 locks: Vec::new(),
263 for_clause: None,
264 settings: None,
265 format_clause: None,
266 pipe_operators: Vec::new(),
267 };
268 super::select::plan_query(&query, catalog, functions)
269 }
270 _ => Err(SqlError::Unsupported {
271 detail: "CTE branch must be SELECT".into(),
272 }),
273 }
274}
275
276fn extract_collection(plan: &SqlPlan) -> Option<String> {
277 match plan {
278 SqlPlan::Scan { collection, .. } => Some(collection.clone()),
279 _ => None,
280 }
281}
282
283fn extract_filters(plan: &SqlPlan) -> Vec<Filter> {
284 match plan {
285 SqlPlan::Scan { filters, .. } => filters.clone(),
286 _ => Vec::new(),
287 }
288}