use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
pub struct SqlParser {
#[allow(dead_code)]
dialect: SqlDialect,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum SqlDialect {
#[default]
PostgreSQL,
MySQL,
SQLite,
}
impl SqlParser {
pub fn new() -> Self {
Self {
dialect: SqlDialect::PostgreSQL,
}
}
pub fn with_dialect(dialect: SqlDialect) -> Self {
Self { dialect }
}
pub fn parse(&self, sql: &str) -> Result<ParsedQuery, ParseError> {
let trimmed = sql.trim();
if trimmed.is_empty() {
return Err(ParseError::EmptyQuery);
}
let upper = trimmed.to_uppercase();
let first_word = upper.split_whitespace().next().unwrap_or("");
let is_select = first_word == "SELECT";
let is_insert = first_word == "INSERT";
let is_update = first_word == "UPDATE";
let is_delete = first_word == "DELETE";
let is_ddl = matches!(first_word, "CREATE" | "ALTER" | "DROP" | "TRUNCATE");
let tables = self.extract_tables(trimmed);
let has_select_star = is_select && self.has_select_star(trimmed);
let has_limit = upper.contains(" LIMIT ");
let has_where = upper.contains(" WHERE ");
let normalized = self.normalize(trimmed);
Ok(ParsedQuery {
original: trimmed.to_string(),
normalized,
tables,
has_select_star,
has_limit,
has_where,
is_select,
is_insert,
is_update,
is_delete,
is_ddl,
})
}
pub fn normalize(&self, sql: &str) -> String {
let mut result = String::with_capacity(sql.len());
let mut chars = sql.chars().peekable();
while let Some(c) = chars.next() {
match c {
'\'' => {
result.push('?');
let mut escaped = false;
for inner in chars.by_ref() {
if inner == '\'' && !escaped {
break;
}
escaped = inner == '\\' && !escaped;
}
}
'"' => {
result.push(c);
for inner in chars.by_ref() {
result.push(inner);
if inner == '"' {
break;
}
}
}
'0'..='9' => {
result.push('?');
while chars
.peek()
.map(|c| c.is_ascii_digit() || *c == '.')
.unwrap_or(false)
{
chars.next();
}
}
'$' => {
result.push('?');
while chars.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
chars.next();
}
}
_ => result.push(c),
}
}
let mut prev_space = false;
result
.chars()
.filter(|&c| {
if c.is_whitespace() {
if prev_space {
return false;
}
prev_space = true;
} else {
prev_space = false;
}
true
})
.collect::<String>()
.trim()
.to_string()
}
fn extract_tables(&self, sql: &str) -> Vec<String> {
let mut tables = Vec::new();
let upper = sql.to_uppercase();
let words: Vec<&str> = sql.split_whitespace().collect();
let upper_words: Vec<&str> = upper.split_whitespace().collect();
let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
for (i, word) in upper_words.iter().enumerate() {
if table_keywords.contains(&word.trim_end_matches(',')) {
if let Some(table) = words.get(i + 1) {
let table =
table.trim_matches(|c| c == ',' || c == '(' || c == ')' || c == ';');
if !table.is_empty() && !is_keyword(table) {
let table_name = table.split('.').next_back().unwrap_or(table);
tables.push(table_name.to_string());
}
}
}
}
tables.sort();
tables.dedup();
tables
}
fn has_select_star(&self, sql: &str) -> bool {
let upper = sql.to_uppercase();
if let Some(select_pos) = upper.find("SELECT") {
let after_select = &upper[select_pos + 6..];
let trimmed = after_select.trim_start();
if trimmed.starts_with("*") {
return true;
}
if let Some(after_distinct) = trimmed.strip_prefix("DISTINCT") {
if after_distinct.trim_start().starts_with('*') {
return true;
}
}
if let Some(after_all) = trimmed.strip_prefix("ALL") {
if after_all.trim_start().starts_with('*') {
return true;
}
}
}
false
}
pub fn to_sql(&self, parsed: &ParsedQuery) -> String {
parsed.original.clone()
}
}
impl Default for SqlParser {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ParsedQuery {
pub original: String,
pub normalized: String,
pub tables: Vec<String>,
pub has_select_star: bool,
pub has_limit: bool,
pub has_where: bool,
pub is_select: bool,
pub is_insert: bool,
pub is_update: bool,
pub is_delete: bool,
pub is_ddl: bool,
}
impl ParsedQuery {
pub fn fingerprint(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.normalized.to_uppercase().hash(&mut hasher);
hasher.finish()
}
pub fn is_write(&self) -> bool {
self.is_insert || self.is_update || self.is_delete || self.is_ddl
}
pub fn is_read(&self) -> bool {
self.is_select && !self.is_ddl
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SqlStatement {
Select,
Insert,
Update,
Delete,
Create,
Alter,
Drop,
Truncate,
Other,
}
impl SqlStatement {
pub fn from_sql(sql: &str) -> Self {
let first_word = sql.split_whitespace().next().unwrap_or("");
match first_word.to_uppercase().as_str() {
"SELECT" => Self::Select,
"INSERT" => Self::Insert,
"UPDATE" => Self::Update,
"DELETE" => Self::Delete,
"CREATE" => Self::Create,
"ALTER" => Self::Alter,
"DROP" => Self::Drop,
"TRUNCATE" => Self::Truncate,
_ => Self::Other,
}
}
pub fn is_write(&self) -> bool {
matches!(
self,
Self::Insert
| Self::Update
| Self::Delete
| Self::Create
| Self::Alter
| Self::Drop
| Self::Truncate
)
}
}
#[derive(Debug, Clone)]
pub enum ParseError {
EmptyQuery,
InvalidSyntax(String),
UnsupportedStatement(String),
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyQuery => write!(f, "Empty query"),
Self::InvalidSyntax(msg) => write!(f, "Invalid syntax: {}", msg),
Self::UnsupportedStatement(stmt) => write!(f, "Unsupported statement: {}", stmt),
}
}
}
impl std::error::Error for ParseError {}
impl From<ParseError> for super::RewriteError {
fn from(e: ParseError) -> Self {
super::RewriteError::ParseError(e.to_string())
}
}
fn is_keyword(word: &str) -> bool {
let upper = word.to_uppercase();
matches!(
upper.as_str(),
"SELECT"
| "FROM"
| "WHERE"
| "AND"
| "OR"
| "NOT"
| "INSERT"
| "INTO"
| "VALUES"
| "UPDATE"
| "SET"
| "DELETE"
| "CREATE"
| "ALTER"
| "DROP"
| "TABLE"
| "INDEX"
| "VIEW"
| "JOIN"
| "LEFT"
| "RIGHT"
| "INNER"
| "OUTER"
| "CROSS"
| "ON"
| "GROUP"
| "BY"
| "ORDER"
| "HAVING"
| "LIMIT"
| "OFFSET"
| "UNION"
| "INTERSECT"
| "EXCEPT"
| "AS"
| "DISTINCT"
| "ALL"
| "NULL"
| "TRUE"
| "FALSE"
| "CASE"
| "WHEN"
| "THEN"
| "ELSE"
| "END"
| "EXISTS"
| "IN"
| "BETWEEN"
| "LIKE"
| "IS"
| "ASC"
| "DESC"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_select() {
let parser = SqlParser::new();
let parsed = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
assert!(parsed.is_select);
assert!(parsed.has_select_star);
assert!(parsed.has_where);
assert!(!parsed.has_limit);
assert!(parsed.tables.contains(&"users".to_string()));
}
#[test]
fn test_parse_insert() {
let parser = SqlParser::new();
let parsed = parser
.parse("INSERT INTO users (name) VALUES ('test')")
.unwrap();
assert!(parsed.is_insert);
assert!(parsed.tables.contains(&"users".to_string()));
}
#[test]
fn test_normalize() {
let parser = SqlParser::new();
let normalized = parser.normalize("SELECT * FROM users WHERE id = 123 AND name = 'test'");
assert!(normalized.contains("id = ?"));
assert!(normalized.contains("name = ?"));
}
#[test]
fn test_fingerprint() {
let parser = SqlParser::new();
let q1 = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
let q2 = parser.parse("SELECT * FROM users WHERE id = 2").unwrap();
let q3 = parser.parse("SELECT * FROM orders WHERE id = 1").unwrap();
assert_eq!(q1.fingerprint(), q2.fingerprint());
assert_ne!(q1.fingerprint(), q3.fingerprint());
}
#[test]
fn test_extract_tables() {
let parser = SqlParser::new();
let parsed = parser
.parse("SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id")
.unwrap();
assert!(
parsed.tables.contains(&"u".to_string())
|| parsed.tables.contains(&"users".to_string())
);
}
#[test]
fn test_has_select_star() {
let parser = SqlParser::new();
assert!(parser.has_select_star("SELECT * FROM users"));
assert!(parser.has_select_star("SELECT DISTINCT * FROM users"));
assert!(!parser.has_select_star("SELECT id, name FROM users"));
}
#[test]
fn test_empty_query() {
let parser = SqlParser::new();
assert!(matches!(parser.parse(""), Err(ParseError::EmptyQuery)));
assert!(matches!(parser.parse(" "), Err(ParseError::EmptyQuery)));
}
#[test]
fn test_sql_statement_type() {
assert_eq!(
SqlStatement::from_sql("SELECT * FROM users"),
SqlStatement::Select
);
assert_eq!(
SqlStatement::from_sql("INSERT INTO users"),
SqlStatement::Insert
);
assert_eq!(
SqlStatement::from_sql("UPDATE users SET"),
SqlStatement::Update
);
assert_eq!(
SqlStatement::from_sql("DELETE FROM users"),
SqlStatement::Delete
);
assert_eq!(
SqlStatement::from_sql("CREATE TABLE users"),
SqlStatement::Create
);
}
}