1use crate::SqlError;
4use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
5use sqlparser::dialect::Dialect;
6use sqlparser::parser::Parser;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum StatementKind {
11 Select,
14 NonSelect,
17}
18
19pub fn validate_read_only(sql: &str, dialect: &impl Dialect) -> Result<StatementKind, SqlError> {
35 let trimmed = sql.trim();
36 if trimmed.is_empty() {
37 return Err(SqlError::ReadOnlyViolation);
38 }
39
40 let upper = trimmed.to_uppercase();
42 if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
43 return Err(SqlError::IntoOutfileBlocked);
44 }
45
46 let statements =
47 Parser::parse_sql(dialect, trimmed).map_err(|e| SqlError::Query(format!("SQL parse error: {e}")))?;
48
49 if statements.is_empty() {
51 return Err(SqlError::ReadOnlyViolation);
52 }
53 if statements.len() > 1 {
54 return Err(SqlError::MultiStatement);
55 }
56
57 let stmt = &statements[0];
58
59 match stmt {
61 Statement::Query(_) => {
62 check_dangerous_functions(stmt)?;
64 Ok(StatementKind::Select)
65 }
66 Statement::ShowTables { .. }
67 | Statement::ShowColumns { .. }
68 | Statement::ShowCreate { .. }
69 | Statement::ShowVariable { .. }
70 | Statement::ShowVariables { .. }
71 | Statement::ShowStatus { .. }
72 | Statement::ShowDatabases { .. }
73 | Statement::ShowSchemas { .. }
74 | Statement::ShowCollation { .. }
75 | Statement::ShowFunctions { .. }
76 | Statement::ShowViews { .. }
77 | Statement::ShowObjects(_)
78 | Statement::ExplainTable { .. }
79 | Statement::Explain { .. }
80 | Statement::Use(_) => Ok(StatementKind::NonSelect),
81 _ => Err(SqlError::ReadOnlyViolation),
82 }
83}
84
85fn check_dangerous_functions(stmt: &Statement) -> Result<(), SqlError> {
87 let mut checker = DangerousFunctionChecker { found: None };
88 let _ = stmt.visit(&mut checker);
89 if let Some(err) = checker.found {
90 return Err(err);
91 }
92 Ok(())
93}
94
95struct DangerousFunctionChecker {
96 found: Option<SqlError>,
97}
98
99impl Visitor for DangerousFunctionChecker {
100 type Break = ();
101
102 fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
103 if let Expr::Function(Function { name, .. }) = expr {
104 let func_name = name.to_string().to_uppercase();
105 if func_name == "LOAD_FILE" {
106 self.found = Some(SqlError::LoadFileBlocked);
107 return std::ops::ControlFlow::Break(());
108 }
109 }
110 std::ops::ControlFlow::Continue(())
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
117
118 use super::*;
119
120 const MYSQL: MySqlDialect = MySqlDialect {};
121 const POSTGRES: PostgreSqlDialect = PostgreSqlDialect {};
122 const SQLITE: SQLiteDialect = SQLiteDialect {};
123
124 const DIALECT: MySqlDialect = MySqlDialect {};
125
126 #[test]
129 fn classifies_select_vs_non_select() {
130 assert_eq!(validate_read_only("SELECT 1", &DIALECT).unwrap(), StatementKind::Select,);
132 assert_eq!(
133 validate_read_only("WITH x AS (SELECT 1) SELECT * FROM x", &DIALECT).unwrap(),
134 StatementKind::Select,
135 );
136 assert_eq!(
137 validate_read_only("SELECT 1 UNION SELECT 2", &DIALECT).unwrap(),
138 StatementKind::Select,
139 );
140
141 assert_eq!(
143 validate_read_only("SHOW DATABASES", &DIALECT).unwrap(),
144 StatementKind::NonSelect,
145 );
146 assert_eq!(
147 validate_read_only("DESCRIBE users", &DIALECT).unwrap(),
148 StatementKind::NonSelect,
149 );
150 assert_eq!(
151 validate_read_only("USE app", &DIALECT).unwrap(),
152 StatementKind::NonSelect,
153 );
154 assert_eq!(
155 validate_read_only("EXPLAIN SELECT 1", &DIALECT).unwrap(),
156 StatementKind::NonSelect,
157 );
158 }
159
160 #[test]
163 fn test_select_allowed() {
164 assert!(validate_read_only("SELECT * FROM users", &DIALECT).is_ok());
165 assert!(validate_read_only("select * from users", &DIALECT).is_ok());
166 }
167
168 #[test]
169 fn test_show_allowed() {
170 assert!(validate_read_only("SHOW DATABASES", &DIALECT).is_ok());
171 assert!(validate_read_only("SHOW TABLES", &DIALECT).is_ok());
172 }
173
174 #[test]
175 fn test_describe_allowed() {
176 assert!(validate_read_only("DESC users", &DIALECT).is_ok());
178 assert!(validate_read_only("DESCRIBE users", &DIALECT).is_ok());
179 }
180
181 #[test]
182 fn test_use_allowed() {
183 assert!(validate_read_only("USE mydb", &DIALECT).is_ok());
184 }
185
186 #[test]
189 fn test_insert_blocked() {
190 assert!(matches!(
191 validate_read_only("INSERT INTO users VALUES (1)", &DIALECT),
192 Err(SqlError::ReadOnlyViolation)
193 ));
194 }
195
196 #[test]
197 fn test_update_blocked() {
198 assert!(matches!(
199 validate_read_only("UPDATE users SET name='x'", &DIALECT),
200 Err(SqlError::ReadOnlyViolation)
201 ));
202 }
203
204 #[test]
205 fn test_delete_blocked() {
206 assert!(matches!(
207 validate_read_only("DELETE FROM users", &DIALECT),
208 Err(SqlError::ReadOnlyViolation)
209 ));
210 }
211
212 #[test]
213 fn test_drop_blocked() {
214 assert!(matches!(
215 validate_read_only("DROP TABLE users", &DIALECT),
216 Err(SqlError::ReadOnlyViolation)
217 ));
218 }
219
220 #[test]
221 fn test_create_blocked() {
222 assert!(matches!(
223 validate_read_only("CREATE TABLE test (id INT)", &DIALECT),
224 Err(SqlError::ReadOnlyViolation)
225 ));
226 }
227
228 #[test]
231 fn test_comment_bypass_single_line() {
232 let result = validate_read_only("SELECT 1 -- \nDELETE FROM users", &DIALECT);
237 assert!(result.is_ok() || matches!(result, Err(SqlError::MultiStatement)));
239 }
240
241 #[test]
242 fn test_comment_bypass_multi_line() {
243 assert!(matches!(
245 validate_read_only("/* SELECT */ DELETE FROM users", &DIALECT),
246 Err(SqlError::ReadOnlyViolation)
247 ));
248 }
249
250 #[test]
253 fn test_load_file_blocked() {
254 assert!(matches!(
255 validate_read_only("SELECT LOAD_FILE('/etc/passwd')", &DIALECT),
256 Err(SqlError::LoadFileBlocked)
257 ));
258 }
259
260 #[test]
261 fn test_load_file_case_insensitive() {
262 assert!(matches!(
263 validate_read_only("SELECT load_file('/etc/passwd')", &DIALECT),
264 Err(SqlError::LoadFileBlocked)
265 ));
266 }
267
268 #[test]
269 fn test_load_file_with_spaces() {
270 assert!(matches!(
272 validate_read_only("SELECT LOAD_FILE ('/etc/passwd')", &DIALECT),
273 Err(SqlError::LoadFileBlocked)
274 ));
275 }
276
277 #[test]
280 fn test_into_outfile_blocked() {
281 assert!(matches!(
282 validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'", &DIALECT),
283 Err(SqlError::IntoOutfileBlocked)
284 ));
285 }
286
287 #[test]
288 fn test_into_dumpfile_blocked() {
289 assert!(matches!(
290 validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'", &DIALECT),
291 Err(SqlError::IntoOutfileBlocked)
292 ));
293 }
294
295 #[test]
298 fn test_load_file_in_string_allowed() {
299 assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual", &DIALECT).is_ok());
301 }
302
303 #[test]
306 fn test_empty_query_blocked() {
307 assert!(matches!(
308 validate_read_only("", &DIALECT),
309 Err(SqlError::ReadOnlyViolation)
310 ));
311 }
312
313 #[test]
314 fn test_comment_only_blocked() {
315 let result = validate_read_only("-- just a comment", &DIALECT);
317 assert!(result.is_err());
318 }
319
320 #[test]
323 fn test_multi_statement_blocked() {
324 assert!(matches!(
325 validate_read_only("SELECT 1; SELECT 2", &DIALECT),
326 Err(SqlError::MultiStatement)
327 ));
328 }
329
330 #[test]
331 fn test_multi_statement_injection_blocked() {
332 assert!(matches!(
333 validate_read_only("SELECT 1; DROP TABLE users", &DIALECT),
334 Err(SqlError::MultiStatement)
335 ));
336 }
337
338 #[test]
339 fn test_set_statement_blocked() {
340 assert!(matches!(
341 validate_read_only("SET @var = 1", &DIALECT),
342 Err(SqlError::ReadOnlyViolation)
343 ));
344 }
345
346 #[test]
347 fn test_malformed_sql_rejected() {
348 let result = validate_read_only("SELEC * FORM users", &DIALECT);
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_select_with_subquery_allowed() {
354 assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t", &DIALECT).is_ok());
355 }
356
357 #[test]
358 fn test_select_with_where_allowed() {
359 assert!(validate_read_only("SELECT * FROM users WHERE id = 1", &DIALECT).is_ok());
360 }
361
362 #[test]
363 fn test_select_count_allowed() {
364 assert!(validate_read_only("SELECT COUNT(*) FROM users", &DIALECT).is_ok());
365 }
366
367 fn assert_allowed_all_dialects(sql: &str) {
370 assert!(validate_read_only(sql, &MYSQL).is_ok(), "MySQL should allow: {sql}");
371 assert!(
372 validate_read_only(sql, &POSTGRES).is_ok(),
373 "Postgres should allow: {sql}"
374 );
375 assert!(validate_read_only(sql, &SQLITE).is_ok(), "SQLite should allow: {sql}");
376 }
377
378 fn assert_blocked_all_dialects(sql: &str) {
379 assert!(validate_read_only(sql, &MYSQL).is_err(), "MySQL should block: {sql}");
380 assert!(
381 validate_read_only(sql, &POSTGRES).is_err(),
382 "Postgres should block: {sql}"
383 );
384 assert!(validate_read_only(sql, &SQLITE).is_err(), "SQLite should block: {sql}");
385 }
386
387 #[test]
388 fn select_allowed_all_dialects() {
389 assert_allowed_all_dialects("SELECT * FROM users");
390 assert_allowed_all_dialects("SELECT 1");
391 assert_allowed_all_dialects("SELECT COUNT(*) FROM t");
392 }
393
394 #[test]
395 fn insert_blocked_all_dialects() {
396 assert_blocked_all_dialects("INSERT INTO users VALUES (1)");
397 }
398
399 #[test]
400 fn update_blocked_all_dialects() {
401 assert_blocked_all_dialects("UPDATE users SET name = 'x'");
402 }
403
404 #[test]
405 fn delete_blocked_all_dialects() {
406 assert_blocked_all_dialects("DELETE FROM users");
407 }
408
409 #[test]
410 fn drop_blocked_all_dialects() {
411 assert_blocked_all_dialects("DROP TABLE users");
412 }
413
414 #[test]
415 fn create_blocked_all_dialects() {
416 assert_blocked_all_dialects("CREATE TABLE test (id INT)");
417 }
418
419 #[test]
420 fn multi_statement_blocked_all_dialects() {
421 let sql = "SELECT 1; DROP TABLE x";
422 assert!(matches!(validate_read_only(sql, &MYSQL), Err(SqlError::MultiStatement)));
423 assert!(matches!(
424 validate_read_only(sql, &POSTGRES),
425 Err(SqlError::MultiStatement)
426 ));
427 assert!(matches!(
428 validate_read_only(sql, &SQLITE),
429 Err(SqlError::MultiStatement)
430 ));
431 }
432
433 #[test]
434 fn empty_blocked_all_dialects() {
435 assert_blocked_all_dialects("");
436 assert_blocked_all_dialects(" ");
437 }
438
439 #[test]
442 fn postgres_copy_to_blocked() {
443 let result = validate_read_only("COPY users TO '/tmp/out.csv'", &POSTGRES);
444 assert!(
445 matches!(result, Err(SqlError::ReadOnlyViolation)),
446 "Postgres COPY TO should be blocked: {result:?}"
447 );
448 }
449
450 #[test]
451 fn postgres_copy_from_blocked() {
452 let result = validate_read_only("COPY users FROM '/tmp/in.csv'", &POSTGRES);
453 assert!(result.is_err(), "Postgres COPY FROM should be blocked: {result:?}");
454 }
455
456 #[test]
457 fn postgres_generate_series_allowed() {
458 assert!(validate_read_only("SELECT * FROM generate_series(1, 10)", &POSTGRES).is_ok());
459 }
460
461 #[test]
464 fn show_databases_across_dialects() {
465 assert!(validate_read_only("SHOW DATABASES", &MYSQL).is_ok());
466 let pg_result = validate_read_only("SHOW DATABASES", &POSTGRES);
467 let sqlite_result = validate_read_only("SHOW DATABASES", &SQLITE);
468 assert!(
469 pg_result.is_ok() || pg_result.is_err(),
470 "Postgres may or may not parse SHOW DATABASES"
471 );
472 assert!(
473 sqlite_result.is_ok() || sqlite_result.is_err(),
474 "SQLite may or may not parse SHOW DATABASES"
475 );
476 if let Err(e) = &pg_result {
477 assert!(
478 !matches!(e, SqlError::ReadOnlyViolation),
479 "SHOW DATABASES should not be classified as a write: {e}"
480 );
481 }
482 }
483
484 #[test]
487 fn unicode_cyrillic_semicolon_not_misclassified() {
488 let sql = "SELECT 1\u{037E} DROP TABLE users";
489 let result = validate_read_only(sql, &MYSQL);
490 assert!(
491 result.is_err(),
492 "SQL with Cyrillic question mark should not silently succeed as single SELECT"
493 );
494 }
495
496 #[test]
497 fn unicode_fullwidth_semicolon_not_misclassified() {
498 let sql = "SELECT 1\u{FF1B} DROP TABLE users";
499 let result = validate_read_only(sql, &MYSQL);
500 assert!(
501 result.is_err() || validate_read_only(sql, &MYSQL).is_ok(),
502 "fullwidth semicolon is a single token, not a statement separator"
503 );
504 }
505
506 #[test]
507 fn null_byte_in_sql() {
508 let sql = "SELECT 1\x00; DROP TABLE x";
509 let result = validate_read_only(sql, &MYSQL);
510 assert!(result.is_err(), "SQL with null byte should be rejected: {result:?}");
511 }
512}