use sqlparser::dialect::Dialect;
use sqlparser::parser::Parser;
use crate::ast::{RawStatement, SqltStatement};
use crate::dialect::DialectId;
use crate::error::Result;
mod split;
pub fn parse(sql: &str, dialect: DialectId) -> Result<Vec<SqltStatement>> {
let upstream = dialect.upstream();
let preprocessed = if dialect == DialectId::MariaDb {
preprocess_mariadb(sql)
} else {
sql.to_string()
};
match Parser::parse_sql(&*upstream, &preprocessed) {
Ok(stmts) => Ok(stmts.into_iter().map(SqltStatement::from).collect()),
Err(e) if dialect == DialectId::MariaDb => {
mariadb_with_fallback(sql, &preprocessed, &*upstream, e)
}
Err(e) => Err(e.into()),
}
}
fn preprocess_mariadb(sql: &str) -> String {
let mut out = String::with_capacity(sql.len());
let bytes = sql.as_bytes();
let mut i = 0;
let mut state = State::Normal;
while i < bytes.len() {
let b = bytes[i];
match state {
State::Normal => match b {
b'\'' => {
out.push(b as char);
state = State::SingleQuote;
i += 1;
}
b'"' => {
out.push(b as char);
state = State::DoubleQuote;
i += 1;
}
b'`' => {
out.push(b as char);
state = State::Backtick;
i += 1;
}
b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => {
let after = i + 2;
let conditional = match bytes.get(after) {
Some(b'!') => Some(after + 1),
Some(b'M') if bytes.get(after + 1) == Some(&b'!') => Some(after + 2),
_ => None,
};
if let Some(mut cursor) = conditional {
let mut end = None;
let mut k = cursor;
while k + 1 < bytes.len() {
if bytes[k] == b'*' && bytes[k + 1] == b'/' {
end = Some(k);
break;
}
k += 1;
}
if let Some(end) = end {
for _ in i..cursor {
out.push(' ');
}
while cursor < end && bytes[cursor].is_ascii_digit() {
out.push(' ');
cursor += 1;
}
for &b in &bytes[cursor..end] {
out.push(b as char);
}
out.push(' ');
out.push(' ');
i = end + 2;
continue;
}
}
out.push_str("/*");
state = State::BlockComment;
i += 2;
}
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => {
let after = i + 2;
let next = bytes.get(after).copied();
out.push_str("--");
if matches!(next, None | Some(b'\n') | Some(b'\r')) {
out.push(' ');
}
state = State::LineComment;
i = after;
}
_ => {
out.push(bytes[i] as char);
i += 1;
}
},
State::SingleQuote => {
out.push(bytes[i] as char);
if bytes[i] == b'\\' && i + 1 < bytes.len() {
out.push(bytes[i + 1] as char);
i += 2;
continue;
}
if bytes[i] == b'\'' {
state = State::Normal;
}
i += 1;
}
State::DoubleQuote => {
out.push(bytes[i] as char);
if bytes[i] == b'"' {
state = State::Normal;
}
i += 1;
}
State::Backtick => {
out.push(bytes[i] as char);
if bytes[i] == b'`' {
state = State::Normal;
}
i += 1;
}
State::BlockComment => {
out.push(bytes[i] as char);
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
out.push('/');
i += 2;
state = State::Normal;
continue;
}
i += 1;
}
State::LineComment => {
out.push(bytes[i] as char);
if bytes[i] == b'\n' {
state = State::Normal;
}
i += 1;
}
}
}
out
}
enum State {
Normal,
SingleQuote,
DoubleQuote,
Backtick,
BlockComment,
LineComment,
}
fn mariadb_with_fallback(
original_sql: &str,
preprocessed: &str,
upstream: &dyn Dialect,
original_err: sqlparser::parser::ParserError,
) -> Result<Vec<SqltStatement>> {
let pieces = split::split_statements_with_lines(original_sql);
if pieces.is_empty() {
return Err(original_err.into());
}
let _ = preprocessed; let mut out = Vec::with_capacity(pieces.len());
let mut any_fallback = false;
for (start_line, piece) in pieces {
let trimmed = piece.trim();
if trimmed.is_empty() {
continue;
}
let piece_pp = preprocess_mariadb(trimmed);
let pad_lines = (start_line as usize).saturating_sub(1);
let padded = if pad_lines == 0 {
piece_pp
} else {
let mut s = String::with_capacity(pad_lines + piece_pp.len());
for _ in 0..pad_lines {
s.push('\n');
}
s.push_str(&piece_pp);
s
};
match Parser::parse_sql(upstream, &padded) {
Ok(mut stmts) if !stmts.is_empty() => {
out.extend(stmts.drain(..).map(SqltStatement::from));
}
_ => {
let reason = classify_mariadb_raw(trimmed);
out.push(SqltStatement::Raw(RawStatement {
sqlt_raw: trimmed.to_string(),
reason,
start_line: Some(start_line),
}));
any_fallback = true;
}
}
}
if !any_fallback {
return Err(original_err.into());
}
Ok(out)
}
fn classify_mariadb_raw(stmt: &str) -> String {
let upper = stmt.to_ascii_uppercase();
let mut head = upper.trim_start();
if let Some(rest) = head
.strip_prefix("/*!")
.or_else(|| head.strip_prefix("/*M!"))
{
let rest = rest.trim_start_matches(|c: char| c.is_ascii_digit());
let rest = rest.trim_start();
head = rest;
}
if head.contains("WITH SYSTEM VERSIONING") || head.contains("PERIOD FOR SYSTEM_TIME") {
return "system_versioning".to_string();
}
if head.contains("FOR SYSTEM_TIME") {
return "temporal_query".to_string();
}
if head.starts_with("CREATE PACKAGE") || head.starts_with("CREATE OR REPLACE PACKAGE") {
return "create_package".to_string();
}
if head.starts_with("CREATE SEQUENCE") || head.starts_with("CREATE OR REPLACE SEQUENCE") {
return "sequence_options".to_string();
}
if head.contains("VECTOR(") || head.contains("VEC_DISTANCE") {
return "vector_type".to_string();
}
if head.starts_with("DELIMITER") {
return "delimiter".to_string();
}
if head.starts_with("CREATE TRIGGER")
|| head.starts_with("CREATE FUNCTION")
|| head.starts_with("CREATE PROCEDURE")
{
return "stored_program_body".to_string();
}
if head.starts_with("ALTER TABLE")
&& (head.contains(" DISABLE KEYS") || head.contains(" ENABLE KEYS"))
{
return "optimization_hint".to_string();
}
if head.starts_with("CREATE DEFINER=")
|| head.starts_with("CREATE ALGORITHM=")
|| head.contains(" DEFINER=`")
{
return "definer_clause".to_string();
}
if head.starts_with("CREATE EVENT") {
return "create_event".to_string();
}
"unrepresented".to_string()
}