1#[must_use]
8pub fn classify_sql(sql: &str) -> Option<bool> {
9 let cleaned = strip_comments(sql);
10 let trimmed = cleaned.trim();
11
12 if trimmed.is_empty() {
13 return None;
14 }
15
16 let statements: Vec<&str> = split_statements(trimmed);
18 if statements.len() > 1 {
19 let results: Vec<Option<bool>> = statements.iter().copied().map(classify_single).collect();
21 if results.iter().any(Option::is_none) {
22 return None;
23 }
24 return Some(results.iter().all(|r| *r == Some(true)));
25 }
26
27 classify_single(trimmed)
28}
29
30fn classify_single(sql: &str) -> Option<bool> {
32 let trimmed = sql.trim();
33 if trimmed.is_empty() {
34 return None;
35 }
36
37 let main_stmt = skip_cte(trimmed);
39 let upper = main_stmt.to_uppercase();
40
41 let first_word = upper.split_whitespace().next()?;
43
44 match first_word {
45 "SELECT" => {
46 if contains_keyword(&upper, "INTO")
48 && !contains_keyword(&upper, "INTO OUTFILE")
49 && is_select_into(&upper)
50 {
51 Some(false)
52 } else {
53 Some(true)
54 }
55 }
56 "SHOW" | "DESCRIBE" | "DESC" | "EXPLAIN" | "PRAGMA" | "TABLE" => Some(true),
57 "WITH" => {
58 None
60 }
61 "INSERT" | "UPDATE" | "DELETE" | "CREATE" | "DROP" | "ALTER" | "TRUNCATE" | "REPLACE"
62 | "MERGE" | "GRANT" | "REVOKE" | "RENAME" | "UPSERT" | "VACUUM" | "REINDEX" | "ANALYZE" => {
63 Some(false)
64 }
65 _ => None,
66 }
67}
68
69fn is_select_into(upper_sql: &str) -> bool {
71 upper_sql.find(" INTO ").is_some_and(|into_pos| {
74 let after_into = &upper_sql[into_pos + 6..];
75 let next_word = after_into.split_whitespace().next().unwrap_or("");
76 !matches!(next_word, "OUTFILE" | "DUMPFILE")
77 })
78}
79
80fn strip_comments(sql: &str) -> String {
82 let mut result = String::with_capacity(sql.len());
83 let bytes = sql.as_bytes();
84 let len = bytes.len();
85 let mut i = 0;
86 let mut in_single_quote = false;
87 let mut in_double_quote = false;
88
89 while i < len {
90 if in_single_quote {
91 result.push(bytes[i] as char);
92 if bytes[i] == b'\'' {
93 in_single_quote = false;
94 }
95 i += 1;
96 } else if in_double_quote {
97 result.push(bytes[i] as char);
98 if bytes[i] == b'"' {
99 in_double_quote = false;
100 }
101 i += 1;
102 } else if bytes[i] == b'\'' {
103 in_single_quote = true;
104 result.push('\'');
105 i += 1;
106 } else if bytes[i] == b'"' {
107 in_double_quote = true;
108 result.push('"');
109 i += 1;
110 } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
111 while i < len && bytes[i] != b'\n' {
113 i += 1;
114 }
115 } else if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
116 i += 2;
118 while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
119 i += 1;
120 }
121 i += 2; } else {
123 result.push(bytes[i] as char);
124 i += 1;
125 }
126 }
127 result
128}
129
130fn split_statements(sql: &str) -> Vec<&str> {
132 let mut stmts = Vec::new();
133 let mut start = 0;
134 let bytes = sql.as_bytes();
135 let len = bytes.len();
136 let mut i = 0;
137 let mut in_single_quote = false;
138 let mut in_double_quote = false;
139
140 while i < len {
141 if in_single_quote {
142 if bytes[i] == b'\'' {
143 in_single_quote = false;
144 }
145 } else if in_double_quote {
146 if bytes[i] == b'"' {
147 in_double_quote = false;
148 }
149 } else if bytes[i] == b'\'' {
150 in_single_quote = true;
151 } else if bytes[i] == b'"' {
152 in_double_quote = true;
153 } else if bytes[i] == b';' {
154 let stmt = sql[start..i].trim();
155 if !stmt.is_empty() {
156 stmts.push(stmt);
157 }
158 start = i + 1;
159 }
160 i += 1;
161 }
162
163 let last = sql[start..].trim();
164 if !last.is_empty() {
165 stmts.push(last);
166 }
167 stmts
168}
169
170fn skip_cte(sql: &str) -> &str {
172 let upper = sql.to_uppercase();
173 if !upper.starts_with("WITH ") {
174 return sql;
175 }
176
177 let mut depth = 0i32;
179 let mut last_close = 0;
180 for (i, ch) in sql.chars().enumerate() {
181 if ch == '(' {
182 depth += 1;
183 } else if ch == ')' {
184 depth -= 1;
185 if depth == 0 {
186 last_close = i;
187 }
188 }
189 }
190
191 if last_close > 0 && last_close + 1 < sql.len() {
192 let after = sql[last_close + 1..].trim();
193 let after = after.strip_prefix(',').map_or(after, str::trim);
195 if after.to_uppercase().starts_with("SELECT")
196 || after.to_uppercase().starts_with("INSERT")
197 || after.to_uppercase().starts_with("UPDATE")
198 || after.to_uppercase().starts_with("DELETE")
199 {
200 return after;
201 }
202 }
203
204 sql
205}
206
207fn contains_keyword(upper_sql: &str, keyword: &str) -> bool {
209 upper_sql.contains(keyword)
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn select_is_readonly() {
218 assert_eq!(classify_sql("SELECT * FROM users"), Some(true));
219 assert_eq!(classify_sql("select id from orders"), Some(true));
220 }
221
222 #[test]
223 fn show_describe_are_readonly() {
224 assert_eq!(classify_sql("SHOW TABLES"), Some(true));
225 assert_eq!(classify_sql("DESCRIBE users"), Some(true));
226 assert_eq!(classify_sql("EXPLAIN SELECT 1"), Some(true));
227 }
228
229 #[test]
230 fn write_statements() {
231 assert_eq!(classify_sql("INSERT INTO users VALUES (1)"), Some(false));
232 assert_eq!(classify_sql("UPDATE users SET name='x'"), Some(false));
233 assert_eq!(classify_sql("DELETE FROM users"), Some(false));
234 assert_eq!(classify_sql("CREATE TABLE t (id INT)"), Some(false));
235 assert_eq!(classify_sql("DROP TABLE users"), Some(false));
236 assert_eq!(classify_sql("ALTER TABLE users ADD col INT"), Some(false));
237 assert_eq!(classify_sql("TRUNCATE TABLE users"), Some(false));
238 }
239
240 #[test]
241 fn select_into_is_write() {
242 assert_eq!(
243 classify_sql("SELECT * INTO new_table FROM users"),
244 Some(false)
245 );
246 }
247
248 #[test]
249 fn cte_with_select() {
250 assert_eq!(
251 classify_sql("WITH cte AS (SELECT 1) SELECT * FROM cte"),
252 Some(true)
253 );
254 }
255
256 #[test]
257 fn multi_statement_all_readonly() {
258 assert_eq!(classify_sql("SELECT 1; SELECT 2"), Some(true));
259 }
260
261 #[test]
262 fn multi_statement_mixed() {
263 assert_eq!(classify_sql("SELECT 1; DELETE FROM users"), Some(false));
264 }
265
266 #[test]
267 fn comments_stripped() {
268 assert_eq!(classify_sql("-- comment\nSELECT 1"), Some(true));
269 assert_eq!(classify_sql("/* block */ SELECT 1"), Some(true));
270 }
271
272 #[test]
273 fn empty_is_ambiguous() {
274 assert_eq!(classify_sql(""), None);
275 assert_eq!(classify_sql(" "), None);
276 }
277
278 #[test]
279 fn exec_is_ambiguous() {
280 assert_eq!(classify_sql("EXEC sp_something"), None);
281 }
282}