1use std::collections::hash_map::DefaultHasher;
6use std::hash::{Hash, Hasher};
7
8pub struct SqlParser {
10 dialect: SqlDialect,
12}
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum SqlDialect {
17 #[default]
18 PostgreSQL,
19 MySQL,
20 SQLite,
21}
22
23impl SqlParser {
24 pub fn new() -> Self {
26 Self {
27 dialect: SqlDialect::PostgreSQL,
28 }
29 }
30
31 pub fn with_dialect(dialect: SqlDialect) -> Self {
33 Self { dialect }
34 }
35
36 pub fn parse(&self, sql: &str) -> Result<ParsedQuery, ParseError> {
38 let trimmed = sql.trim();
39
40 if trimmed.is_empty() {
41 return Err(ParseError::EmptyQuery);
42 }
43
44 let upper = trimmed.to_uppercase();
45 let first_word = upper.split_whitespace().next().unwrap_or("");
46
47 let is_select = first_word == "SELECT";
48 let is_insert = first_word == "INSERT";
49 let is_update = first_word == "UPDATE";
50 let is_delete = first_word == "DELETE";
51 let is_ddl = matches!(first_word, "CREATE" | "ALTER" | "DROP" | "TRUNCATE");
52
53 let tables = self.extract_tables(trimmed);
54 let has_select_star = is_select && self.has_select_star(trimmed);
55 let has_limit = upper.contains(" LIMIT ");
56 let has_where = upper.contains(" WHERE ");
57
58 let normalized = self.normalize(trimmed);
59
60 Ok(ParsedQuery {
61 original: trimmed.to_string(),
62 normalized,
63 tables,
64 has_select_star,
65 has_limit,
66 has_where,
67 is_select,
68 is_insert,
69 is_update,
70 is_delete,
71 is_ddl,
72 })
73 }
74
75 pub fn normalize(&self, sql: &str) -> String {
77 let mut result = String::with_capacity(sql.len());
78 let mut chars = sql.chars().peekable();
79
80 while let Some(c) = chars.next() {
81 match c {
82 '\'' => {
84 result.push('?');
85 let mut escaped = false;
86 for inner in chars.by_ref() {
87 if inner == '\'' && !escaped {
88 break;
89 }
90 escaped = inner == '\\' && !escaped;
91 }
92 }
93 '"' => {
95 result.push(c);
96 for inner in chars.by_ref() {
97 result.push(inner);
98 if inner == '"' {
99 break;
100 }
101 }
102 }
103 '0'..='9' => {
105 result.push('?');
106 while chars.peek().map(|c| c.is_ascii_digit() || *c == '.').unwrap_or(false) {
107 chars.next();
108 }
109 }
110 '$' => {
112 result.push('?');
113 while chars.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
114 chars.next();
115 }
116 }
117 _ => result.push(c),
119 }
120 }
121
122 let mut prev_space = false;
124 result.chars().filter(|&c| {
125 if c.is_whitespace() {
126 if prev_space {
127 return false;
128 }
129 prev_space = true;
130 } else {
131 prev_space = false;
132 }
133 true
134 }).collect::<String>().trim().to_string()
135 }
136
137 fn extract_tables(&self, sql: &str) -> Vec<String> {
139 let mut tables = Vec::new();
140 let upper = sql.to_uppercase();
141 let words: Vec<&str> = sql.split_whitespace().collect();
142 let upper_words: Vec<&str> = upper.split_whitespace().collect();
143
144 let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
146
147 for (i, word) in upper_words.iter().enumerate() {
148 if table_keywords.contains(&word.trim_end_matches(',')) {
149 if let Some(table) = words.get(i + 1) {
150 let table = table.trim_matches(|c| c == ',' || c == '(' || c == ')' || c == ';');
151 if !table.is_empty() && !is_keyword(table) {
152 let table_name = table.split('.').last().unwrap_or(table);
154 tables.push(table_name.to_string());
155 }
156 }
157 }
158 }
159
160 tables.sort();
162 tables.dedup();
163 tables
164 }
165
166 fn has_select_star(&self, sql: &str) -> bool {
168 let upper = sql.to_uppercase();
169
170 if let Some(select_pos) = upper.find("SELECT") {
172 let after_select = &upper[select_pos + 6..];
173 let trimmed = after_select.trim_start();
174
175 if trimmed.starts_with("*") {
177 return true;
178 }
179 if trimmed.starts_with("DISTINCT") {
180 let after_distinct = trimmed[8..].trim_start();
181 if after_distinct.starts_with("*") {
182 return true;
183 }
184 }
185 if trimmed.starts_with("ALL") {
186 let after_all = trimmed[3..].trim_start();
187 if after_all.starts_with("*") {
188 return true;
189 }
190 }
191 }
192
193 false
194 }
195
196 pub fn to_sql(&self, parsed: &ParsedQuery) -> String {
198 parsed.original.clone()
201 }
202}
203
204impl Default for SqlParser {
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct ParsedQuery {
213 pub original: String,
215
216 pub normalized: String,
218
219 pub tables: Vec<String>,
221
222 pub has_select_star: bool,
224
225 pub has_limit: bool,
227
228 pub has_where: bool,
230
231 pub is_select: bool,
233
234 pub is_insert: bool,
236
237 pub is_update: bool,
239
240 pub is_delete: bool,
242
243 pub is_ddl: bool,
245}
246
247impl ParsedQuery {
248 pub fn fingerprint(&self) -> u64 {
250 let mut hasher = DefaultHasher::new();
251 self.normalized.to_uppercase().hash(&mut hasher);
252 hasher.finish()
253 }
254
255 pub fn is_write(&self) -> bool {
257 self.is_insert || self.is_update || self.is_delete || self.is_ddl
258 }
259
260 pub fn is_read(&self) -> bool {
262 self.is_select && !self.is_ddl
263 }
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq)]
268pub enum SqlStatement {
269 Select,
270 Insert,
271 Update,
272 Delete,
273 Create,
274 Alter,
275 Drop,
276 Truncate,
277 Other,
278}
279
280impl SqlStatement {
281 pub fn from_sql(sql: &str) -> Self {
283 let first_word = sql.trim().split_whitespace().next().unwrap_or("");
284 match first_word.to_uppercase().as_str() {
285 "SELECT" => Self::Select,
286 "INSERT" => Self::Insert,
287 "UPDATE" => Self::Update,
288 "DELETE" => Self::Delete,
289 "CREATE" => Self::Create,
290 "ALTER" => Self::Alter,
291 "DROP" => Self::Drop,
292 "TRUNCATE" => Self::Truncate,
293 _ => Self::Other,
294 }
295 }
296
297 pub fn is_write(&self) -> bool {
299 matches!(self, Self::Insert | Self::Update | Self::Delete | Self::Create | Self::Alter | Self::Drop | Self::Truncate)
300 }
301}
302
303#[derive(Debug, Clone)]
305pub enum ParseError {
306 EmptyQuery,
308
309 InvalidSyntax(String),
311
312 UnsupportedStatement(String),
314}
315
316impl std::fmt::Display for ParseError {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 match self {
319 Self::EmptyQuery => write!(f, "Empty query"),
320 Self::InvalidSyntax(msg) => write!(f, "Invalid syntax: {}", msg),
321 Self::UnsupportedStatement(stmt) => write!(f, "Unsupported statement: {}", stmt),
322 }
323 }
324}
325
326impl std::error::Error for ParseError {}
327
328impl From<ParseError> for super::RewriteError {
329 fn from(e: ParseError) -> Self {
330 super::RewriteError::ParseError(e.to_string())
331 }
332}
333
334fn is_keyword(word: &str) -> bool {
336 let upper = word.to_uppercase();
337 matches!(upper.as_str(),
338 "SELECT" | "FROM" | "WHERE" | "AND" | "OR" | "NOT" |
339 "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET" | "DELETE" |
340 "CREATE" | "ALTER" | "DROP" | "TABLE" | "INDEX" | "VIEW" |
341 "JOIN" | "LEFT" | "RIGHT" | "INNER" | "OUTER" | "CROSS" | "ON" |
342 "GROUP" | "BY" | "ORDER" | "HAVING" | "LIMIT" | "OFFSET" |
343 "UNION" | "INTERSECT" | "EXCEPT" | "AS" | "DISTINCT" | "ALL" |
344 "NULL" | "TRUE" | "FALSE" | "CASE" | "WHEN" | "THEN" | "ELSE" | "END" |
345 "EXISTS" | "IN" | "BETWEEN" | "LIKE" | "IS" | "ASC" | "DESC"
346 )
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_parse_select() {
355 let parser = SqlParser::new();
356 let parsed = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
357
358 assert!(parsed.is_select);
359 assert!(parsed.has_select_star);
360 assert!(parsed.has_where);
361 assert!(!parsed.has_limit);
362 assert!(parsed.tables.contains(&"users".to_string()));
363 }
364
365 #[test]
366 fn test_parse_insert() {
367 let parser = SqlParser::new();
368 let parsed = parser.parse("INSERT INTO users (name) VALUES ('test')").unwrap();
369
370 assert!(parsed.is_insert);
371 assert!(parsed.tables.contains(&"users".to_string()));
372 }
373
374 #[test]
375 fn test_normalize() {
376 let parser = SqlParser::new();
377
378 let normalized = parser.normalize("SELECT * FROM users WHERE id = 123 AND name = 'test'");
379 assert!(normalized.contains("id = ?"));
380 assert!(normalized.contains("name = ?"));
381 }
382
383 #[test]
384 fn test_fingerprint() {
385 let parser = SqlParser::new();
386
387 let q1 = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
388 let q2 = parser.parse("SELECT * FROM users WHERE id = 2").unwrap();
389 let q3 = parser.parse("SELECT * FROM orders WHERE id = 1").unwrap();
390
391 assert_eq!(q1.fingerprint(), q2.fingerprint());
393 assert_ne!(q1.fingerprint(), q3.fingerprint());
395 }
396
397 #[test]
398 fn test_extract_tables() {
399 let parser = SqlParser::new();
400
401 let parsed = parser.parse(
402 "SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id"
403 ).unwrap();
404
405 assert!(parsed.tables.contains(&"u".to_string()) || parsed.tables.contains(&"users".to_string()));
406 }
407
408 #[test]
409 fn test_has_select_star() {
410 let parser = SqlParser::new();
411
412 assert!(parser.has_select_star("SELECT * FROM users"));
413 assert!(parser.has_select_star("SELECT DISTINCT * FROM users"));
414 assert!(!parser.has_select_star("SELECT id, name FROM users"));
415 }
416
417 #[test]
418 fn test_empty_query() {
419 let parser = SqlParser::new();
420 assert!(matches!(parser.parse(""), Err(ParseError::EmptyQuery)));
421 assert!(matches!(parser.parse(" "), Err(ParseError::EmptyQuery)));
422 }
423
424 #[test]
425 fn test_sql_statement_type() {
426 assert_eq!(SqlStatement::from_sql("SELECT * FROM users"), SqlStatement::Select);
427 assert_eq!(SqlStatement::from_sql("INSERT INTO users"), SqlStatement::Insert);
428 assert_eq!(SqlStatement::from_sql("UPDATE users SET"), SqlStatement::Update);
429 assert_eq!(SqlStatement::from_sql("DELETE FROM users"), SqlStatement::Delete);
430 assert_eq!(SqlStatement::from_sql("CREATE TABLE users"), SqlStatement::Create);
431 }
432}