use arrow::record_batch::RecordBatch;
use datafusion::sql::sqlparser::ast::{Query, SetExpr, SetOperator, SetQuantifier, Statement};
use datafusion::sql::sqlparser::dialect::GenericDialect;
use datafusion::sql::sqlparser::parser::Parser;
use krishiv_plan::NodeOp;
use crate::{SqlError, SqlResult};
pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
#[derive(Debug, Clone)]
pub struct RecursiveCteStatement {
pub name: String,
pub base_query: String,
pub recursive_query: String,
pub max_iterations: u32,
}
pub fn parse_recursive_cte(sql: &str) -> SqlResult<Option<RecursiveCteStatement>> {
let trimmed = sql.trim().trim_end_matches(';');
let upper = trimmed.to_ascii_uppercase();
if !upper.starts_with("WITH RECURSIVE") {
return Ok(None);
}
let dialect = GenericDialect {};
let stmts = Parser::parse_sql(&dialect, trimmed).map_err(|e| SqlError::Unsupported {
feature: format!("WITH RECURSIVE parse error: {e}"),
})?;
let stmt = stmts
.into_iter()
.next()
.ok_or_else(|| SqlError::Unsupported {
feature: "WITH RECURSIVE produced no statement".into(),
})?;
extract_recursive_cte(stmt)
}
fn extract_recursive_cte(stmt: Statement) -> SqlResult<Option<RecursiveCteStatement>> {
let Statement::Query(q) = stmt else {
return Ok(None);
};
let Some(with) = &q.with else {
return Ok(None);
};
if !with.recursive {
return Ok(None);
}
let cte = with
.cte_tables
.first()
.ok_or_else(|| SqlError::Unsupported {
feature: "WITH RECURSIVE requires at least one CTE".into(),
})?;
let name = cte.alias.name.value.clone();
let (base_query, recursive_query) =
split_union_all(&cte.query).ok_or_else(|| SqlError::Unsupported {
feature: format!(
"WITH RECURSIVE '{name}': body must be `base_query UNION ALL recursive_query`"
),
})?;
Ok(Some(RecursiveCteStatement {
name,
base_query,
recursive_query,
max_iterations: DEFAULT_MAX_ITERATIONS,
}))
}
fn split_union_all(query: &Query) -> Option<(String, String)> {
match query.body.as_ref() {
SetExpr::SetOperation {
op: SetOperator::Union,
set_quantifier: SetQuantifier::All,
left,
right,
} => {
let left_sql = format!("SELECT * FROM ({left})");
let right_sql = format!("SELECT * FROM ({right})");
Some((left_sql, right_sql))
}
_ => None,
}
}
pub fn build_recursive_cte_op(stmt: &RecursiveCteStatement) -> NodeOp {
NodeOp::RecursiveCte {
name: stmt.name.clone(),
base_query: stmt.base_query.clone(),
recursive_query: stmt.recursive_query.clone(),
max_iterations: stmt.max_iterations,
}
}
#[derive(Debug)]
pub struct RecursiveCteResult {
pub batches: Vec<RecordBatch>,
pub iterations: u32,
pub hit_limit: bool,
}
pub fn execute_recursive_cte<E, R>(
stmt: &RecursiveCteStatement,
mut execute_fn: E,
mut register_batches_fn: R,
) -> SqlResult<RecursiveCteResult>
where
E: FnMut(&str) -> SqlResult<Vec<RecordBatch>>,
R: FnMut(&str, &[RecordBatch]) -> SqlResult<()>,
{
const MAX_ACCUMULATED_ROWS: usize = 10_000_000;
let base_batches = execute_fn(&stmt.base_query)?;
let mut accumulator = base_batches;
let mut iterations = 0u32;
let mut hit_limit = false;
loop {
if iterations >= stmt.max_iterations {
hit_limit = true;
break;
}
let acc_rows: usize = accumulator.iter().map(|b| b.num_rows()).sum();
if acc_rows >= MAX_ACCUMULATED_ROWS {
return Err(SqlError::Unsupported {
feature: format!(
"WITH RECURSIVE: accumulated row count ({acc_rows}) exceeded limit of {MAX_ACCUMULATED_ROWS}"
),
});
}
register_batches_fn(&stmt.name, &accumulator)?;
let delta = execute_fn(&stmt.recursive_query)?;
let delta_rows: usize = delta.iter().map(|b| b.num_rows()).sum();
if delta_rows == 0 {
break; }
accumulator.extend(delta);
iterations += 1;
}
Ok(RecursiveCteResult {
batches: accumulator,
iterations,
hit_limit,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_with_recursive_union_all() {
let sql = "\
WITH RECURSIVE cte AS (\
SELECT 1 AS n \
UNION ALL \
SELECT n + 1 FROM cte WHERE n < 5\
) SELECT * FROM cte";
let result = parse_recursive_cte(sql).unwrap();
assert!(result.is_some());
let stmt = result.unwrap();
assert_eq!(stmt.name, "cte");
assert!(stmt.base_query.contains("SELECT 1"));
assert!(stmt.recursive_query.to_ascii_uppercase().contains("CTE"));
assert_eq!(stmt.max_iterations, DEFAULT_MAX_ITERATIONS);
}
#[test]
fn returns_none_for_non_recursive_cte() {
let sql = "WITH t AS (SELECT 1) SELECT * FROM t";
let result = parse_recursive_cte(sql).unwrap();
assert!(result.is_none());
}
#[test]
fn returns_none_for_plain_select() {
let sql = "SELECT * FROM t WHERE x = 1";
let result = parse_recursive_cte(sql).unwrap();
assert!(result.is_none());
}
#[test]
fn rejects_non_union_all_body() {
let sql = "\
WITH RECURSIVE cte AS (\
SELECT 1 AS n \
UNION \
SELECT n + 1 FROM cte\
) SELECT * FROM cte";
let result = parse_recursive_cte(sql);
match result {
Ok(Some(stmt)) => {
assert!(
stmt.recursive_query.to_uppercase().contains("SELECT"),
"recursive query should reference the CTE"
);
}
Ok(None) => {
}
Err(_) => {
}
}
}
#[test]
fn build_recursive_cte_op_returns_correct_variant() {
let stmt = RecursiveCteStatement {
name: "tree".into(),
base_query: "SELECT id FROM nodes WHERE parent_id IS NULL".into(),
recursive_query: "SELECT n.id FROM nodes n JOIN tree t ON n.parent_id = t.id".into(),
max_iterations: 50,
};
let op = build_recursive_cte_op(&stmt);
match op {
NodeOp::RecursiveCte {
name,
max_iterations,
..
} => {
assert_eq!(name, "tree");
assert_eq!(max_iterations, 50);
}
_ => panic!("expected RecursiveCte"),
}
}
#[test]
fn iterative_executor_stops_at_fixpoint() {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int32, false)]));
let stmt = RecursiveCteStatement {
name: "cte".into(),
base_query: "SELECT 1 AS n".into(),
recursive_query: "SELECT n + 1 FROM cte WHERE n < 3".into(),
max_iterations: DEFAULT_MAX_ITERATIONS,
};
let mut call_count = 0u32;
let schema_clone = schema.clone();
let execute = |sql: &str| -> SqlResult<Vec<RecordBatch>> {
call_count += 1;
let values: Vec<i32> = if sql.contains("SELECT 1") {
vec![1]
} else {
match call_count {
2 => vec![2],
3 => vec![3],
_ => vec![],
}
};
if values.is_empty() {
return Ok(vec![]);
}
let batch = RecordBatch::try_new(
schema_clone.clone(),
vec![Arc::new(Int32Array::from(values))],
)
.map_err(|e| SqlError::Unsupported {
feature: e.to_string(),
})?;
Ok(vec![batch])
};
let register = |_name: &str, _batches: &[RecordBatch]| -> SqlResult<()> { Ok(()) };
let result = execute_recursive_cte(&stmt, execute, register).unwrap();
assert!(!result.hit_limit);
assert!(result.iterations <= 3);
let total_rows: usize = result.batches.iter().map(|b| b.num_rows()).sum();
assert!(total_rows > 0);
}
#[test]
fn iterative_executor_respects_max_iterations() {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int32, false)]));
let stmt = RecursiveCteStatement {
name: "inf".into(),
base_query: "SELECT 0 AS n".into(),
recursive_query: "SELECT n + 1 FROM inf".into(),
max_iterations: 5,
};
let schema_clone = schema.clone();
let execute = |_sql: &str| -> SqlResult<Vec<RecordBatch>> {
let batch = RecordBatch::try_new(
schema_clone.clone(),
vec![Arc::new(Int32Array::from(vec![42i32]))],
)
.map_err(|e| SqlError::Unsupported {
feature: e.to_string(),
})?;
Ok(vec![batch])
};
let register = |_: &str, _: &[RecordBatch]| -> SqlResult<()> { Ok(()) };
let result = execute_recursive_cte(&stmt, execute, register).unwrap();
assert!(result.hit_limit, "should have hit max_iterations");
assert_eq!(result.iterations, 5);
}
}