pub mod base;
pub mod binary_ops;
pub mod case_when;
pub mod clauses;
pub mod cte;
pub mod ddl;
pub mod dml;
pub mod expressions;
pub mod functions;
pub mod joins;
pub mod merge;
pub mod special_funcs;
use self::base::*;
use self::clauses::*;
use self::ddl::*;
use self::dml::*;
use self::joins::*;
use crate::ast::*;
use nom::{
IResult, Parser,
bytes::complete::tag_no_case,
character::complete::{multispace0, multispace1},
combinator::opt,
multi::many0,
};
pub fn parse(input: &str) -> Result<Qail, String> {
let cleaned = strip_sql_comments(input);
let desugared = desugar_bracket_filter(&cleaned);
match parse_root(&desugared) {
Ok(("", cmd)) => Ok(cmd),
Ok((remaining, _)) => Err(format!("Unexpected trailing content: '{}'", remaining)),
Err(e) => Err(format!("Parse error: {:?}", e)),
}
}
fn desugar_bracket_filter(input: &str) -> String {
let trimmed = input.trim();
if let Some(bracket_start) = trimmed.find('[') {
let before_bracket = &trimmed[..bracket_start];
if !before_bracket.contains(' ') {
return trimmed.to_string();
}
let before_lower = before_bracket.to_ascii_lowercase();
if before_lower.contains(" where ")
|| before_lower.contains(" fields ")
|| before_lower.contains(" having ")
|| before_lower.contains(" order ")
|| before_lower.contains(" limit ")
|| before_lower.contains(" offset ")
|| before_lower.contains(" join ")
{
return trimmed.to_string();
}
let after_bracket = &trimmed[bracket_start + 1..];
let mut depth = 1;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut bracket_end = None;
for (i, c) in after_bracket.char_indices() {
match c {
'\'' if !in_double_quote => in_single_quote = !in_single_quote,
'"' if !in_single_quote => in_double_quote = !in_double_quote,
'[' if !in_single_quote && !in_double_quote => depth += 1,
']' if !in_single_quote && !in_double_quote => {
depth -= 1;
if depth == 0 {
bracket_end = Some(i);
break;
}
}
_ => {}
}
}
if let Some(end_pos) = bracket_end {
let filter = &after_bracket[..end_pos];
let rest = &after_bracket[end_pos + 1..].trim();
let rest_lower = rest.to_lowercase();
if rest_lower.contains("where ") || rest_lower.contains("where\n") {
return format!("{} {} AND {}", before_bracket, rest, filter);
} else if rest.is_empty() {
return format!("{} where {}", before_bracket, filter);
} else {
return format!("{} {} where {}", before_bracket, rest, filter);
}
}
}
trimmed.to_string()
}
pub fn parse_root(input: &str) -> IResult<&str, Qail> {
let input = input.trim();
if let Ok((remaining, cmd)) = parse_txn_command(input) {
return Ok((remaining, cmd));
}
if let Ok((remaining, cmd)) = parse_procedural_command(input) {
return Ok((remaining, cmd));
}
if let Ok((remaining, cmd)) = parse_create_index(input) {
return Ok((remaining, cmd));
}
let lower_input = input.to_lowercase();
let (input, ctes) = if lower_input.starts_with("with")
&& lower_input
.chars()
.nth(4)
.map(|c| c.is_whitespace())
.unwrap_or(false)
{
let (remaining, (cte_defs, _is_recursive)) = cte::parse_with_clause(input)?;
let (remaining, _) = multispace0(remaining)?;
(remaining, cte_defs)
} else {
(input, vec![])
};
let (input, (action, distinct)) = parse_action(input)?;
let (input, _) = multispace1(input)?;
let (input, distinct_on) = if distinct {
if let Ok((remaining, _)) = tag_no_case::<_, _, nom::error::Error<&str>>("on").parse(input)
{
let (remaining, _) = multispace0(remaining)?;
let (remaining, exprs) = nom::sequence::delimited(
nom::character::complete::char('('),
nom::multi::separated_list1(
(
multispace0,
nom::character::complete::char(','),
multispace0,
),
expressions::parse_expression,
),
nom::character::complete::char(')'),
)
.parse(remaining)?;
let (remaining, _) = multispace1(remaining)?;
(remaining, exprs)
} else {
(input, vec![])
}
} else {
(input, vec![])
};
let (input, table) = parse_identifier(input)?;
let (input, _) = multispace0(input)?;
if matches!(action, Action::Make) {
return parse_create_table(input, table);
}
if matches!(action, Action::Merge) {
return merge::parse_merge_after_target(input, table, ctes);
}
let (input, joins) = many0(parse_join_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, set_cages) = if matches!(action, Action::Set) {
opt(parse_values_clause).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, columns) = opt(parse_fields_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, source_query) = if matches!(action, Action::Add) {
opt(dml::parse_source_query).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, add_cages) = if source_query.is_none() && matches!(action, Action::Add) {
opt(dml::parse_insert_values).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, where_cages) = opt(parse_where_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, having) = opt(parse_having_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, on_conflict) = if matches!(action, Action::Add) {
opt(dml::parse_on_conflict).parse(input)?
} else {
(input, None)
};
let (input, _) = multispace0(input)?;
let (input, order_cages) = opt(parse_order_by_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, limit_cage) = opt(parse_limit_clause).parse(input)?;
let (input, _) = multispace0(input)?;
let (input, offset_cage) = opt(parse_offset_clause).parse(input)?;
let mut cages = Vec::new();
if let Some(sc) = set_cages {
cages.push(sc);
}
if let Some(ac) = add_cages {
cages.push(ac);
}
if let Some(wc) = where_cages {
cages.extend(wc);
}
if let Some(oc) = order_cages {
cages.extend(oc);
}
if let Some(lc) = limit_cage {
cages.push(lc);
}
if let Some(oc) = offset_cage {
cages.push(oc);
}
Ok((
input,
Qail {
action,
table: table.to_string(),
columns: columns.unwrap_or_else(|| vec![Expr::Star]),
joins,
cages,
distinct,
distinct_on,
index_def: None,
table_constraints: vec![],
set_ops: vec![],
having: having.unwrap_or_default(),
group_by_mode: GroupByMode::default(),
returning: None,
ctes,
on_conflict,
merge: None,
source_query,
channel: None,
payload: None,
savepoint_name: None,
from_tables: vec![],
using_tables: vec![],
lock_mode: None,
skip_locked: false,
fetch: None,
default_values: false,
overriding: None,
sample: None,
only_table: false,
vector: None,
score_threshold: None,
vector_name: None,
with_vector: false,
vector_size: None,
distance: None,
on_disk: None,
function_def: None,
trigger_def: None,
policy_def: None,
},
))
}
fn strip_sql_comments(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let bytes = input.as_bytes();
let mut i = 0;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut raw_delimiter: Option<String> = None;
while i < input.len() {
if let Some(ref delimiter) = raw_delimiter {
if input[i..].starts_with(delimiter) {
result.push_str(delimiter);
i += delimiter.len();
raw_delimiter = None;
} else {
push_char_at(input, &mut result, &mut i);
}
continue;
}
if in_single_quote {
if bytes[i] == b'\'' {
result.push('\'');
i += 1;
if i < input.len() && bytes[i] == b'\'' {
result.push('\'');
i += 1;
} else {
in_single_quote = false;
}
} else {
push_char_at(input, &mut result, &mut i);
}
continue;
}
if in_double_quote {
if bytes[i] == b'"' {
result.push('"');
i += 1;
if i < input.len() && bytes[i] == b'"' {
result.push('"');
i += 1;
} else {
in_double_quote = false;
}
} else {
push_char_at(input, &mut result, &mut i);
}
continue;
}
if input[i..].starts_with("'''") || input[i..].starts_with("\"\"\"") {
let delimiter = &input[i..i + 3];
result.push_str(delimiter);
raw_delimiter = Some(delimiter.to_string());
i += 3;
continue;
}
if bytes[i] == b'\'' {
in_single_quote = true;
result.push('\'');
i += 1;
continue;
}
if bytes[i] == b'"' {
in_double_quote = true;
result.push('"');
i += 1;
continue;
}
if let Some(delimiter_len) = dollar_quote_delimiter_len(bytes, i) {
let delimiter = &input[i..i + delimiter_len];
result.push_str(delimiter);
raw_delimiter = Some(delimiter.to_string());
i += delimiter_len;
continue;
}
if bytes[i] == b'-' && i + 1 < input.len() && bytes[i + 1] == b'-' {
i += 2;
while i < input.len() {
let Some(ch) = input.get(i..).and_then(|s| s.chars().next()) else {
break;
};
i += ch.len_utf8();
if ch == '\n' {
result.push('\n');
break;
}
}
} else if bytes[i] == b'/' && i + 1 < input.len() && bytes[i + 1] == b'*' {
i += 2;
let mut closed = false;
while i < input.len() {
if bytes[i] == b'*' && i + 1 < input.len() && bytes[i + 1] == b'/' {
i += 2;
result.push(' '); closed = true;
break;
}
advance_char(input, &mut i);
}
if !closed {
result.push_str("/*");
}
} else {
push_char_at(input, &mut result, &mut i);
}
}
result
}
fn push_char_at(input: &str, output: &mut String, index: &mut usize) {
if let Some(ch) = input.get(*index..).and_then(|s| s.chars().next()) {
output.push(ch);
*index += ch.len_utf8();
} else {
*index = input.len();
}
}
fn advance_char(input: &str, index: &mut usize) {
if let Some(ch) = input.get(*index..).and_then(|s| s.chars().next()) {
*index += ch.len_utf8();
} else {
*index = input.len();
}
}
fn dollar_quote_delimiter_len(bytes: &[u8], start: usize) -> Option<usize> {
if bytes.get(start) != Some(&b'$') {
return None;
}
let mut end = start + 1;
if bytes.get(end) == Some(&b'$') {
return Some(2);
}
let first = *bytes.get(end)?;
if !first.is_ascii_alphabetic() && first != b'_' {
return None;
}
end += 1;
while let Some(&byte) = bytes.get(end) {
if byte == b'$' {
return Some(end - start + 1);
}
if !byte.is_ascii_alphanumeric() && byte != b'_' {
return None;
}
end += 1;
}
None
}