Skip to main content

heliosdb_proxy/rewriter/
parser.rs

1//! SQL Parser
2//!
3//! SQL parsing utilities for query rewriting.
4
5use std::collections::hash_map::DefaultHasher;
6use std::hash::{Hash, Hasher};
7
8/// SQL parser
9pub struct SqlParser {
10    /// Dialect-specific settings
11    dialect: SqlDialect,
12}
13
14/// SQL dialect
15#[derive(Debug, Clone, Copy, Default)]
16pub enum SqlDialect {
17    #[default]
18    PostgreSQL,
19    MySQL,
20    SQLite,
21}
22
23impl SqlParser {
24    /// Create a new parser
25    pub fn new() -> Self {
26        Self {
27            dialect: SqlDialect::PostgreSQL,
28        }
29    }
30
31    /// Create a parser with specific dialect
32    pub fn with_dialect(dialect: SqlDialect) -> Self {
33        Self { dialect }
34    }
35
36    /// Parse a SQL query
37    pub fn parse(&self, sql: &str) -> Result<ParsedQuery, ParseError> {
38        let trimmed = sql.trim();
39
40        if trimmed.is_empty() {
41            return Err(ParseError::EmptyQuery);
42        }
43
44        let upper = trimmed.to_uppercase();
45        let first_word = upper.split_whitespace().next().unwrap_or("");
46
47        let is_select = first_word == "SELECT";
48        let is_insert = first_word == "INSERT";
49        let is_update = first_word == "UPDATE";
50        let is_delete = first_word == "DELETE";
51        let is_ddl = matches!(first_word, "CREATE" | "ALTER" | "DROP" | "TRUNCATE");
52
53        let tables = self.extract_tables(trimmed);
54        let has_select_star = is_select && self.has_select_star(trimmed);
55        let has_limit = upper.contains(" LIMIT ");
56        let has_where = upper.contains(" WHERE ");
57
58        let normalized = self.normalize(trimmed);
59
60        Ok(ParsedQuery {
61            original: trimmed.to_string(),
62            normalized,
63            tables,
64            has_select_star,
65            has_limit,
66            has_where,
67            is_select,
68            is_insert,
69            is_update,
70            is_delete,
71            is_ddl,
72        })
73    }
74
75    /// Normalize a query (replace literals with placeholders)
76    pub fn normalize(&self, sql: &str) -> String {
77        let mut result = String::with_capacity(sql.len());
78        let mut chars = sql.chars().peekable();
79
80        while let Some(c) = chars.next() {
81            match c {
82                // String literals
83                '\'' => {
84                    result.push('?');
85                    let mut escaped = false;
86                    for inner in chars.by_ref() {
87                        if inner == '\'' && !escaped {
88                            break;
89                        }
90                        escaped = inner == '\\' && !escaped;
91                    }
92                }
93                // Double-quoted identifiers (keep them)
94                '"' => {
95                    result.push(c);
96                    for inner in chars.by_ref() {
97                        result.push(inner);
98                        if inner == '"' {
99                            break;
100                        }
101                    }
102                }
103                // Numbers
104                '0'..='9' => {
105                    result.push('?');
106                    while chars.peek().map(|c| c.is_ascii_digit() || *c == '.').unwrap_or(false) {
107                        chars.next();
108                    }
109                }
110                // Parameter placeholders
111                '$' => {
112                    result.push('?');
113                    while chars.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
114                        chars.next();
115                    }
116                }
117                // Everything else
118                _ => result.push(c),
119            }
120        }
121
122        // Collapse whitespace
123        let mut prev_space = false;
124        result.chars().filter(|&c| {
125            if c.is_whitespace() {
126                if prev_space {
127                    return false;
128                }
129                prev_space = true;
130            } else {
131                prev_space = false;
132            }
133            true
134        }).collect::<String>().trim().to_string()
135    }
136
137    /// Extract table names from query
138    fn extract_tables(&self, sql: &str) -> Vec<String> {
139        let mut tables = Vec::new();
140        let upper = sql.to_uppercase();
141        let words: Vec<&str> = sql.split_whitespace().collect();
142        let upper_words: Vec<&str> = upper.split_whitespace().collect();
143
144        // Look for FROM, JOIN, INTO, UPDATE table names
145        let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
146
147        for (i, word) in upper_words.iter().enumerate() {
148            if table_keywords.contains(&word.trim_end_matches(',')) {
149                if let Some(table) = words.get(i + 1) {
150                    let table = table.trim_matches(|c| c == ',' || c == '(' || c == ')' || c == ';');
151                    if !table.is_empty() && !is_keyword(table) {
152                        // Handle schema.table format
153                        let table_name = table.split('.').last().unwrap_or(table);
154                        tables.push(table_name.to_string());
155                    }
156                }
157            }
158        }
159
160        // Deduplicate
161        tables.sort();
162        tables.dedup();
163        tables
164    }
165
166    /// Check if query has SELECT *
167    fn has_select_star(&self, sql: &str) -> bool {
168        let upper = sql.to_uppercase();
169
170        // Check for SELECT * (with potential whitespace variations)
171        if let Some(select_pos) = upper.find("SELECT") {
172            let after_select = &upper[select_pos + 6..];
173            let trimmed = after_select.trim_start();
174
175            // Check for SELECT * or SELECT DISTINCT *
176            if trimmed.starts_with("*") {
177                return true;
178            }
179            if trimmed.starts_with("DISTINCT") {
180                let after_distinct = trimmed[8..].trim_start();
181                if after_distinct.starts_with("*") {
182                    return true;
183                }
184            }
185            if trimmed.starts_with("ALL") {
186                let after_all = trimmed[3..].trim_start();
187                if after_all.starts_with("*") {
188                    return true;
189                }
190            }
191        }
192
193        false
194    }
195
196    /// Convert AST back to SQL
197    pub fn to_sql(&self, parsed: &ParsedQuery) -> String {
198        // For now, return the normalized version
199        // In production, use sqlparser-rs for full AST manipulation
200        parsed.original.clone()
201    }
202}
203
204impl Default for SqlParser {
205    fn default() -> Self {
206        Self::new()
207    }
208}
209
210/// Parsed query representation
211#[derive(Debug, Clone)]
212pub struct ParsedQuery {
213    /// Original query
214    pub original: String,
215
216    /// Normalized query (literals replaced)
217    pub normalized: String,
218
219    /// Tables referenced
220    pub tables: Vec<String>,
221
222    /// Has SELECT *
223    pub has_select_star: bool,
224
225    /// Has LIMIT clause
226    pub has_limit: bool,
227
228    /// Has WHERE clause
229    pub has_where: bool,
230
231    /// Is SELECT statement
232    pub is_select: bool,
233
234    /// Is INSERT statement
235    pub is_insert: bool,
236
237    /// Is UPDATE statement
238    pub is_update: bool,
239
240    /// Is DELETE statement
241    pub is_delete: bool,
242
243    /// Is DDL statement
244    pub is_ddl: bool,
245}
246
247impl ParsedQuery {
248    /// Calculate query fingerprint
249    pub fn fingerprint(&self) -> u64 {
250        let mut hasher = DefaultHasher::new();
251        self.normalized.to_uppercase().hash(&mut hasher);
252        hasher.finish()
253    }
254
255    /// Check if query modifies data
256    pub fn is_write(&self) -> bool {
257        self.is_insert || self.is_update || self.is_delete || self.is_ddl
258    }
259
260    /// Check if query is read-only
261    pub fn is_read(&self) -> bool {
262        self.is_select && !self.is_ddl
263    }
264}
265
266/// SQL statement type
267#[derive(Debug, Clone, Copy, PartialEq, Eq)]
268pub enum SqlStatement {
269    Select,
270    Insert,
271    Update,
272    Delete,
273    Create,
274    Alter,
275    Drop,
276    Truncate,
277    Other,
278}
279
280impl SqlStatement {
281    /// Parse from SQL string
282    pub fn from_sql(sql: &str) -> Self {
283        let first_word = sql.trim().split_whitespace().next().unwrap_or("");
284        match first_word.to_uppercase().as_str() {
285            "SELECT" => Self::Select,
286            "INSERT" => Self::Insert,
287            "UPDATE" => Self::Update,
288            "DELETE" => Self::Delete,
289            "CREATE" => Self::Create,
290            "ALTER" => Self::Alter,
291            "DROP" => Self::Drop,
292            "TRUNCATE" => Self::Truncate,
293            _ => Self::Other,
294        }
295    }
296
297    /// Check if statement modifies data
298    pub fn is_write(&self) -> bool {
299        matches!(self, Self::Insert | Self::Update | Self::Delete | Self::Create | Self::Alter | Self::Drop | Self::Truncate)
300    }
301}
302
303/// Parse error
304#[derive(Debug, Clone)]
305pub enum ParseError {
306    /// Empty query
307    EmptyQuery,
308
309    /// Invalid syntax
310    InvalidSyntax(String),
311
312    /// Unsupported statement
313    UnsupportedStatement(String),
314}
315
316impl std::fmt::Display for ParseError {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        match self {
319            Self::EmptyQuery => write!(f, "Empty query"),
320            Self::InvalidSyntax(msg) => write!(f, "Invalid syntax: {}", msg),
321            Self::UnsupportedStatement(stmt) => write!(f, "Unsupported statement: {}", stmt),
322        }
323    }
324}
325
326impl std::error::Error for ParseError {}
327
328impl From<ParseError> for super::RewriteError {
329    fn from(e: ParseError) -> Self {
330        super::RewriteError::ParseError(e.to_string())
331    }
332}
333
334/// Check if a word is a SQL keyword
335fn is_keyword(word: &str) -> bool {
336    let upper = word.to_uppercase();
337    matches!(upper.as_str(),
338        "SELECT" | "FROM" | "WHERE" | "AND" | "OR" | "NOT" |
339        "INSERT" | "INTO" | "VALUES" | "UPDATE" | "SET" | "DELETE" |
340        "CREATE" | "ALTER" | "DROP" | "TABLE" | "INDEX" | "VIEW" |
341        "JOIN" | "LEFT" | "RIGHT" | "INNER" | "OUTER" | "CROSS" | "ON" |
342        "GROUP" | "BY" | "ORDER" | "HAVING" | "LIMIT" | "OFFSET" |
343        "UNION" | "INTERSECT" | "EXCEPT" | "AS" | "DISTINCT" | "ALL" |
344        "NULL" | "TRUE" | "FALSE" | "CASE" | "WHEN" | "THEN" | "ELSE" | "END" |
345        "EXISTS" | "IN" | "BETWEEN" | "LIKE" | "IS" | "ASC" | "DESC"
346    )
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_parse_select() {
355        let parser = SqlParser::new();
356        let parsed = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
357
358        assert!(parsed.is_select);
359        assert!(parsed.has_select_star);
360        assert!(parsed.has_where);
361        assert!(!parsed.has_limit);
362        assert!(parsed.tables.contains(&"users".to_string()));
363    }
364
365    #[test]
366    fn test_parse_insert() {
367        let parser = SqlParser::new();
368        let parsed = parser.parse("INSERT INTO users (name) VALUES ('test')").unwrap();
369
370        assert!(parsed.is_insert);
371        assert!(parsed.tables.contains(&"users".to_string()));
372    }
373
374    #[test]
375    fn test_normalize() {
376        let parser = SqlParser::new();
377
378        let normalized = parser.normalize("SELECT * FROM users WHERE id = 123 AND name = 'test'");
379        assert!(normalized.contains("id = ?"));
380        assert!(normalized.contains("name = ?"));
381    }
382
383    #[test]
384    fn test_fingerprint() {
385        let parser = SqlParser::new();
386
387        let q1 = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
388        let q2 = parser.parse("SELECT * FROM users WHERE id = 2").unwrap();
389        let q3 = parser.parse("SELECT * FROM orders WHERE id = 1").unwrap();
390
391        // Same query structure should have same fingerprint
392        assert_eq!(q1.fingerprint(), q2.fingerprint());
393        // Different query structure should have different fingerprint
394        assert_ne!(q1.fingerprint(), q3.fingerprint());
395    }
396
397    #[test]
398    fn test_extract_tables() {
399        let parser = SqlParser::new();
400
401        let parsed = parser.parse(
402            "SELECT u.*, o.total FROM users u JOIN orders o ON u.id = o.user_id"
403        ).unwrap();
404
405        assert!(parsed.tables.contains(&"u".to_string()) || parsed.tables.contains(&"users".to_string()));
406    }
407
408    #[test]
409    fn test_has_select_star() {
410        let parser = SqlParser::new();
411
412        assert!(parser.has_select_star("SELECT * FROM users"));
413        assert!(parser.has_select_star("SELECT DISTINCT * FROM users"));
414        assert!(!parser.has_select_star("SELECT id, name FROM users"));
415    }
416
417    #[test]
418    fn test_empty_query() {
419        let parser = SqlParser::new();
420        assert!(matches!(parser.parse(""), Err(ParseError::EmptyQuery)));
421        assert!(matches!(parser.parse("   "), Err(ParseError::EmptyQuery)));
422    }
423
424    #[test]
425    fn test_sql_statement_type() {
426        assert_eq!(SqlStatement::from_sql("SELECT * FROM users"), SqlStatement::Select);
427        assert_eq!(SqlStatement::from_sql("INSERT INTO users"), SqlStatement::Insert);
428        assert_eq!(SqlStatement::from_sql("UPDATE users SET"), SqlStatement::Update);
429        assert_eq!(SqlStatement::from_sql("DELETE FROM users"), SqlStatement::Delete);
430        assert_eq!(SqlStatement::from_sql("CREATE TABLE users"), SqlStatement::Create);
431    }
432}