1use sqlparser::ast::{self, Query, SetExpr};
6
7use crate::error::{Result, SqlError};
8use crate::functions::registry::FunctionRegistry;
9use crate::parser::normalize::{normalize_ident, normalize_object_name_checked};
10use crate::types::*;
11
12pub const DEFAULT_MAX_RECURSION_DEPTH: usize = 1000;
14
15pub fn plan_recursive_cte(
21 query: &Query,
22 catalog: &dyn SqlCatalog,
23 functions: &FunctionRegistry,
24 temporal: crate::TemporalScope,
25) -> Result<SqlPlan> {
26 let with = query.with.as_ref().ok_or_else(|| SqlError::Parse {
27 detail: "expected WITH clause".into(),
28 })?;
29
30 let cte = with.cte_tables.first().ok_or_else(|| SqlError::Parse {
31 detail: "empty WITH clause".into(),
32 })?;
33
34 let cte_name = normalize_ident(&cte.alias.name);
35 let declared_columns: Vec<String> = cte
36 .alias
37 .columns
38 .iter()
39 .map(|c| normalize_ident(&c.name))
40 .collect();
41
42 let cte_query = &cte.query;
43
44 let (left, right, set_quantifier) = match &*cte_query.body {
46 SetExpr::SetOperation {
47 op: ast::SetOperator::Union,
48 left,
49 right,
50 set_quantifier,
51 } => (left, right, set_quantifier),
52 SetExpr::SetOperation { op, .. } => {
53 return Err(SqlError::InvalidRecursiveSetOp {
54 op: format!("{op}"),
55 });
56 }
57 _ => {
58 return Err(SqlError::InvalidRecursiveSetOp {
59 op: "non-set-operation".into(),
60 });
61 }
62 };
63
64 validate_self_ref_count(right, &cte_name)?;
66
67 let distinct = !matches!(set_quantifier, ast::SetQuantifier::All);
68
69 match plan_cte_branch(left, catalog, functions, temporal) {
72 Ok(base) => {
73 let collection = extract_collection(&base);
74 if collection.is_empty() {
75 plan_recursive_value(left, right, &cte_name, &declared_columns, distinct)
77 } else {
78 plan_recursive_scan_from_parts(
79 &cte_name,
80 &base,
81 &RecursiveParts {
82 left,
83 right,
84 declared_columns: &declared_columns,
85 distinct,
86 },
87 catalog,
88 functions,
89 temporal,
90 )
91 }
92 }
93 Err(_) => {
94 plan_recursive_value(left, right, &cte_name, &declared_columns, distinct)
96 }
97 }
98}
99
100struct RecursiveParts<'a> {
103 left: &'a SetExpr,
104 right: &'a SetExpr,
105 declared_columns: &'a [String],
106 distinct: bool,
107}
108
109fn plan_recursive_scan_from_parts(
110 cte_name: &str,
111 base: &SqlPlan,
112 parts: &RecursiveParts<'_>,
113 catalog: &dyn SqlCatalog,
114 functions: &FunctionRegistry,
115 temporal: crate::TemporalScope,
116) -> Result<SqlPlan> {
117 let RecursiveParts {
118 left,
119 right,
120 declared_columns,
121 distinct,
122 } = parts;
123 let collection = extract_collection(base);
124
125 if !declared_columns.is_empty() {
127 let anchor_cols = count_select_cols(left);
128 if anchor_cols != 0 && anchor_cols != declared_columns.len() {
129 return Err(SqlError::RecursiveColumnMismatch {
130 cte_name: cte_name.to_owned(),
131 anchor_cols,
132 declared_cols: declared_columns.len(),
133 });
134 }
135 }
136
137 let (recursive_filters, join_link) = match plan_cte_branch(right, catalog, functions, temporal)
138 {
139 Ok(plan) => (extract_filters(&plan), None),
140 Err(_) => extract_recursive_info(right, cte_name)?,
141 };
142
143 Ok(SqlPlan::RecursiveScan {
144 collection,
145 base_filters: extract_filters(base),
146 recursive_filters,
147 join_link,
148 max_iterations: DEFAULT_MAX_RECURSION_DEPTH,
149 distinct: *distinct,
150 limit: 10000,
151 })
152}
153
154fn plan_recursive_value(
161 left: &SetExpr,
162 right: &SetExpr,
163 cte_name: &str,
164 declared_columns: &[String],
165 distinct: bool,
166) -> Result<SqlPlan> {
167 let init_exprs = extract_select_exprs_as_text(left).ok_or_else(|| SqlError::Parse {
168 detail: "WITH RECURSIVE anchor must be a SELECT".into(),
169 })?;
170
171 if !declared_columns.is_empty() && init_exprs.len() != declared_columns.len() {
173 return Err(SqlError::RecursiveColumnMismatch {
174 cte_name: cte_name.to_owned(),
175 anchor_cols: init_exprs.len(),
176 declared_cols: declared_columns.len(),
177 });
178 }
179
180 let (step_exprs, condition) =
181 extract_step_exprs_and_condition(right).ok_or_else(|| SqlError::Parse {
182 detail: "WITH RECURSIVE step must be a SELECT".into(),
183 })?;
184
185 let columns = if declared_columns.is_empty() {
187 (0..init_exprs.len()).map(|i| format!("col{i}")).collect()
189 } else {
190 declared_columns.to_vec()
191 };
192
193 Ok(SqlPlan::RecursiveValue {
194 cte_name: cte_name.to_owned(),
195 columns,
196 init_exprs,
197 step_exprs,
198 condition,
199 max_depth: DEFAULT_MAX_RECURSION_DEPTH,
200 distinct,
201 })
202}
203
204fn extract_select_exprs_as_text(expr: &SetExpr) -> Option<Vec<String>> {
206 let select = match expr {
207 SetExpr::Select(s) => s,
208 _ => return None,
209 };
210 Some(
211 select
212 .projection
213 .iter()
214 .map(|item| match item {
215 ast::SelectItem::UnnamedExpr(e) => format!("{e}"),
216 ast::SelectItem::ExprWithAlias { expr: e, .. } => format!("{e}"),
217 ast::SelectItem::Wildcard(_) => "*".into(),
218 ast::SelectItem::QualifiedWildcard(name, _) => format!("{name}.*"),
219 })
220 .collect(),
221 )
222}
223
224fn extract_step_exprs_and_condition(expr: &SetExpr) -> Option<(Vec<String>, Option<String>)> {
228 let select = match expr {
229 SetExpr::Select(s) => s,
230 _ => return None,
231 };
232 let step_exprs = select
233 .projection
234 .iter()
235 .map(|item| match item {
236 ast::SelectItem::UnnamedExpr(e) => format!("{e}"),
237 ast::SelectItem::ExprWithAlias { expr: e, .. } => format!("{e}"),
238 ast::SelectItem::Wildcard(_) => "*".into(),
239 ast::SelectItem::QualifiedWildcard(name, _) => format!("{name}.*"),
240 })
241 .collect();
242 let condition = select.selection.as_ref().map(|e| format!("{e}"));
243 Some((step_exprs, condition))
244}
245
246fn count_select_cols(expr: &SetExpr) -> usize {
250 match expr {
251 SetExpr::Select(s) => s.projection.len(),
252 _ => 0,
253 }
254}
255
256fn validate_self_ref_count(expr: &SetExpr, cte_name: &str) -> Result<()> {
261 let select = match expr {
262 SetExpr::Select(s) => s,
263 _ => return Ok(()),
265 };
266
267 let mut count = 0usize;
268
269 for from in &select.from {
270 if table_ref_matches(&from.relation, cte_name) {
271 count += 1;
272 }
273 for join in &from.joins {
274 if table_ref_matches(&join.relation, cte_name) {
275 if is_nullable_join_side(&join.join_operator) {
277 return Err(SqlError::InvalidRecursiveSelfRef {
278 cte_name: cte_name.to_owned(),
279 reason: "self-reference on the nullable side of an outer join is not \
280 permitted; use INNER JOIN or move the CTE reference to the \
281 driving table position"
282 .into(),
283 });
284 }
285 count += 1;
286 }
287 }
288 }
289
290 if where_contains_subquery_ref(&select.selection, cte_name) {
292 return Err(SqlError::InvalidRecursiveSelfRef {
293 cte_name: cte_name.to_owned(),
294 reason: "self-reference inside a subquery is not permitted".into(),
295 });
296 }
297
298 if count > 1 {
299 return Err(SqlError::InvalidRecursiveSelfRef {
300 cte_name: cte_name.to_owned(),
301 reason: format!("self-reference appears {count} times; exactly one is required"),
302 });
303 }
304
305 Ok(())
307}
308
309fn table_ref_matches(factor: &ast::TableFactor, cte_name: &str) -> bool {
310 match factor {
311 ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name)
312 .map(|n| n.eq_ignore_ascii_case(cte_name))
313 .unwrap_or(false),
314 _ => false,
315 }
316}
317
318fn is_nullable_join_side(op: &ast::JoinOperator) -> bool {
319 use ast::JoinOperator::*;
320 matches!(op, LeftOuter(_) | RightOuter(_) | FullOuter(_))
321}
322
323fn where_contains_subquery_ref(selection: &Option<ast::Expr>, cte_name: &str) -> bool {
324 match selection {
325 None => false,
326 Some(e) => expr_contains_subquery_ref(e, cte_name),
327 }
328}
329
330fn expr_contains_subquery_ref(expr: &ast::Expr, cte_name: &str) -> bool {
331 match expr {
332 ast::Expr::InSubquery { subquery, .. } | ast::Expr::Exists { subquery, .. } => {
333 query_references_cte(subquery, cte_name)
334 }
335 ast::Expr::Subquery(q) => query_references_cte(q, cte_name),
336 ast::Expr::BinaryOp { left, right, .. } => {
337 expr_contains_subquery_ref(left, cte_name)
338 || expr_contains_subquery_ref(right, cte_name)
339 }
340 ast::Expr::Nested(inner) => expr_contains_subquery_ref(inner, cte_name),
341 _ => false,
342 }
343}
344
345fn query_references_cte(query: &Query, cte_name: &str) -> bool {
346 match &*query.body {
347 SetExpr::Select(s) => s.from.iter().any(|f| {
348 table_ref_matches(&f.relation, cte_name)
349 || f.joins
350 .iter()
351 .any(|j| table_ref_matches(&j.relation, cte_name))
352 }),
353 _ => false,
354 }
355}
356
357type RecursiveInfo = (Vec<Filter>, Option<(String, String)>);
366
367fn extract_recursive_info(expr: &SetExpr, cte_name: &str) -> Result<RecursiveInfo> {
368 let select = match expr {
369 SetExpr::Select(s) => s,
370 _ => {
371 return Err(SqlError::Unsupported {
372 detail: "recursive CTE branch must be SELECT".into(),
373 });
374 }
375 };
376
377 let mut real_table_alias = None;
378 let mut cte_alias = None;
379 let mut join_on_expr = None;
380
381 for from in &select.from {
382 let table_name = extract_table_name(&from.relation);
383 let table_alias = extract_table_alias(&from.relation);
384
385 if let Some(name) = &table_name {
386 if name.eq_ignore_ascii_case(cte_name) {
387 cte_alias = table_alias.or_else(|| Some(name.clone()));
388 } else {
389 real_table_alias = table_alias.or_else(|| Some(name.clone()));
390 }
391 }
392
393 for join in &from.joins {
394 let join_table = extract_table_name(&join.relation);
395 let join_alias = extract_table_alias(&join.relation);
396 if let Some(jt) = &join_table {
397 if jt.eq_ignore_ascii_case(cte_name) {
398 cte_alias = join_alias.or_else(|| Some(jt.clone()));
399 if let Some(cond) = extract_join_on_condition(&join.join_operator) {
400 join_on_expr = Some(cond.clone());
401 }
402 } else {
403 real_table_alias = join_alias.or_else(|| Some(jt.clone()));
404 if join_on_expr.is_none()
405 && let Some(cond) = extract_join_on_condition(&join.join_operator)
406 {
407 join_on_expr = Some(cond.clone());
408 }
409 }
410 }
411 }
412 }
413
414 let join_link = if let (Some(real_alias), Some(cte_al), Some(on_expr)) =
416 (&real_table_alias, &cte_alias, &join_on_expr)
417 {
418 extract_equi_link(on_expr, real_alias, cte_al)
419 } else {
420 None
421 };
422
423 let mut filters = Vec::new();
424 if let Some(where_expr) = &select.selection {
425 let converted = crate::resolver::expr::convert_expr(where_expr)?;
426 filters.push(Filter {
427 expr: FilterExpr::Expr(converted),
428 });
429 }
430
431 Ok((filters, join_link))
432}
433
434fn extract_equi_link(
436 expr: &ast::Expr,
437 real_alias: &str,
438 cte_alias: &str,
439) -> Option<(String, String)> {
440 match expr {
441 ast::Expr::BinaryOp {
442 left,
443 op: ast::BinaryOperator::Eq,
444 right,
445 } => {
446 let left_parts = extract_qualified_column(left)?;
447 let right_parts = extract_qualified_column(right)?;
448
449 if left_parts.0.eq_ignore_ascii_case(real_alias)
450 && right_parts.0.eq_ignore_ascii_case(cte_alias)
451 {
452 Some((left_parts.1, right_parts.1))
453 } else if right_parts.0.eq_ignore_ascii_case(real_alias)
454 && left_parts.0.eq_ignore_ascii_case(cte_alias)
455 {
456 Some((right_parts.1, left_parts.1))
457 } else {
458 None
459 }
460 }
461 ast::Expr::BinaryOp {
462 left,
463 op: ast::BinaryOperator::And,
464 right,
465 } => extract_equi_link(left, real_alias, cte_alias)
466 .or_else(|| extract_equi_link(right, real_alias, cte_alias)),
467 _ => None,
468 }
469}
470
471fn extract_qualified_column(expr: &ast::Expr) -> Option<(String, String)> {
472 match expr {
473 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
474 Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
475 }
476 _ => None,
477 }
478}
479
480fn extract_table_name(relation: &ast::TableFactor) -> Option<String> {
481 match relation {
482 ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name).ok(),
483 _ => None,
484 }
485}
486
487fn extract_table_alias(relation: &ast::TableFactor) -> Option<String> {
488 match relation {
489 ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
490 _ => None,
491 }
492}
493
494fn extract_join_on_condition(op: &ast::JoinOperator) -> Option<&ast::Expr> {
495 use ast::JoinOperator::*;
496 let constraint = match op {
497 Inner(c) | LeftOuter(c) | RightOuter(c) | FullOuter(c) => c,
498 _ => return None,
499 };
500 match constraint {
501 ast::JoinConstraint::On(expr) => Some(expr),
502 _ => None,
503 }
504}
505
506fn plan_cte_branch(
507 expr: &SetExpr,
508 catalog: &dyn SqlCatalog,
509 functions: &FunctionRegistry,
510 temporal: crate::TemporalScope,
511) -> Result<SqlPlan> {
512 match expr {
513 SetExpr::Select(select) => {
514 let query = Query {
515 with: None,
516 body: Box::new(SetExpr::Select(select.clone())),
517 order_by: None,
518 limit_clause: None,
519 fetch: None,
520 locks: Vec::new(),
521 for_clause: None,
522 settings: None,
523 format_clause: None,
524 pipe_operators: Vec::new(),
525 };
526 super::select::plan_query(&query, catalog, functions, temporal)
527 }
528 _ => Err(SqlError::Unsupported {
529 detail: "CTE branch must be SELECT".into(),
530 }),
531 }
532}
533
534fn extract_collection(plan: &SqlPlan) -> String {
535 match plan {
536 SqlPlan::Scan { collection, .. } => collection.clone(),
537 _ => String::new(),
538 }
539}
540
541fn extract_filters(plan: &SqlPlan) -> Vec<Filter> {
542 match plan {
543 SqlPlan::Scan { filters, .. } => filters.clone(),
544 _ => Vec::new(),
545 }
546}