1use std::collections::hash_map::DefaultHasher;
6use std::hash::{Hash, Hasher};
7
8pub struct SqlParser {
10 #[allow(dead_code)]
12 dialect: SqlDialect,
13}
14
15#[derive(Debug, Clone, Copy, Default)]
17pub enum SqlDialect {
18 #[default]
19 PostgreSQL,
20 MySQL,
21 SQLite,
22}
23
24impl SqlParser {
25 pub fn new() -> Self {
27 Self {
28 dialect: SqlDialect::PostgreSQL,
29 }
30 }
31
32 pub fn with_dialect(dialect: SqlDialect) -> Self {
34 Self { dialect }
35 }
36
37 pub fn parse(&self, sql: &str) -> Result<ParsedQuery, ParseError> {
39 let trimmed = sql.trim();
40
41 if trimmed.is_empty() {
42 return Err(ParseError::EmptyQuery);
43 }
44
45 let upper = trimmed.to_uppercase();
46 let first_word = upper.split_whitespace().next().unwrap_or("");
47
48 let is_select = first_word == "SELECT";
49 let is_insert = first_word == "INSERT";
50 let is_update = first_word == "UPDATE";
51 let is_delete = first_word == "DELETE";
52 let is_ddl = matches!(first_word, "CREATE" | "ALTER" | "DROP" | "TRUNCATE");
53
54 let tables = self.extract_tables(trimmed);
55 let has_select_star = is_select && self.has_select_star(trimmed);
56 let has_limit = upper.contains(" LIMIT ");
57 let has_where = upper.contains(" WHERE ");
58
59 let normalized = self.normalize(trimmed);
60
61 Ok(ParsedQuery {
62 original: trimmed.to_string(),
63 normalized,
64 tables,
65 has_select_star,
66 has_limit,
67 has_where,
68 is_select,
69 is_insert,
70 is_update,
71 is_delete,
72 is_ddl,
73 })
74 }
75
76 pub fn normalize(&self, sql: &str) -> String {
78 let mut result = String::with_capacity(sql.len());
79 let mut chars = sql.chars().peekable();
80
81 while let Some(c) = chars.next() {
82 match c {
83 '\'' => {
85 result.push('?');
86 let mut escaped = false;
87 for inner in chars.by_ref() {
88 if inner == '\'' && !escaped {
89 break;
90 }
91 escaped = inner == '\\' && !escaped;
92 }
93 }
94 '"' => {
96 result.push(c);
97 for inner in chars.by_ref() {
98 result.push(inner);
99 if inner == '"' {
100 break;
101 }
102 }
103 }
104 '0'..='9' => {
106 result.push('?');
107 while chars
108 .peek()
109 .map(|c| c.is_ascii_digit() || *c == '.')
110 .unwrap_or(false)
111 {
112 chars.next();
113 }
114 }
115 '$' => {
117 result.push('?');
118 while chars.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
119 chars.next();
120 }
121 }
122 _ => result.push(c),
124 }
125 }
126
127 let mut prev_space = false;
129 result
130 .chars()
131 .filter(|&c| {
132 if c.is_whitespace() {
133 if prev_space {
134 return false;
135 }
136 prev_space = true;
137 } else {
138 prev_space = false;
139 }
140 true
141 })
142 .collect::<String>()
143 .trim()
144 .to_string()
145 }
146
147 fn extract_tables(&self, sql: &str) -> Vec<String> {
149 let mut tables = Vec::new();
150 let upper = sql.to_uppercase();
151 let words: Vec<&str> = sql.split_whitespace().collect();
152 let upper_words: Vec<&str> = upper.split_whitespace().collect();
153
154 let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
156
157 for (i, word) in upper_words.iter().enumerate() {
158 if table_keywords.contains(&word.trim_end_matches(',')) {
159 if let Some(table) = words.get(i + 1) {
160 let table =
161 table.trim_matches(|c| c == ',' || c == '(' || c == ')' || c == ';');
162 if !table.is_empty() && !is_keyword(table) {
163 let table_name = table.split('.').next_back().unwrap_or(table);
165 tables.push(table_name.to_string());
166 }
167 }
168 }
169 }
170
171 tables.sort();
173 tables.dedup();
174 tables
175 }
176
177 fn has_select_star(&self, sql: &str) -> bool {
179 let upper = sql.to_uppercase();
180
181 if let Some(select_pos) = upper.find("SELECT") {
183 let after_select = &upper[select_pos + 6..];
184 let trimmed = after_select.trim_start();
185
186 if trimmed.starts_with("*") {
188 return true;
189 }
190 if let Some(after_distinct) = trimmed.strip_prefix("DISTINCT") {
191 if after_distinct.trim_start().starts_with('*') {
192 return true;
193 }
194 }
195 if let Some(after_all) = trimmed.strip_prefix("ALL") {
196 if after_all.trim_start().starts_with('*') {
197 return true;
198 }
199 }
200 }
201
202 false
203 }
204
205 pub fn to_sql(&self, parsed: &ParsedQuery) -> String {
207 parsed.original.clone()
210 }
211}
212
213impl Default for SqlParser {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct ParsedQuery {
222 pub original: String,
224
225 pub normalized: String,
227
228 pub tables: Vec<String>,
230
231 pub has_select_star: bool,
233
234 pub has_limit: bool,
236
237 pub has_where: bool,
239
240 pub is_select: bool,
242
243 pub is_insert: bool,
245
246 pub is_update: bool,
248
249 pub is_delete: bool,
251
252 pub is_ddl: bool,
254}
255
256impl ParsedQuery {
257 pub fn fingerprint(&self) -> u64 {
259 let mut hasher = DefaultHasher::new();
260 self.normalized.to_uppercase().hash(&mut hasher);
261 hasher.finish()
262 }
263
264 pub fn is_write(&self) -> bool {
266 self.is_insert || self.is_update || self.is_delete || self.is_ddl
267 }
268
269 pub fn is_read(&self) -> bool {
271 self.is_select && !self.is_ddl
272 }
273}
274
275#[derive(Debug, Clone, Copy, PartialEq, Eq)]
277pub enum SqlStatement {
278 Select,
279 Insert,
280 Update,
281 Delete,
282 Create,
283 Alter,
284 Drop,
285 Truncate,
286 Other,
287}
288
289impl SqlStatement {
290 pub fn from_sql(sql: &str) -> Self {
292 let first_word = sql.split_whitespace().next().unwrap_or("");
293 match first_word.to_uppercase().as_str() {
294 "SELECT" => Self::Select,
295 "INSERT" => Self::Insert,
296 "UPDATE" => Self::Update,
297 "DELETE" => Self::Delete,
298 "CREATE" => Self::Create,
299 "ALTER" => Self::Alter,
300 "DROP" => Self::Drop,
301 "TRUNCATE" => Self::Truncate,
302 _ => Self::Other,
303 }
304 }
305
306 pub fn is_write(&self) -> bool {
308 matches!(
309 self,
310 Self::Insert
311 | Self::Update
312 | Self::Delete
313 | Self::Create
314 | Self::Alter
315 | Self::Drop
316 | Self::Truncate
317 )
318 }
319}
320
321#[derive(Debug, Clone)]
323pub enum ParseError {
324 EmptyQuery,
326
327 InvalidSyntax(String),
329
330 UnsupportedStatement(String),
332}
333
334impl std::fmt::Display for ParseError {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 match self {
337 Self::EmptyQuery => write!(f, "Empty query"),
338 Self::InvalidSyntax(msg) => write!(f, "Invalid syntax: {}", msg),
339 Self::UnsupportedStatement(stmt) => write!(f, "Unsupported statement: {}", stmt),
340 }
341 }
342}
343
344impl std::error::Error for ParseError {}
345
346impl From<ParseError> for super::RewriteError {
347 fn from(e: ParseError) -> Self {
348 super::RewriteError::ParseError(e.to_string())
349 }
350}
351
352fn is_keyword(word: &str) -> bool {
354 let upper = word.to_uppercase();
355 matches!(
356 upper.as_str(),
357 "SELECT"
358 | "FROM"
359 | "WHERE"
360 | "AND"
361 | "OR"
362 | "NOT"
363 | "INSERT"
364 | "INTO"
365 | "VALUES"
366 | "UPDATE"
367 | "SET"
368 | "DELETE"
369 | "CREATE"
370 | "ALTER"
371 | "DROP"
372 | "TABLE"
373 | "INDEX"
374 | "VIEW"
375 | "JOIN"
376 | "LEFT"
377 | "RIGHT"
378 | "INNER"
379 | "OUTER"
380 | "CROSS"
381 | "ON"
382 | "GROUP"
383 | "BY"
384 | "ORDER"
385 | "HAVING"
386 | "LIMIT"
387 | "OFFSET"
388 | "UNION"
389 | "INTERSECT"
390 | "EXCEPT"
391 | "AS"
392 | "DISTINCT"
393 | "ALL"
394 | "NULL"
395 | "TRUE"
396 | "FALSE"
397 | "CASE"
398 | "WHEN"
399 | "THEN"
400 | "ELSE"
401 | "END"
402 | "EXISTS"
403 | "IN"
404 | "BETWEEN"
405 | "LIKE"
406 | "IS"
407 | "ASC"
408 | "DESC"
409 )
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_parse_select() {
418 let parser = SqlParser::new();
419 let parsed = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
420
421 assert!(parsed.is_select);
422 assert!(parsed.has_select_star);
423 assert!(parsed.has_where);
424 assert!(!parsed.has_limit);
425 assert!(parsed.tables.contains(&"users".to_string()));
426 }
427
428 #[test]
429 fn test_parse_insert() {
430 let parser = SqlParser::new();
431 let parsed = parser
432 .parse("INSERT INTO users (name) VALUES ('test')")
433 .unwrap();
434
435 assert!(parsed.is_insert);
436 assert!(parsed.tables.contains(&"users".to_string()));
437 }
438
439 #[test]
440 fn test_normalize() {
441 let parser = SqlParser::new();
442
443 let normalized = parser.normalize("SELECT * FROM users WHERE id = 123 AND name = 'test'");
444 assert!(normalized.contains("id = ?"));
445 assert!(normalized.contains("name = ?"));
446 }
447
448 #[test]
449 fn test_fingerprint() {
450 let parser = SqlParser::new();
451
452 let q1 = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
453 let q2 = parser.parse("SELECT * FROM users WHERE id = 2").unwrap();
454 let q3 = parser.parse("SELECT * FROM orders WHERE id = 1").unwrap();
455
456 assert_eq!(q1.fingerprint(), q2.fingerprint());
458 assert_ne!(q1.fingerprint(), q3.fingerprint());
460 }
461
462 #[test]
463 fn test_extract_tables() {
464 let parser = SqlParser::new();
465
466 let parsed = parser
467 .parse("SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id")
468 .unwrap();
469
470 assert!(
471 parsed.tables.contains(&"u".to_string())
472 || parsed.tables.contains(&"users".to_string())
473 );
474 }
475
476 #[test]
477 fn test_has_select_star() {
478 let parser = SqlParser::new();
479
480 assert!(parser.has_select_star("SELECT * FROM users"));
481 assert!(parser.has_select_star("SELECT DISTINCT * FROM users"));
482 assert!(!parser.has_select_star("SELECT id, name FROM users"));
483 }
484
485 #[test]
486 fn test_empty_query() {
487 let parser = SqlParser::new();
488 assert!(matches!(parser.parse(""), Err(ParseError::EmptyQuery)));
489 assert!(matches!(parser.parse(" "), Err(ParseError::EmptyQuery)));
490 }
491
492 #[test]
493 fn test_sql_statement_type() {
494 assert_eq!(
495 SqlStatement::from_sql("SELECT * FROM users"),
496 SqlStatement::Select
497 );
498 assert_eq!(
499 SqlStatement::from_sql("INSERT INTO users"),
500 SqlStatement::Insert
501 );
502 assert_eq!(
503 SqlStatement::from_sql("UPDATE users SET"),
504 SqlStatement::Update
505 );
506 assert_eq!(
507 SqlStatement::from_sql("DELETE FROM users"),
508 SqlStatement::Delete
509 );
510 assert_eq!(
511 SqlStatement::from_sql("CREATE TABLE users"),
512 SqlStatement::Create
513 );
514 }
515}