1use arrow::record_batch::RecordBatch;
26use datafusion::sql::sqlparser::ast::{Query, SetExpr, SetOperator, SetQuantifier, Statement};
27use datafusion::sql::sqlparser::dialect::GenericDialect;
28use datafusion::sql::sqlparser::parser::Parser;
29
30use krishiv_plan::NodeOp;
31
32use crate::{SqlError, SqlResult};
33
34pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
36
37#[derive(Debug, Clone)]
41pub struct RecursiveCteStatement {
42 pub name: String,
44 pub base_query: String,
46 pub recursive_query: String,
48 pub max_iterations: u32,
50}
51
52pub fn parse_recursive_cte(sql: &str) -> SqlResult<Option<RecursiveCteStatement>> {
58 let trimmed = sql.trim().trim_end_matches(';');
59 let upper = trimmed.to_ascii_uppercase();
60
61 if !upper.starts_with("WITH RECURSIVE") {
62 return Ok(None);
63 }
64
65 let dialect = GenericDialect {};
66 let stmts = Parser::parse_sql(&dialect, trimmed).map_err(|e| SqlError::Unsupported {
67 feature: format!("WITH RECURSIVE parse error: {e}"),
68 })?;
69
70 let stmt = stmts
71 .into_iter()
72 .next()
73 .ok_or_else(|| SqlError::Unsupported {
74 feature: "WITH RECURSIVE produced no statement".into(),
75 })?;
76
77 extract_recursive_cte(stmt)
78}
79
80fn extract_recursive_cte(stmt: Statement) -> SqlResult<Option<RecursiveCteStatement>> {
81 let Statement::Query(q) = stmt else {
82 return Ok(None);
83 };
84 let Some(with) = &q.with else {
85 return Ok(None);
86 };
87 if !with.recursive {
88 return Ok(None);
89 }
90
91 let cte = with
92 .cte_tables
93 .first()
94 .ok_or_else(|| SqlError::Unsupported {
95 feature: "WITH RECURSIVE requires at least one CTE".into(),
96 })?;
97
98 let name = cte.alias.name.value.clone();
99
100 let (base_query, recursive_query) =
101 split_union_all(&cte.query).ok_or_else(|| SqlError::Unsupported {
102 feature: format!(
103 "WITH RECURSIVE '{name}': body must be `base_query UNION ALL recursive_query`"
104 ),
105 })?;
106
107 Ok(Some(RecursiveCteStatement {
108 name,
109 base_query,
110 recursive_query,
111 max_iterations: DEFAULT_MAX_ITERATIONS,
112 }))
113}
114
115fn split_union_all(query: &Query) -> Option<(String, String)> {
117 match query.body.as_ref() {
118 SetExpr::SetOperation {
119 op: SetOperator::Union,
120 set_quantifier: SetQuantifier::All,
121 left,
122 right,
123 } => {
124 let left_sql = format!("SELECT * FROM ({left})");
125 let right_sql = format!("SELECT * FROM ({right})");
126 Some((left_sql, right_sql))
127 }
128 _ => None,
129 }
130}
131
132pub fn build_recursive_cte_op(stmt: &RecursiveCteStatement) -> NodeOp {
136 NodeOp::RecursiveCte {
137 name: stmt.name.clone(),
138 base_query: stmt.base_query.clone(),
139 recursive_query: stmt.recursive_query.clone(),
140 max_iterations: stmt.max_iterations,
141 }
142}
143
144#[derive(Debug)]
148pub struct RecursiveCteResult {
149 pub batches: Vec<RecordBatch>,
151 pub iterations: u32,
153 pub hit_limit: bool,
155}
156
157pub fn execute_recursive_cte<E, R>(
166 stmt: &RecursiveCteStatement,
167 mut execute_fn: E,
168 mut register_batches_fn: R,
169) -> SqlResult<RecursiveCteResult>
170where
171 E: FnMut(&str) -> SqlResult<Vec<RecordBatch>>,
172 R: FnMut(&str, &[RecordBatch]) -> SqlResult<()>,
173{
174 const MAX_ACCUMULATED_ROWS: usize = 10_000_000;
177
178 let base_batches = execute_fn(&stmt.base_query)?;
180 let mut accumulator = base_batches;
181
182 let mut iterations = 0u32;
183 let mut hit_limit = false;
184
185 loop {
186 if iterations >= stmt.max_iterations {
187 hit_limit = true;
188 break;
189 }
190
191 let acc_rows: usize = accumulator.iter().map(|b| b.num_rows()).sum();
192 if acc_rows >= MAX_ACCUMULATED_ROWS {
193 return Err(SqlError::Unsupported {
194 feature: format!(
195 "WITH RECURSIVE: accumulated row count ({acc_rows}) exceeded limit of {MAX_ACCUMULATED_ROWS}"
196 ),
197 });
198 }
199
200 register_batches_fn(&stmt.name, &accumulator)?;
202
203 let delta = execute_fn(&stmt.recursive_query)?;
204 let delta_rows: usize = delta.iter().map(|b| b.num_rows()).sum();
205
206 if delta_rows == 0 {
207 break; }
209
210 accumulator.extend(delta);
211 iterations += 1;
212 }
213
214 Ok(RecursiveCteResult {
215 batches: accumulator,
216 iterations,
217 hit_limit,
218 })
219}
220
221#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn parses_with_recursive_union_all() {
229 let sql = "\
230 WITH RECURSIVE cte AS (\
231 SELECT 1 AS n \
232 UNION ALL \
233 SELECT n + 1 FROM cte WHERE n < 5\
234 ) SELECT * FROM cte";
235 let result = parse_recursive_cte(sql).unwrap();
236 assert!(result.is_some());
237 let stmt = result.unwrap();
238 assert_eq!(stmt.name, "cte");
239 assert!(stmt.base_query.contains("SELECT 1"));
240 assert!(stmt.recursive_query.to_ascii_uppercase().contains("CTE"));
241 assert_eq!(stmt.max_iterations, DEFAULT_MAX_ITERATIONS);
242 }
243
244 #[test]
245 fn returns_none_for_non_recursive_cte() {
246 let sql = "WITH t AS (SELECT 1) SELECT * FROM t";
247 let result = parse_recursive_cte(sql).unwrap();
248 assert!(result.is_none());
249 }
250
251 #[test]
252 fn returns_none_for_plain_select() {
253 let sql = "SELECT * FROM t WHERE x = 1";
254 let result = parse_recursive_cte(sql).unwrap();
255 assert!(result.is_none());
256 }
257
258 #[test]
259 fn rejects_non_union_all_body() {
260 let sql = "\
262 WITH RECURSIVE cte AS (\
263 SELECT 1 AS n \
264 UNION \
265 SELECT n + 1 FROM cte\
266 ) SELECT * FROM cte";
267 let result = parse_recursive_cte(sql);
268 match result {
272 Ok(Some(stmt)) => {
273 assert!(
274 stmt.recursive_query.to_uppercase().contains("SELECT"),
275 "recursive query should reference the CTE"
276 );
277 }
278 Ok(None) => {
279 }
281 Err(_) => {
282 }
284 }
285 }
286
287 #[test]
288 fn build_recursive_cte_op_returns_correct_variant() {
289 let stmt = RecursiveCteStatement {
290 name: "tree".into(),
291 base_query: "SELECT id FROM nodes WHERE parent_id IS NULL".into(),
292 recursive_query: "SELECT n.id FROM nodes n JOIN tree t ON n.parent_id = t.id".into(),
293 max_iterations: 50,
294 };
295 let op = build_recursive_cte_op(&stmt);
296 match op {
297 NodeOp::RecursiveCte {
298 name,
299 max_iterations,
300 ..
301 } => {
302 assert_eq!(name, "tree");
303 assert_eq!(max_iterations, 50);
304 }
305 _ => panic!("expected RecursiveCte"),
306 }
307 }
308
309 #[test]
310 fn iterative_executor_stops_at_fixpoint() {
311 use arrow::array::Int32Array;
312 use arrow::datatypes::{DataType, Field, Schema};
313 use std::sync::Arc;
314
315 let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int32, false)]));
316
317 let stmt = RecursiveCteStatement {
318 name: "cte".into(),
319 base_query: "SELECT 1 AS n".into(),
320 recursive_query: "SELECT n + 1 FROM cte WHERE n < 3".into(),
321 max_iterations: DEFAULT_MAX_ITERATIONS,
322 };
323
324 let mut call_count = 0u32;
327 let schema_clone = schema.clone();
328 let execute = |sql: &str| -> SqlResult<Vec<RecordBatch>> {
329 call_count += 1;
330 let values: Vec<i32> = if sql.contains("SELECT 1") {
331 vec![1]
332 } else {
333 match call_count {
335 2 => vec![2],
336 3 => vec![3],
337 _ => vec![],
338 }
339 };
340 if values.is_empty() {
341 return Ok(vec![]);
342 }
343 let batch = RecordBatch::try_new(
344 schema_clone.clone(),
345 vec![Arc::new(Int32Array::from(values))],
346 )
347 .map_err(|e| SqlError::Unsupported {
348 feature: e.to_string(),
349 })?;
350 Ok(vec![batch])
351 };
352
353 let register = |_name: &str, _batches: &[RecordBatch]| -> SqlResult<()> { Ok(()) };
354
355 let result = execute_recursive_cte(&stmt, execute, register).unwrap();
356 assert!(!result.hit_limit);
357 assert!(result.iterations <= 3);
358 let total_rows: usize = result.batches.iter().map(|b| b.num_rows()).sum();
359 assert!(total_rows > 0);
360 }
361
362 #[test]
363 fn iterative_executor_respects_max_iterations() {
364 use arrow::array::Int32Array;
365 use arrow::datatypes::{DataType, Field, Schema};
366 use std::sync::Arc;
367
368 let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int32, false)]));
369
370 let stmt = RecursiveCteStatement {
371 name: "inf".into(),
372 base_query: "SELECT 0 AS n".into(),
373 recursive_query: "SELECT n + 1 FROM inf".into(),
374 max_iterations: 5,
375 };
376
377 let schema_clone = schema.clone();
378 let execute = |_sql: &str| -> SqlResult<Vec<RecordBatch>> {
379 let batch = RecordBatch::try_new(
380 schema_clone.clone(),
381 vec![Arc::new(Int32Array::from(vec![42i32]))],
382 )
383 .map_err(|e| SqlError::Unsupported {
384 feature: e.to_string(),
385 })?;
386 Ok(vec![batch])
387 };
388
389 let register = |_: &str, _: &[RecordBatch]| -> SqlResult<()> { Ok(()) };
390
391 let result = execute_recursive_cte(&stmt, execute, register).unwrap();
392 assert!(result.hit_limit, "should have hit max_iterations");
393 assert_eq!(result.iterations, 5);
394 }
395}