Skip to main content

heliosdb_proxy/rewriter/
parser.rs

1//! SQL Parser
2//!
3//! SQL parsing utilities for query rewriting.
4
5use std::collections::hash_map::DefaultHasher;
6use std::hash::{Hash, Hasher};
7
8/// SQL parser
9pub struct SqlParser {
10    /// Dialect-specific settings
11    #[allow(dead_code)]
12    dialect: SqlDialect,
13}
14
15/// SQL dialect
16#[derive(Debug, Clone, Copy, Default)]
17pub enum SqlDialect {
18    #[default]
19    PostgreSQL,
20    MySQL,
21    SQLite,
22}
23
24impl SqlParser {
25    /// Create a new parser
26    pub fn new() -> Self {
27        Self {
28            dialect: SqlDialect::PostgreSQL,
29        }
30    }
31
32    /// Create a parser with specific dialect
33    pub fn with_dialect(dialect: SqlDialect) -> Self {
34        Self { dialect }
35    }
36
37    /// Parse a SQL query
38    pub fn parse(&self, sql: &str) -> Result<ParsedQuery, ParseError> {
39        let trimmed = sql.trim();
40
41        if trimmed.is_empty() {
42            return Err(ParseError::EmptyQuery);
43        }
44
45        let upper = trimmed.to_uppercase();
46        let first_word = upper.split_whitespace().next().unwrap_or("");
47
48        let is_select = first_word == "SELECT";
49        let is_insert = first_word == "INSERT";
50        let is_update = first_word == "UPDATE";
51        let is_delete = first_word == "DELETE";
52        let is_ddl = matches!(first_word, "CREATE" | "ALTER" | "DROP" | "TRUNCATE");
53
54        let tables = self.extract_tables(trimmed);
55        let has_select_star = is_select && self.has_select_star(trimmed);
56        let has_limit = upper.contains(" LIMIT ");
57        let has_where = upper.contains(" WHERE ");
58
59        let normalized = self.normalize(trimmed);
60
61        Ok(ParsedQuery {
62            original: trimmed.to_string(),
63            normalized,
64            tables,
65            has_select_star,
66            has_limit,
67            has_where,
68            is_select,
69            is_insert,
70            is_update,
71            is_delete,
72            is_ddl,
73        })
74    }
75
76    /// Normalize a query (replace literals with placeholders)
77    pub fn normalize(&self, sql: &str) -> String {
78        let mut result = String::with_capacity(sql.len());
79        let mut chars = sql.chars().peekable();
80
81        while let Some(c) = chars.next() {
82            match c {
83                // String literals
84                '\'' => {
85                    result.push('?');
86                    let mut escaped = false;
87                    for inner in chars.by_ref() {
88                        if inner == '\'' && !escaped {
89                            break;
90                        }
91                        escaped = inner == '\\' && !escaped;
92                    }
93                }
94                // Double-quoted identifiers (keep them)
95                '"' => {
96                    result.push(c);
97                    for inner in chars.by_ref() {
98                        result.push(inner);
99                        if inner == '"' {
100                            break;
101                        }
102                    }
103                }
104                // Numbers
105                '0'..='9' => {
106                    result.push('?');
107                    while chars
108                        .peek()
109                        .map(|c| c.is_ascii_digit() || *c == '.')
110                        .unwrap_or(false)
111                    {
112                        chars.next();
113                    }
114                }
115                // Parameter placeholders
116                '$' => {
117                    result.push('?');
118                    while chars.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
119                        chars.next();
120                    }
121                }
122                // Everything else
123                _ => result.push(c),
124            }
125        }
126
127        // Collapse whitespace
128        let mut prev_space = false;
129        result
130            .chars()
131            .filter(|&c| {
132                if c.is_whitespace() {
133                    if prev_space {
134                        return false;
135                    }
136                    prev_space = true;
137                } else {
138                    prev_space = false;
139                }
140                true
141            })
142            .collect::<String>()
143            .trim()
144            .to_string()
145    }
146
147    /// Extract table names from query
148    fn extract_tables(&self, sql: &str) -> Vec<String> {
149        let mut tables = Vec::new();
150        let upper = sql.to_uppercase();
151        let words: Vec<&str> = sql.split_whitespace().collect();
152        let upper_words: Vec<&str> = upper.split_whitespace().collect();
153
154        // Look for FROM, JOIN, INTO, UPDATE table names
155        let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
156
157        for (i, word) in upper_words.iter().enumerate() {
158            if table_keywords.contains(&word.trim_end_matches(',')) {
159                if let Some(table) = words.get(i + 1) {
160                    let table =
161                        table.trim_matches(|c| c == ',' || c == '(' || c == ')' || c == ';');
162                    if !table.is_empty() && !is_keyword(table) {
163                        // Handle schema.table format
164                        let table_name = table.split('.').next_back().unwrap_or(table);
165                        tables.push(table_name.to_string());
166                    }
167                }
168            }
169        }
170
171        // Deduplicate
172        tables.sort();
173        tables.dedup();
174        tables
175    }
176
177    /// Check if query has SELECT *
178    fn has_select_star(&self, sql: &str) -> bool {
179        let upper = sql.to_uppercase();
180
181        // Check for SELECT * (with potential whitespace variations)
182        if let Some(select_pos) = upper.find("SELECT") {
183            let after_select = &upper[select_pos + 6..];
184            let trimmed = after_select.trim_start();
185
186            // Check for SELECT * or SELECT DISTINCT *
187            if trimmed.starts_with("*") {
188                return true;
189            }
190            if let Some(after_distinct) = trimmed.strip_prefix("DISTINCT") {
191                if after_distinct.trim_start().starts_with('*') {
192                    return true;
193                }
194            }
195            if let Some(after_all) = trimmed.strip_prefix("ALL") {
196                if after_all.trim_start().starts_with('*') {
197                    return true;
198                }
199            }
200        }
201
202        false
203    }
204
205    /// Convert AST back to SQL
206    pub fn to_sql(&self, parsed: &ParsedQuery) -> String {
207        // For now, return the normalized version
208        // In production, use sqlparser-rs for full AST manipulation
209        parsed.original.clone()
210    }
211}
212
213impl Default for SqlParser {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219/// Parsed query representation
220#[derive(Debug, Clone)]
221pub struct ParsedQuery {
222    /// Original query
223    pub original: String,
224
225    /// Normalized query (literals replaced)
226    pub normalized: String,
227
228    /// Tables referenced
229    pub tables: Vec<String>,
230
231    /// Has SELECT *
232    pub has_select_star: bool,
233
234    /// Has LIMIT clause
235    pub has_limit: bool,
236
237    /// Has WHERE clause
238    pub has_where: bool,
239
240    /// Is SELECT statement
241    pub is_select: bool,
242
243    /// Is INSERT statement
244    pub is_insert: bool,
245
246    /// Is UPDATE statement
247    pub is_update: bool,
248
249    /// Is DELETE statement
250    pub is_delete: bool,
251
252    /// Is DDL statement
253    pub is_ddl: bool,
254}
255
256impl ParsedQuery {
257    /// Calculate query fingerprint
258    pub fn fingerprint(&self) -> u64 {
259        let mut hasher = DefaultHasher::new();
260        self.normalized.to_uppercase().hash(&mut hasher);
261        hasher.finish()
262    }
263
264    /// Check if query modifies data
265    pub fn is_write(&self) -> bool {
266        self.is_insert || self.is_update || self.is_delete || self.is_ddl
267    }
268
269    /// Check if query is read-only
270    pub fn is_read(&self) -> bool {
271        self.is_select && !self.is_ddl
272    }
273}
274
275/// SQL statement type
276#[derive(Debug, Clone, Copy, PartialEq, Eq)]
277pub enum SqlStatement {
278    Select,
279    Insert,
280    Update,
281    Delete,
282    Create,
283    Alter,
284    Drop,
285    Truncate,
286    Other,
287}
288
289impl SqlStatement {
290    /// Parse from SQL string
291    pub fn from_sql(sql: &str) -> Self {
292        let first_word = sql.split_whitespace().next().unwrap_or("");
293        match first_word.to_uppercase().as_str() {
294            "SELECT" => Self::Select,
295            "INSERT" => Self::Insert,
296            "UPDATE" => Self::Update,
297            "DELETE" => Self::Delete,
298            "CREATE" => Self::Create,
299            "ALTER" => Self::Alter,
300            "DROP" => Self::Drop,
301            "TRUNCATE" => Self::Truncate,
302            _ => Self::Other,
303        }
304    }
305
306    /// Check if statement modifies data
307    pub fn is_write(&self) -> bool {
308        matches!(
309            self,
310            Self::Insert
311                | Self::Update
312                | Self::Delete
313                | Self::Create
314                | Self::Alter
315                | Self::Drop
316                | Self::Truncate
317        )
318    }
319}
320
321/// Parse error
322#[derive(Debug, Clone)]
323pub enum ParseError {
324    /// Empty query
325    EmptyQuery,
326
327    /// Invalid syntax
328    InvalidSyntax(String),
329
330    /// Unsupported statement
331    UnsupportedStatement(String),
332}
333
334impl std::fmt::Display for ParseError {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        match self {
337            Self::EmptyQuery => write!(f, "Empty query"),
338            Self::InvalidSyntax(msg) => write!(f, "Invalid syntax: {}", msg),
339            Self::UnsupportedStatement(stmt) => write!(f, "Unsupported statement: {}", stmt),
340        }
341    }
342}
343
344impl std::error::Error for ParseError {}
345
346impl From<ParseError> for super::RewriteError {
347    fn from(e: ParseError) -> Self {
348        super::RewriteError::ParseError(e.to_string())
349    }
350}
351
352/// Check if a word is a SQL keyword
353fn is_keyword(word: &str) -> bool {
354    let upper = word.to_uppercase();
355    matches!(
356        upper.as_str(),
357        "SELECT"
358            | "FROM"
359            | "WHERE"
360            | "AND"
361            | "OR"
362            | "NOT"
363            | "INSERT"
364            | "INTO"
365            | "VALUES"
366            | "UPDATE"
367            | "SET"
368            | "DELETE"
369            | "CREATE"
370            | "ALTER"
371            | "DROP"
372            | "TABLE"
373            | "INDEX"
374            | "VIEW"
375            | "JOIN"
376            | "LEFT"
377            | "RIGHT"
378            | "INNER"
379            | "OUTER"
380            | "CROSS"
381            | "ON"
382            | "GROUP"
383            | "BY"
384            | "ORDER"
385            | "HAVING"
386            | "LIMIT"
387            | "OFFSET"
388            | "UNION"
389            | "INTERSECT"
390            | "EXCEPT"
391            | "AS"
392            | "DISTINCT"
393            | "ALL"
394            | "NULL"
395            | "TRUE"
396            | "FALSE"
397            | "CASE"
398            | "WHEN"
399            | "THEN"
400            | "ELSE"
401            | "END"
402            | "EXISTS"
403            | "IN"
404            | "BETWEEN"
405            | "LIKE"
406            | "IS"
407            | "ASC"
408            | "DESC"
409    )
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_parse_select() {
418        let parser = SqlParser::new();
419        let parsed = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
420
421        assert!(parsed.is_select);
422        assert!(parsed.has_select_star);
423        assert!(parsed.has_where);
424        assert!(!parsed.has_limit);
425        assert!(parsed.tables.contains(&"users".to_string()));
426    }
427
428    #[test]
429    fn test_parse_insert() {
430        let parser = SqlParser::new();
431        let parsed = parser
432            .parse("INSERT INTO users (name) VALUES ('test')")
433            .unwrap();
434
435        assert!(parsed.is_insert);
436        assert!(parsed.tables.contains(&"users".to_string()));
437    }
438
439    #[test]
440    fn test_normalize() {
441        let parser = SqlParser::new();
442
443        let normalized = parser.normalize("SELECT * FROM users WHERE id = 123 AND name = 'test'");
444        assert!(normalized.contains("id = ?"));
445        assert!(normalized.contains("name = ?"));
446    }
447
448    #[test]
449    fn test_fingerprint() {
450        let parser = SqlParser::new();
451
452        let q1 = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
453        let q2 = parser.parse("SELECT * FROM users WHERE id = 2").unwrap();
454        let q3 = parser.parse("SELECT * FROM orders WHERE id = 1").unwrap();
455
456        // Same query structure should have same fingerprint
457        assert_eq!(q1.fingerprint(), q2.fingerprint());
458        // Different query structure should have different fingerprint
459        assert_ne!(q1.fingerprint(), q3.fingerprint());
460    }
461
462    #[test]
463    fn test_extract_tables() {
464        let parser = SqlParser::new();
465
466        let parsed = parser
467            .parse("SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id")
468            .unwrap();
469
470        assert!(
471            parsed.tables.contains(&"u".to_string())
472                || parsed.tables.contains(&"users".to_string())
473        );
474    }
475
476    #[test]
477    fn test_has_select_star() {
478        let parser = SqlParser::new();
479
480        assert!(parser.has_select_star("SELECT * FROM users"));
481        assert!(parser.has_select_star("SELECT DISTINCT * FROM users"));
482        assert!(!parser.has_select_star("SELECT id, name FROM users"));
483    }
484
485    #[test]
486    fn test_empty_query() {
487        let parser = SqlParser::new();
488        assert!(matches!(parser.parse(""), Err(ParseError::EmptyQuery)));
489        assert!(matches!(parser.parse("   "), Err(ParseError::EmptyQuery)));
490    }
491
492    #[test]
493    fn test_sql_statement_type() {
494        assert_eq!(
495            SqlStatement::from_sql("SELECT * FROM users"),
496            SqlStatement::Select
497        );
498        assert_eq!(
499            SqlStatement::from_sql("INSERT INTO users"),
500            SqlStatement::Insert
501        );
502        assert_eq!(
503            SqlStatement::from_sql("UPDATE users SET"),
504            SqlStatement::Update
505        );
506        assert_eq!(
507            SqlStatement::from_sql("DELETE FROM users"),
508            SqlStatement::Delete
509        );
510        assert_eq!(
511            SqlStatement::from_sql("CREATE TABLE users"),
512            SqlStatement::Create
513        );
514    }
515}