pub mod base;
pub mod binary_ops;
pub mod case_when;
pub mod clauses;
pub mod cte;
pub mod ddl;
pub mod dml;
pub mod expressions;
pub mod functions;
pub mod joins;
pub mod special_funcs;
use self::base::*;
use self::clauses::*;
use self::ddl::*;
use self::dml::*;
use self::joins::*;
use crate::ast::*;
use nom::{
IResult, Parser,
bytes::complete::tag_no_case,
character::complete::{multispace0, multispace1},
combinator::opt,
multi::many0,
};
pub fn parse(input: &str) -> Result<Qail, String> {
let cleaned = strip_sql_comments(input);
let desugared = desugar_bracket_filter(&cleaned);
match parse_root(&desugared) {
Ok(("", cmd)) => Ok(cmd),
Ok((remaining, _)) => Err(format!("Unexpected trailing content: '{}'", remaining)),
Err(e) => Err(format!("Parse error: {:?}", e)),
}
}
fn desugar_bracket_filter(input: &str) -> String {
let trimmed = input.trim();
if let Some(bracket_start) = trimmed.find('[') {
let before_bracket = &trimmed[..bracket_start];
if !before_bracket.contains(' ') {
return trimmed.to_string();
}
let before_lower = before_bracket.to_ascii_lowercase();
if before_lower.contains(" where ")
|| before_lower.contains(" fields ")
|| before_lower.contains(" having ")
|| before_lower.contains(" order ")
|| before_lower.contains(" limit ")
|| before_lower.contains(" offset ")
|| before_lower.contains(" join ")
{
return trimmed.to_string();
}
let after_bracket = &trimmed[bracket_start + 1..];
let mut depth = 1;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut bracket_end = None;
for (i, c) in after_bracket.char_indices() {
match c {
'\'' if !in_double_quote => in_single_quote = !in_single_quote,
'"' if !in_single_quote => in_double_quote = !in_double_quote,
'[' if !in_single_quote && !in_double_quote => depth += 1,
']' if !in_single_quote && !in_double_quote => {
depth -= 1;
if depth == 0 {
bracket_end = Some(i);
break;
}
}
_ => {}
}
}
if let Some(end_pos) = bracket_end {
let filter = &after_bracket[..end_pos];
let rest = &after_bracket[end_pos + 1..].trim();
let rest_lower = rest.to_lowercase();
if rest_lower.contains("where ") || rest_lower.contains("where\n") {
return format!("{} {} AND {}", before_bracket, rest, filter);
} else if rest.is_empty() {
return format!("{} where {}", before_bracket, filter);
} else {
return format!("{} {} where {}", before_bracket, rest, filter);
}
}
}
trimmed.to_string()
}
pub fn parse_root(input: &str) -> IResult<&str, Qail> {
let input = input.trim();
if let Ok((remaining, cmd)) = parse_txn_command(input) {
return Ok((remaining, cmd));
}
if let Ok((remaining, cmd)) = parse_procedural_command(input) {
return Ok((remaining, cmd));
}
if let Ok((remaining, cmd)) = parse_create_index(input) {
return Ok((remaining, cmd));
}
let lower_input = input.to_lowercase();
let (input, ctes) = if lower_input.starts_with("with")
&& lower_input
.chars()
.nth(4)
.map(|c| c.is_whitespace())
.unwrap_or(false)
{
let (remaining, (cte_defs, _is_recursive)) = cte::parse_with_clause(input)?;
let (remaining, _) = multispace0(remaining)?;
(remaining, cte_defs)
} else {
(input, vec![])
};
let (input, (action, distinct)) = parse_action(input)?;
let (input, _) = multispace1(input)?;
let (input, distinct_on) = if distinct {
if let Ok((remaining, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("on").parse(input)
{
let (remaining, _) = multispace0(remaining)?;
let (remaining, exprs) = nom::sequence::delimited(
nom::character::complete::char('('),
nom::multi::separated_list1(
(
multispace0,
nom::character::complete::char(','),
multispace0,
),
expressions::parse_expression,
),
nom::character::complete::char(')'),
)
.parse(remaining)?;
let (remaining, _) = multispace1(remaining)?;
(remaining, exprs)
} else {
(input, vec![])
}
} else {
(input, vec![])
};
let (input, table) = parse_identifier(input)?;
let (input, _) = multispace0(input)?;
if matches!(action, Action::Make) {
return parse_create_table(input, table);
}
let (input, joins) = many0(parse_join_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, set_cages) = if matches!(action, Action::Set) {
opt(parse_values_clause).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, columns) = opt(parse_fields_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, source_query) = if matches!(action, Action::Add) {
opt(dml::parse_source_query).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, add_cages) = if source_query.is_none() && matches!(action, Action::Add) {
opt(dml::parse_insert_values).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, where_cages) = opt(parse_where_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, having) = opt(parse_having_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, on_conflict) = if matches!(action, Action::Add) {
opt(dml::parse_on_conflict).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, order_cages) = opt(parse_order_by_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, limit_cage) = opt(parse_limit_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, offset_cage) = opt(parse_offset_clause).parse(input)?;
let mut cages = Vec::new();
if let Some(sc) = set_cages {
cages.push(sc);
}
if let Some(ac) = add_cages {
cages.push(ac);
}
if let Some(wc) = where_cages {
cages.extend(wc);
}
if let Some(oc) = order_cages {
cages.extend(oc);
}
if let Some(lc) = limit_cage {
cages.push(lc);
}
if let Some(oc) = offset_cage {
cages.push(oc);
}
Ok((
input,
Qail {
action,
table: table.to_string(),
columns: columns.unwrap_or_else(|| vec![Expr::Star]),
joins,
cages,
distinct,
distinct_on,
index_def: None,
table_constraints: vec![],
set_ops: vec![],
having: having.unwrap_or_default(),
group_by_mode: GroupByMode::default(),
returning: None,
ctes,
on_conflict,
source_query,
channel: None,
payload: None,
savepoint_name: None,
from_tables: vec![],
using_tables: vec![],
lock_mode: None,
skip_locked: false,
fetch: None,
default_values: false,
overriding: None,
sample: None,
only_table: false,
vector: None,
score_threshold: None,
vector_name: None,
with_vector: false,
vector_size: None,
distance: None,
on_disk: None,
function_def: None,
trigger_def: None,
policy_def: None,
},
))
}
fn strip_sql_comments(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
if c == '-' && chars.peek() == Some(&'-') {
chars.next(); while let Some(&nc) = chars.peek() {
if nc == '\n' {
result.push('\n'); chars.next();
break;
}
chars.next();
}
} else if c == '/' && chars.peek() == Some(&'*') {
chars.next(); let mut closed = false;
while let Some(nc) = chars.next() {
if nc == '*' && chars.peek() == Some(&'/') {
chars.next(); result.push(' '); closed = true;
break;
}
}
if !closed {
result.push_str("/*");
}
} else {
result.push(c);
}
}
result
}