use super::base::parse_identifier;
use crate::ast::*;
use nom::{
IResult, Parser,
bytes::complete::tag_no_case,
character::complete::{char, multispace0, multispace1},
combinator::{map, opt},
multi::separated_list1,
sequence::{delimited, preceded},
};
pub fn parse_with_clause(input: &str) -> IResult<&str, (Vec<CTEDef>, bool)> {
let (input, _) = tag_no_case("with").parse(input)?;
let (input, _) = multispace1(input)?;
let (input, recursive) = opt(preceded(tag_no_case("recursive"), multispace1)).parse(input)?;
let is_recursive = recursive.is_some();
let (input, ctes) = separated_list1((multispace0, char(','), multispace0), |i| {
parse_cte_definition(i, is_recursive)
})
.parse(input)?;
Ok((input, (ctes, is_recursive)))
}
fn parse_cte_definition(input: &str, is_recursive: bool) -> IResult<&str, CTEDef> {
let (input, name) = parse_identifier(input)?;
let (input, _) = multispace0(input)?;
let (input, columns) = opt(delimited(
char('('),
separated_list1(
(multispace0, char(','), multispace0),
map(parse_identifier, |s| s.to_string()),
),
char(')'),
))
.parse(input)?;
let (input, _) = multispace0(input)?;
let (input, _) = tag_no_case("as").parse(input)?;
let (input, _) = multispace0(input)?;
let (input, cte_body) =
delimited(char('('), take_until_matching_paren, char(')')).parse(input)?;
let cte_body = cte_body.trim();
if is_recursive {
parse_recursive_cte_strict(input, name, columns.unwrap_or_default(), cte_body)
} else {
let base_query = parse_qail_strict(cte_body).map_err(|_| {
nom::Err::Failure(nom::error::Error::new(input, nom::error::ErrorKind::Verify))
})?;
Ok((
input,
CTEDef {
name: name.to_string(),
recursive: false,
columns: columns.unwrap_or_default(),
base_query: Box::new(base_query),
recursive_query: None,
source_table: None,
},
))
}
}
fn parse_recursive_cte_strict<'a>(
remaining_input: &'a str,
name: &str,
columns: Vec<String>,
body: &str,
) -> IResult<&'a str, CTEDef> {
let (base_str, recursive_str) = split_top_level_union_all(body).map_err(|_| {
nom::Err::Failure(nom::error::Error::new(
remaining_input,
nom::error::ErrorKind::Verify,
))
})?;
let base_query = parse_qail_strict(base_str.trim()).map_err(|_| {
nom::Err::Failure(nom::error::Error::new(
remaining_input,
nom::error::ErrorKind::Verify,
))
})?;
let recursive_query = parse_qail_strict(recursive_str.trim()).map_err(|_| {
nom::Err::Failure(nom::error::Error::new(
remaining_input,
nom::error::ErrorKind::Verify,
))
})?;
if contains_ident_outside_quotes_comments(base_str, name) {
return Err(nom::Err::Failure(nom::error::Error::new(
remaining_input,
nom::error::ErrorKind::Verify,
)));
}
if !contains_ident_outside_quotes_comments(recursive_str, name) {
return Err(nom::Err::Failure(nom::error::Error::new(
remaining_input,
nom::error::ErrorKind::Verify,
)));
}
Ok((
remaining_input,
CTEDef {
name: name.to_string(),
recursive: true,
columns,
base_query: Box::new(base_query),
recursive_query: Some(Box::new(recursive_query)),
source_table: None,
},
))
}
pub fn split_top_level_union_all(body: &str) -> Result<(&str, &str), &'static str> {
let bytes = body.as_bytes();
let len = bytes.len();
let mut i = 0;
let mut depth: usize = 0;
let mut union_pos: Option<usize> = None; let mut union_end: Option<usize> = None;
while i < len {
match bytes[i] {
b'\'' => {
i += 1;
while i < len {
if bytes[i] == b'\'' {
i += 1;
if i < len && bytes[i] == b'\'' {
i += 1;
continue;
}
break;
}
i += 1;
}
}
b'"' => {
i += 1;
while i < len {
if bytes[i] == b'"' {
i += 1;
if i < len && bytes[i] == b'"' {
i += 1;
continue;
}
break;
}
i += 1;
}
}
b'$' if i + 1 < len && bytes[i + 1] == b'$' => {
i += 2;
while i + 1 < len {
if bytes[i] == b'$' && bytes[i + 1] == b'$' {
i += 2;
break;
}
i += 1;
}
}
b'-' if i + 1 < len && bytes[i + 1] == b'-' => {
i += 2;
while i < len && bytes[i] != b'\n' {
i += 1;
}
if i < len {
i += 1;
}
}
b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
i += 2;
while i + 1 < len {
if bytes[i] == b'*' && bytes[i + 1] == b'/' {
i += 2;
break;
}
i += 1;
}
}
b'(' => {
depth += 1;
i += 1;
}
b')' => {
depth = depth.saturating_sub(1);
i += 1;
}
b'U' | b'u' if depth == 0 => {
if i > 0 && is_ident_char(bytes[i - 1]) {
i += 1;
continue;
}
if i + 9 <= len
&& body[i..i + 5].eq_ignore_ascii_case("UNION")
&& !is_ident_char(bytes[i + 5])
{
let mut j = i + 5;
while j < len && bytes[j].is_ascii_whitespace() {
j += 1;
}
if j + 3 <= len && body[j..j + 3].eq_ignore_ascii_case("ALL") {
let after_all = j + 3;
if after_all >= len || !is_ident_char(bytes[after_all]) {
if union_pos.is_some() {
return Err("multiple top-level UNION ALL found");
}
union_pos = Some(i);
union_end = Some(after_all);
i = after_all;
continue;
}
}
let after_union = i + 5;
if after_union >= len || !is_ident_char(bytes[after_union]) {
return Err(
"bare UNION (without ALL) found; only UNION ALL is supported in recursive CTEs",
);
}
}
i += 1;
}
_ => {
i += 1;
}
}
}
match (union_pos, union_end) {
(Some(pos), Some(end)) => Ok((&body[..pos], &body[end..])),
_ => Err("no top-level UNION ALL found"),
}
}
pub fn parse_qail_strict(sql: &str) -> Result<Qail, &'static str> {
match super::parse_root(sql) {
Ok((remaining, cmd)) => {
if !remaining.trim().is_empty() {
return Err("partial parse — trailing input");
}
Ok(cmd)
}
Err(_) => Err("QAIL parse failed"),
}
}
pub fn contains_ident_outside_quotes_comments(input: &str, ident: &str) -> bool {
let bytes = input.as_bytes();
let len = bytes.len();
let ident_len = ident.len();
let mut i = 0;
while i < len {
match bytes[i] {
b'\'' => {
i += 1;
while i < len {
if bytes[i] == b'\'' {
i += 1;
if i < len && bytes[i] == b'\'' {
i += 1;
continue;
}
break;
}
i += 1;
}
}
b'"' => {
i += 1;
while i < len {
if bytes[i] == b'"' {
i += 1;
if i < len && bytes[i] == b'"' {
i += 1;
continue;
}
break;
}
i += 1;
}
}
b'$' if i + 1 < len && bytes[i + 1] == b'$' => {
i += 2;
while i + 1 < len {
if bytes[i] == b'$' && bytes[i + 1] == b'$' {
i += 2;
break;
}
i += 1;
}
}
b'-' if i + 1 < len && bytes[i + 1] == b'-' => {
i += 2;
while i < len && bytes[i] != b'\n' {
i += 1;
}
if i < len {
i += 1;
}
}
b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
i += 2;
while i + 1 < len {
if bytes[i] == b'*' && bytes[i + 1] == b'/' {
i += 2;
break;
}
i += 1;
}
}
_ => {
if i + ident_len <= len
&& input[i..i + ident_len].eq_ignore_ascii_case(ident)
&& (i == 0 || !is_ident_char(bytes[i - 1]))
&& (i + ident_len >= len || !is_ident_char(bytes[i + ident_len]))
{
return true;
}
i += 1;
}
}
}
false
}
fn is_ident_char(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
fn take_until_matching_paren(input: &str) -> IResult<&str, &str> {
let bytes = input.as_bytes();
let len = bytes.len();
let mut depth: usize = 1;
let mut i = 0;
while i < len {
match bytes[i] {
b'\'' => {
i += 1;
while i < len {
if bytes[i] == b'\'' {
i += 1;
if i < len && bytes[i] == b'\'' {
i += 1;
continue;
}
break;
}
i += 1;
}
}
b'"' => {
i += 1;
while i < len {
if bytes[i] == b'"' {
i += 1;
if i < len && bytes[i] == b'"' {
i += 1;
continue;
}
break;
}
i += 1;
}
}
b'$' if i + 1 < len && bytes[i + 1] == b'$' => {
i += 2;
while i + 1 < len {
if bytes[i] == b'$' && bytes[i + 1] == b'$' {
i += 2;
break;
}
i += 1;
}
}
b'-' if i + 1 < len && bytes[i + 1] == b'-' => {
i += 2;
while i < len && bytes[i] != b'\n' {
i += 1;
}
if i < len {
i += 1;
}
}
b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
i += 2;
while i + 1 < len {
if bytes[i] == b'*' && bytes[i + 1] == b'/' {
i += 2;
break;
}
i += 1;
}
}
b'(' => {
depth += 1;
i += 1;
}
b')' => {
depth -= 1;
if depth == 0 {
return Ok((&input[i..], &input[..i]));
}
i += 1;
}
_ => {
i += 1;
}
}
}
Err(nom::Err::Error(nom::error::Error::new(
input,
nom::error::ErrorKind::TakeUntil,
)))
}