Skip to main content

rippy_cli/
sql.rs

1/// Classify a SQL statement as read-only, write, or ambiguous.
2///
3/// Returns:
4/// - `Some(true)` — read-only (SELECT, SHOW, DESCRIBE, EXPLAIN)
5/// - `Some(false)` — write (INSERT, UPDATE, DELETE, CREATE, DROP, etc.)
6/// - `None` — ambiguous, multiple statements, or unrecognizable
7#[must_use]
8pub fn classify_sql(sql: &str) -> Option<bool> {
9    let cleaned = strip_comments(sql);
10    let trimmed = cleaned.trim();
11
12    if trimmed.is_empty() {
13        return None;
14    }
15
16    // Check for multiple statements (semicolons outside quotes)
17    let statements: Vec<&str> = split_statements(trimmed);
18    if statements.len() > 1 {
19        // All must be read-only for the whole thing to be read-only
20        let results: Vec<Option<bool>> = statements.iter().copied().map(classify_single).collect();
21        if results.iter().any(Option::is_none) {
22            return None;
23        }
24        return Some(results.iter().all(|r| *r == Some(true)));
25    }
26
27    classify_single(trimmed)
28}
29
30/// Classify a single SQL statement.
31fn classify_single(sql: &str) -> Option<bool> {
32    let trimmed = sql.trim();
33    if trimmed.is_empty() {
34        return None;
35    }
36
37    // Strip leading CTE: WITH ... AS (...) — look at the final statement
38    let main_stmt = skip_cte(trimmed);
39    let upper = main_stmt.to_uppercase();
40
41    // Find the first keyword
42    let first_word = upper.split_whitespace().next()?;
43
44    match first_word {
45        "SELECT" => {
46            // SELECT INTO is a write operation
47            if contains_keyword(&upper, "INTO")
48                && !contains_keyword(&upper, "INTO OUTFILE")
49                && is_select_into(&upper)
50            {
51                Some(false)
52            } else {
53                Some(true)
54            }
55        }
56        "SHOW" | "DESCRIBE" | "DESC" | "EXPLAIN" | "PRAGMA" | "TABLE" => Some(true),
57        "WITH" => {
58            // CTE not fully stripped — try to find the main statement
59            None
60        }
61        "INSERT" | "UPDATE" | "DELETE" | "CREATE" | "DROP" | "ALTER" | "TRUNCATE" | "REPLACE"
62        | "MERGE" | "GRANT" | "REVOKE" | "RENAME" | "UPSERT" | "VACUUM" | "REINDEX" | "ANALYZE" => {
63            Some(false)
64        }
65        _ => None,
66    }
67}
68
69/// Check if the SELECT statement has an INTO clause that makes it a write.
70fn is_select_into(upper_sql: &str) -> bool {
71    // Simple heuristic: look for SELECT ... INTO ... FROM
72    // but not INTO OUTFILE or INTO DUMPFILE
73    upper_sql.find(" INTO ").is_some_and(|into_pos| {
74        let after_into = &upper_sql[into_pos + 6..];
75        let next_word = after_into.split_whitespace().next().unwrap_or("");
76        !matches!(next_word, "OUTFILE" | "DUMPFILE")
77    })
78}
79
80/// Strip SQL comments (-- line comments and /* */ block comments).
81fn strip_comments(sql: &str) -> String {
82    let mut result = String::with_capacity(sql.len());
83    let bytes = sql.as_bytes();
84    let len = bytes.len();
85    let mut i = 0;
86    let mut in_single_quote = false;
87    let mut in_double_quote = false;
88
89    while i < len {
90        if in_single_quote {
91            result.push(bytes[i] as char);
92            if bytes[i] == b'\'' {
93                in_single_quote = false;
94            }
95            i += 1;
96        } else if in_double_quote {
97            result.push(bytes[i] as char);
98            if bytes[i] == b'"' {
99                in_double_quote = false;
100            }
101            i += 1;
102        } else if bytes[i] == b'\'' {
103            in_single_quote = true;
104            result.push('\'');
105            i += 1;
106        } else if bytes[i] == b'"' {
107            in_double_quote = true;
108            result.push('"');
109            i += 1;
110        } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
111            // Line comment — skip to end of line
112            while i < len && bytes[i] != b'\n' {
113                i += 1;
114            }
115        } else if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
116            // Block comment — skip to */
117            i += 2;
118            while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
119                i += 1;
120            }
121            i += 2; // skip */
122        } else {
123            result.push(bytes[i] as char);
124            i += 1;
125        }
126    }
127    result
128}
129
130/// Split SQL on semicolons, respecting quotes.
131fn split_statements(sql: &str) -> Vec<&str> {
132    let mut stmts = Vec::new();
133    let mut start = 0;
134    let bytes = sql.as_bytes();
135    let len = bytes.len();
136    let mut i = 0;
137    let mut in_single_quote = false;
138    let mut in_double_quote = false;
139
140    while i < len {
141        if in_single_quote {
142            if bytes[i] == b'\'' {
143                in_single_quote = false;
144            }
145        } else if in_double_quote {
146            if bytes[i] == b'"' {
147                in_double_quote = false;
148            }
149        } else if bytes[i] == b'\'' {
150            in_single_quote = true;
151        } else if bytes[i] == b'"' {
152            in_double_quote = true;
153        } else if bytes[i] == b';' {
154            let stmt = sql[start..i].trim();
155            if !stmt.is_empty() {
156                stmts.push(stmt);
157            }
158            start = i + 1;
159        }
160        i += 1;
161    }
162
163    let last = sql[start..].trim();
164    if !last.is_empty() {
165        stmts.push(last);
166    }
167    stmts
168}
169
170/// Skip a leading CTE (WITH ... AS (...)) and return the main statement.
171fn skip_cte(sql: &str) -> &str {
172    let upper = sql.to_uppercase();
173    if !upper.starts_with("WITH ") {
174        return sql;
175    }
176
177    // Find the last matching closing paren, then look for the main keyword
178    let mut depth = 0i32;
179    let mut last_close = 0;
180    for (i, ch) in sql.chars().enumerate() {
181        if ch == '(' {
182            depth += 1;
183        } else if ch == ')' {
184            depth -= 1;
185            if depth == 0 {
186                last_close = i;
187            }
188        }
189    }
190
191    if last_close > 0 && last_close + 1 < sql.len() {
192        let after = sql[last_close + 1..].trim();
193        // Skip optional comma for chained CTEs, then return the main statement
194        let after = after.strip_prefix(',').map_or(after, str::trim);
195        if after.to_uppercase().starts_with("SELECT")
196            || after.to_uppercase().starts_with("INSERT")
197            || after.to_uppercase().starts_with("UPDATE")
198            || after.to_uppercase().starts_with("DELETE")
199        {
200            return after;
201        }
202    }
203
204    sql
205}
206
207/// Check if a keyword appears in the SQL (case-insensitive, word boundary).
208fn contains_keyword(upper_sql: &str, keyword: &str) -> bool {
209    upper_sql.contains(keyword)
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn select_is_readonly() {
218        assert_eq!(classify_sql("SELECT * FROM users"), Some(true));
219        assert_eq!(classify_sql("select id from orders"), Some(true));
220    }
221
222    #[test]
223    fn show_describe_are_readonly() {
224        assert_eq!(classify_sql("SHOW TABLES"), Some(true));
225        assert_eq!(classify_sql("DESCRIBE users"), Some(true));
226        assert_eq!(classify_sql("EXPLAIN SELECT 1"), Some(true));
227    }
228
229    #[test]
230    fn write_statements() {
231        assert_eq!(classify_sql("INSERT INTO users VALUES (1)"), Some(false));
232        assert_eq!(classify_sql("UPDATE users SET name='x'"), Some(false));
233        assert_eq!(classify_sql("DELETE FROM users"), Some(false));
234        assert_eq!(classify_sql("CREATE TABLE t (id INT)"), Some(false));
235        assert_eq!(classify_sql("DROP TABLE users"), Some(false));
236        assert_eq!(classify_sql("ALTER TABLE users ADD col INT"), Some(false));
237        assert_eq!(classify_sql("TRUNCATE TABLE users"), Some(false));
238    }
239
240    #[test]
241    fn select_into_is_write() {
242        assert_eq!(
243            classify_sql("SELECT * INTO new_table FROM users"),
244            Some(false)
245        );
246    }
247
248    #[test]
249    fn cte_with_select() {
250        assert_eq!(
251            classify_sql("WITH cte AS (SELECT 1) SELECT * FROM cte"),
252            Some(true)
253        );
254    }
255
256    #[test]
257    fn multi_statement_all_readonly() {
258        assert_eq!(classify_sql("SELECT 1; SELECT 2"), Some(true));
259    }
260
261    #[test]
262    fn multi_statement_mixed() {
263        assert_eq!(classify_sql("SELECT 1; DELETE FROM users"), Some(false));
264    }
265
266    #[test]
267    fn comments_stripped() {
268        assert_eq!(classify_sql("-- comment\nSELECT 1"), Some(true));
269        assert_eq!(classify_sql("/* block */ SELECT 1"), Some(true));
270    }
271
272    #[test]
273    fn empty_is_ambiguous() {
274        assert_eq!(classify_sql(""), None);
275        assert_eq!(classify_sql("   "), None);
276    }
277
278    #[test]
279    fn exec_is_ambiguous() {
280        assert_eq!(classify_sql("EXEC sp_something"), None);
281    }
282}