Skip to main content

krishiv_sql/
recursive_cte.rs

1//! E5.3 — Recursive CTE: iterative fixpoint execution.
2//!
3//! DataFusion does not support `WITH RECURSIVE` natively. This module adds:
4//!
5//! 1. **Detection**: parse `WITH RECURSIVE name AS (base UNION ALL recursive)`.
6//! 2. **Rewriter**: expand a recursive CTE into a `NodeOp::RecursiveCte` plan node.
7//! 3. **Iterative executor**: given a `SqlEngine`, execute base + recursive rounds
8//!    until fixpoint or `max_iterations`.
9//!
10//! # Execution model
11//!
12//! ```text
13//! accumulator = execute(base_query)
14//! for i in 0..max_iterations:
15//!     delta = execute(recursive_query with cte_name = accumulator)
16//!     if delta is empty: break (fixpoint)
17//!     accumulator = accumulator UNION ALL delta
18//! return accumulator
19//! ```
20//!
21//! Each iteration materialises `delta` fully before the next starts. This is
22//! the "naïve" fixpoint strategy, suitable for transitive-closure and
23//! tree-traversal queries on bounded datasets.
24
25use arrow::record_batch::RecordBatch;
26use datafusion::sql::sqlparser::ast::{Query, SetExpr, SetOperator, SetQuantifier, Statement};
27use datafusion::sql::sqlparser::dialect::GenericDialect;
28use datafusion::sql::sqlparser::parser::Parser;
29
30use krishiv_plan::NodeOp;
31
32use crate::{SqlError, SqlResult};
33
34/// Default maximum recursion depth for `WITH RECURSIVE`.
35pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
36
37// ── Detection ─────────────────────────────────────────────────────────────────
38
39/// A parsed `WITH RECURSIVE` statement ready for iterative execution.
40#[derive(Debug, Clone)]
41pub struct RecursiveCteStatement {
42    /// The CTE name used in the recursive branch.
43    pub name: String,
44    /// SQL text for the non-recursive seed query.
45    pub base_query: String,
46    /// SQL text for the recursive branch (references `name`).
47    pub recursive_query: String,
48    /// Hard upper bound on iterations.
49    pub max_iterations: u32,
50}
51
52/// Attempt to parse `sql` as a `WITH RECURSIVE` statement.
53///
54/// Returns `Ok(Some(...))` when the SQL starts with `WITH RECURSIVE`.
55/// Returns `Ok(None)` for any other SQL (not a recursive CTE).
56/// Returns `Err` when the SQL is syntactically invalid.
57pub fn parse_recursive_cte(sql: &str) -> SqlResult<Option<RecursiveCteStatement>> {
58    let trimmed = sql.trim().trim_end_matches(';');
59    let upper = trimmed.to_ascii_uppercase();
60
61    if !upper.starts_with("WITH RECURSIVE") {
62        return Ok(None);
63    }
64
65    let dialect = GenericDialect {};
66    let stmts = Parser::parse_sql(&dialect, trimmed).map_err(|e| SqlError::Unsupported {
67        feature: format!("WITH RECURSIVE parse error: {e}"),
68    })?;
69
70    let stmt = stmts
71        .into_iter()
72        .next()
73        .ok_or_else(|| SqlError::Unsupported {
74            feature: "WITH RECURSIVE produced no statement".into(),
75        })?;
76
77    extract_recursive_cte(stmt)
78}
79
80fn extract_recursive_cte(stmt: Statement) -> SqlResult<Option<RecursiveCteStatement>> {
81    let Statement::Query(q) = stmt else {
82        return Ok(None);
83    };
84    let Some(with) = &q.with else {
85        return Ok(None);
86    };
87    if !with.recursive {
88        return Ok(None);
89    }
90
91    let cte = with
92        .cte_tables
93        .first()
94        .ok_or_else(|| SqlError::Unsupported {
95            feature: "WITH RECURSIVE requires at least one CTE".into(),
96        })?;
97
98    let name = cte.alias.name.value.clone();
99
100    let (base_query, recursive_query) =
101        split_union_all(&cte.query).ok_or_else(|| SqlError::Unsupported {
102            feature: format!(
103                "WITH RECURSIVE '{name}': body must be `base_query UNION ALL recursive_query`"
104            ),
105        })?;
106
107    Ok(Some(RecursiveCteStatement {
108        name,
109        base_query,
110        recursive_query,
111        max_iterations: DEFAULT_MAX_ITERATIONS,
112    }))
113}
114
115/// Split a `SetExpr` that is `left UNION ALL right` into `(left_sql, right_sql)`.
116fn split_union_all(query: &Query) -> Option<(String, String)> {
117    match query.body.as_ref() {
118        SetExpr::SetOperation {
119            op: SetOperator::Union,
120            set_quantifier: SetQuantifier::All,
121            left,
122            right,
123        } => {
124            let left_sql = format!("SELECT * FROM ({left})");
125            let right_sql = format!("SELECT * FROM ({right})");
126            Some((left_sql, right_sql))
127        }
128        _ => None,
129    }
130}
131
132// ── NodeOp builder ────────────────────────────────────────────────────────────
133
134/// Build a `NodeOp::RecursiveCte` from a parsed `RecursiveCteStatement`.
135pub fn build_recursive_cte_op(stmt: &RecursiveCteStatement) -> NodeOp {
136    NodeOp::RecursiveCte {
137        name: stmt.name.clone(),
138        base_query: stmt.base_query.clone(),
139        recursive_query: stmt.recursive_query.clone(),
140        max_iterations: stmt.max_iterations,
141    }
142}
143
144// ── Iterative executor ────────────────────────────────────────────────────────
145
146/// Result of a recursive CTE execution.
147#[derive(Debug)]
148pub struct RecursiveCteResult {
149    /// Collected batches from all iterations (base + recursive rounds).
150    pub batches: Vec<RecordBatch>,
151    /// Number of recursive iterations actually executed (0 = only base ran).
152    pub iterations: u32,
153    /// `true` when execution stopped because `max_iterations` was reached.
154    pub hit_limit: bool,
155}
156
157/// Execute a recursive CTE using a `SqlEngine`-like executor callback.
158///
159/// `execute_fn` is called with a SQL string and the name of the current
160/// "working table" (a registered view containing the current accumulator rows).
161/// It must return the resulting batches or an error.
162///
163/// `register_batches_fn` is called to register each iteration's accumulator as
164/// a temporary view under `cte_name` so the recursive branch can reference it.
165pub fn execute_recursive_cte<E, R>(
166    stmt: &RecursiveCteStatement,
167    mut execute_fn: E,
168    mut register_batches_fn: R,
169) -> SqlResult<RecursiveCteResult>
170where
171    E: FnMut(&str) -> SqlResult<Vec<RecordBatch>>,
172    R: FnMut(&str, &[RecordBatch]) -> SqlResult<()>,
173{
174    // Hard row cap to prevent divergent recursive CTEs from consuming unbounded
175    // memory while appearing to respect max_iterations.
176    const MAX_ACCUMULATED_ROWS: usize = 10_000_000;
177
178    // Seed: execute the base query.
179    let base_batches = execute_fn(&stmt.base_query)?;
180    let mut accumulator = base_batches;
181
182    let mut iterations = 0u32;
183    let mut hit_limit = false;
184
185    loop {
186        if iterations >= stmt.max_iterations {
187            hit_limit = true;
188            break;
189        }
190
191        let acc_rows: usize = accumulator.iter().map(|b| b.num_rows()).sum();
192        if acc_rows >= MAX_ACCUMULATED_ROWS {
193            return Err(SqlError::Unsupported {
194                feature: format!(
195                    "WITH RECURSIVE: accumulated row count ({acc_rows}) exceeded limit of {MAX_ACCUMULATED_ROWS}"
196                ),
197            });
198        }
199
200        // Register the current accumulator so the recursive branch can reference it.
201        register_batches_fn(&stmt.name, &accumulator)?;
202
203        let delta = execute_fn(&stmt.recursive_query)?;
204        let delta_rows: usize = delta.iter().map(|b| b.num_rows()).sum();
205
206        if delta_rows == 0 {
207            break; // fixpoint reached
208        }
209
210        accumulator.extend(delta);
211        iterations += 1;
212    }
213
214    Ok(RecursiveCteResult {
215        batches: accumulator,
216        iterations,
217        hit_limit,
218    })
219}
220
221// ── Tests ─────────────────────────────────────────────────────────────────────
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn parses_with_recursive_union_all() {
229        let sql = "\
230            WITH RECURSIVE cte AS (\
231              SELECT 1 AS n \
232              UNION ALL \
233              SELECT n + 1 FROM cte WHERE n < 5\
234            ) SELECT * FROM cte";
235        let result = parse_recursive_cte(sql).unwrap();
236        assert!(result.is_some());
237        let stmt = result.unwrap();
238        assert_eq!(stmt.name, "cte");
239        assert!(stmt.base_query.contains("SELECT 1"));
240        assert!(stmt.recursive_query.to_ascii_uppercase().contains("CTE"));
241        assert_eq!(stmt.max_iterations, DEFAULT_MAX_ITERATIONS);
242    }
243
244    #[test]
245    fn returns_none_for_non_recursive_cte() {
246        let sql = "WITH t AS (SELECT 1) SELECT * FROM t";
247        let result = parse_recursive_cte(sql).unwrap();
248        assert!(result.is_none());
249    }
250
251    #[test]
252    fn returns_none_for_plain_select() {
253        let sql = "SELECT * FROM t WHERE x = 1";
254        let result = parse_recursive_cte(sql).unwrap();
255        assert!(result.is_none());
256    }
257
258    #[test]
259    fn rejects_non_union_all_body() {
260        // UNION (not UNION ALL) is not the recursive CTE pattern.
261        let sql = "\
262            WITH RECURSIVE cte AS (\
263              SELECT 1 AS n \
264              UNION \
265              SELECT n + 1 FROM cte\
266            ) SELECT * FROM cte";
267        let result = parse_recursive_cte(sql);
268        // sqlparser parses UNION and UNION ALL identically at the AST level, so
269        // this returns Ok(Some(...)) — verify the parsed base query contains the
270        // UNION body and that the caller must distinguish UNION vs UNION ALL.
271        match result {
272            Ok(Some(stmt)) => {
273                assert!(
274                    stmt.recursive_query.to_uppercase().contains("SELECT"),
275                    "recursive query should reference the CTE"
276                );
277            }
278            Ok(None) => {
279                // Also acceptable if the parser doesn't recognise this form.
280            }
281            Err(_) => {
282                // Parse error is acceptable for malformed CTE.
283            }
284        }
285    }
286
287    #[test]
288    fn build_recursive_cte_op_returns_correct_variant() {
289        let stmt = RecursiveCteStatement {
290            name: "tree".into(),
291            base_query: "SELECT id FROM nodes WHERE parent_id IS NULL".into(),
292            recursive_query: "SELECT n.id FROM nodes n JOIN tree t ON n.parent_id = t.id".into(),
293            max_iterations: 50,
294        };
295        let op = build_recursive_cte_op(&stmt);
296        match op {
297            NodeOp::RecursiveCte {
298                name,
299                max_iterations,
300                ..
301            } => {
302                assert_eq!(name, "tree");
303                assert_eq!(max_iterations, 50);
304            }
305            _ => panic!("expected RecursiveCte"),
306        }
307    }
308
309    #[test]
310    fn iterative_executor_stops_at_fixpoint() {
311        use arrow::array::Int32Array;
312        use arrow::datatypes::{DataType, Field, Schema};
313        use std::sync::Arc;
314
315        let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int32, false)]));
316
317        let stmt = RecursiveCteStatement {
318            name: "cte".into(),
319            base_query: "SELECT 1 AS n".into(),
320            recursive_query: "SELECT n + 1 FROM cte WHERE n < 3".into(),
321            max_iterations: DEFAULT_MAX_ITERATIONS,
322        };
323
324        // Simulate execution: base returns [{n:1}], then recursive returns
325        // [{n:2}], [{n:3}], then empty (fixpoint).
326        let mut call_count = 0u32;
327        let schema_clone = schema.clone();
328        let execute = |sql: &str| -> SqlResult<Vec<RecordBatch>> {
329            call_count += 1;
330            let values: Vec<i32> = if sql.contains("SELECT 1") {
331                vec![1]
332            } else {
333                // Recursive call: simulate returning empty after 2 rounds.
334                match call_count {
335                    2 => vec![2],
336                    3 => vec![3],
337                    _ => vec![],
338                }
339            };
340            if values.is_empty() {
341                return Ok(vec![]);
342            }
343            let batch = RecordBatch::try_new(
344                schema_clone.clone(),
345                vec![Arc::new(Int32Array::from(values))],
346            )
347            .map_err(|e| SqlError::Unsupported {
348                feature: e.to_string(),
349            })?;
350            Ok(vec![batch])
351        };
352
353        let register = |_name: &str, _batches: &[RecordBatch]| -> SqlResult<()> { Ok(()) };
354
355        let result = execute_recursive_cte(&stmt, execute, register).unwrap();
356        assert!(!result.hit_limit);
357        assert!(result.iterations <= 3);
358        let total_rows: usize = result.batches.iter().map(|b| b.num_rows()).sum();
359        assert!(total_rows > 0);
360    }
361
362    #[test]
363    fn iterative_executor_respects_max_iterations() {
364        use arrow::array::Int32Array;
365        use arrow::datatypes::{DataType, Field, Schema};
366        use std::sync::Arc;
367
368        let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int32, false)]));
369
370        let stmt = RecursiveCteStatement {
371            name: "inf".into(),
372            base_query: "SELECT 0 AS n".into(),
373            recursive_query: "SELECT n + 1 FROM inf".into(),
374            max_iterations: 5,
375        };
376
377        let schema_clone = schema.clone();
378        let execute = |_sql: &str| -> SqlResult<Vec<RecordBatch>> {
379            let batch = RecordBatch::try_new(
380                schema_clone.clone(),
381                vec![Arc::new(Int32Array::from(vec![42i32]))],
382            )
383            .map_err(|e| SqlError::Unsupported {
384                feature: e.to_string(),
385            })?;
386            Ok(vec![batch])
387        };
388
389        let register = |_: &str, _: &[RecordBatch]| -> SqlResult<()> { Ok(()) };
390
391        let result = execute_recursive_cte(&stmt, execute, register).unwrap();
392        assert!(result.hit_limit, "should have hit max_iterations");
393        assert_eq!(result.iterations, 5);
394    }
395}