use crate::SqlError;
use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
use sqlparser::dialect::Dialect;
use sqlparser::parser::Parser;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StatementKind {
Select,
NonSelect,
}
pub fn validate_read_only(sql: &str, dialect: &impl Dialect) -> Result<StatementKind, SqlError> {
let trimmed = sql.trim();
if trimmed.is_empty() {
return Err(SqlError::ReadOnlyViolation);
}
let upper = trimmed.to_uppercase();
if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
return Err(SqlError::IntoOutfileBlocked);
}
let statements =
Parser::parse_sql(dialect, trimmed).map_err(|e| SqlError::Query(format!("SQL parse error: {e}")))?;
if statements.is_empty() {
return Err(SqlError::ReadOnlyViolation);
}
if statements.len() > 1 {
return Err(SqlError::MultiStatement);
}
let stmt = &statements[0];
match stmt {
Statement::Query(_) => {
check_dangerous_functions(stmt)?;
Ok(StatementKind::Select)
}
Statement::ShowTables { .. }
| Statement::ShowColumns { .. }
| Statement::ShowCreate { .. }
| Statement::ShowVariable { .. }
| Statement::ShowVariables { .. }
| Statement::ShowStatus { .. }
| Statement::ShowDatabases { .. }
| Statement::ShowSchemas { .. }
| Statement::ShowCollation { .. }
| Statement::ShowFunctions { .. }
| Statement::ShowViews { .. }
| Statement::ShowObjects(_)
| Statement::ExplainTable { .. }
| Statement::Explain { .. }
| Statement::Use(_) => Ok(StatementKind::NonSelect),
_ => Err(SqlError::ReadOnlyViolation),
}
}
fn check_dangerous_functions(stmt: &Statement) -> Result<(), SqlError> {
let mut checker = DangerousFunctionChecker { found: None };
let _ = stmt.visit(&mut checker);
if let Some(err) = checker.found {
return Err(err);
}
Ok(())
}
struct DangerousFunctionChecker {
found: Option<SqlError>,
}
impl Visitor for DangerousFunctionChecker {
type Break = ();
fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
if let Expr::Function(Function { name, .. }) = expr {
let func_name = name.to_string().to_uppercase();
if func_name == "LOAD_FILE" {
self.found = Some(SqlError::LoadFileBlocked);
return std::ops::ControlFlow::Break(());
}
}
std::ops::ControlFlow::Continue(())
}
}
#[cfg(test)]
mod tests {
use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
use super::*;
const MYSQL: MySqlDialect = MySqlDialect {};
const POSTGRES: PostgreSqlDialect = PostgreSqlDialect {};
const SQLITE: SQLiteDialect = SQLiteDialect {};
const DIALECT: MySqlDialect = MySqlDialect {};
#[test]
fn classifies_select_vs_non_select() {
assert_eq!(validate_read_only("SELECT 1", &DIALECT).unwrap(), StatementKind::Select,);
assert_eq!(
validate_read_only("WITH x AS (SELECT 1) SELECT * FROM x", &DIALECT).unwrap(),
StatementKind::Select,
);
assert_eq!(
validate_read_only("SELECT 1 UNION SELECT 2", &DIALECT).unwrap(),
StatementKind::Select,
);
assert_eq!(
validate_read_only("SHOW DATABASES", &DIALECT).unwrap(),
StatementKind::NonSelect,
);
assert_eq!(
validate_read_only("DESCRIBE users", &DIALECT).unwrap(),
StatementKind::NonSelect,
);
assert_eq!(
validate_read_only("USE app", &DIALECT).unwrap(),
StatementKind::NonSelect,
);
assert_eq!(
validate_read_only("EXPLAIN SELECT 1", &DIALECT).unwrap(),
StatementKind::NonSelect,
);
}
#[test]
fn test_select_allowed() {
assert!(validate_read_only("SELECT * FROM users", &DIALECT).is_ok());
assert!(validate_read_only("select * from users", &DIALECT).is_ok());
}
#[test]
fn test_show_allowed() {
assert!(validate_read_only("SHOW DATABASES", &DIALECT).is_ok());
assert!(validate_read_only("SHOW TABLES", &DIALECT).is_ok());
}
#[test]
fn test_describe_allowed() {
assert!(validate_read_only("DESC users", &DIALECT).is_ok());
assert!(validate_read_only("DESCRIBE users", &DIALECT).is_ok());
}
#[test]
fn test_use_allowed() {
assert!(validate_read_only("USE mydb", &DIALECT).is_ok());
}
#[test]
fn test_insert_blocked() {
assert!(matches!(
validate_read_only("INSERT INTO users VALUES (1)", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_update_blocked() {
assert!(matches!(
validate_read_only("UPDATE users SET name='x'", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_delete_blocked() {
assert!(matches!(
validate_read_only("DELETE FROM users", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_drop_blocked() {
assert!(matches!(
validate_read_only("DROP TABLE users", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_create_blocked() {
assert!(matches!(
validate_read_only("CREATE TABLE test (id INT)", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_comment_bypass_single_line() {
let result = validate_read_only("SELECT 1 -- \nDELETE FROM users", &DIALECT);
assert!(result.is_ok() || matches!(result, Err(SqlError::MultiStatement)));
}
#[test]
fn test_comment_bypass_multi_line() {
assert!(matches!(
validate_read_only("/* SELECT */ DELETE FROM users", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_load_file_blocked() {
assert!(matches!(
validate_read_only("SELECT LOAD_FILE('/etc/passwd')", &DIALECT),
Err(SqlError::LoadFileBlocked)
));
}
#[test]
fn test_load_file_case_insensitive() {
assert!(matches!(
validate_read_only("SELECT load_file('/etc/passwd')", &DIALECT),
Err(SqlError::LoadFileBlocked)
));
}
#[test]
fn test_load_file_with_spaces() {
assert!(matches!(
validate_read_only("SELECT LOAD_FILE ('/etc/passwd')", &DIALECT),
Err(SqlError::LoadFileBlocked)
));
}
#[test]
fn test_into_outfile_blocked() {
assert!(matches!(
validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'", &DIALECT),
Err(SqlError::IntoOutfileBlocked)
));
}
#[test]
fn test_into_dumpfile_blocked() {
assert!(matches!(
validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'", &DIALECT),
Err(SqlError::IntoOutfileBlocked)
));
}
#[test]
fn test_load_file_in_string_allowed() {
assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual", &DIALECT).is_ok());
}
#[test]
fn test_empty_query_blocked() {
assert!(matches!(
validate_read_only("", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_comment_only_blocked() {
let result = validate_read_only("-- just a comment", &DIALECT);
assert!(result.is_err());
}
#[test]
fn test_multi_statement_blocked() {
assert!(matches!(
validate_read_only("SELECT 1; SELECT 2", &DIALECT),
Err(SqlError::MultiStatement)
));
}
#[test]
fn test_multi_statement_injection_blocked() {
assert!(matches!(
validate_read_only("SELECT 1; DROP TABLE users", &DIALECT),
Err(SqlError::MultiStatement)
));
}
#[test]
fn test_set_statement_blocked() {
assert!(matches!(
validate_read_only("SET @var = 1", &DIALECT),
Err(SqlError::ReadOnlyViolation)
));
}
#[test]
fn test_malformed_sql_rejected() {
let result = validate_read_only("SELEC * FORM users", &DIALECT);
assert!(result.is_err());
}
#[test]
fn test_select_with_subquery_allowed() {
assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t", &DIALECT).is_ok());
}
#[test]
fn test_select_with_where_allowed() {
assert!(validate_read_only("SELECT * FROM users WHERE id = 1", &DIALECT).is_ok());
}
#[test]
fn test_select_count_allowed() {
assert!(validate_read_only("SELECT COUNT(*) FROM users", &DIALECT).is_ok());
}
fn assert_allowed_all_dialects(sql: &str) {
assert!(validate_read_only(sql, &MYSQL).is_ok(), "MySQL should allow: {sql}");
assert!(
validate_read_only(sql, &POSTGRES).is_ok(),
"Postgres should allow: {sql}"
);
assert!(validate_read_only(sql, &SQLITE).is_ok(), "SQLite should allow: {sql}");
}
fn assert_blocked_all_dialects(sql: &str) {
assert!(validate_read_only(sql, &MYSQL).is_err(), "MySQL should block: {sql}");
assert!(
validate_read_only(sql, &POSTGRES).is_err(),
"Postgres should block: {sql}"
);
assert!(validate_read_only(sql, &SQLITE).is_err(), "SQLite should block: {sql}");
}
#[test]
fn select_allowed_all_dialects() {
assert_allowed_all_dialects("SELECT * FROM users");
assert_allowed_all_dialects("SELECT 1");
assert_allowed_all_dialects("SELECT COUNT(*) FROM t");
}
#[test]
fn insert_blocked_all_dialects() {
assert_blocked_all_dialects("INSERT INTO users VALUES (1)");
}
#[test]
fn update_blocked_all_dialects() {
assert_blocked_all_dialects("UPDATE users SET name = 'x'");
}
#[test]
fn delete_blocked_all_dialects() {
assert_blocked_all_dialects("DELETE FROM users");
}
#[test]
fn drop_blocked_all_dialects() {
assert_blocked_all_dialects("DROP TABLE users");
}
#[test]
fn create_blocked_all_dialects() {
assert_blocked_all_dialects("CREATE TABLE test (id INT)");
}
#[test]
fn multi_statement_blocked_all_dialects() {
let sql = "SELECT 1; DROP TABLE x";
assert!(matches!(validate_read_only(sql, &MYSQL), Err(SqlError::MultiStatement)));
assert!(matches!(
validate_read_only(sql, &POSTGRES),
Err(SqlError::MultiStatement)
));
assert!(matches!(
validate_read_only(sql, &SQLITE),
Err(SqlError::MultiStatement)
));
}
#[test]
fn empty_blocked_all_dialects() {
assert_blocked_all_dialects("");
assert_blocked_all_dialects(" ");
}
#[test]
fn postgres_copy_to_blocked() {
let result = validate_read_only("COPY users TO '/tmp/out.csv'", &POSTGRES);
assert!(
matches!(result, Err(SqlError::ReadOnlyViolation)),
"Postgres COPY TO should be blocked: {result:?}"
);
}
#[test]
fn postgres_copy_from_blocked() {
let result = validate_read_only("COPY users FROM '/tmp/in.csv'", &POSTGRES);
assert!(result.is_err(), "Postgres COPY FROM should be blocked: {result:?}");
}
#[test]
fn postgres_generate_series_allowed() {
assert!(validate_read_only("SELECT * FROM generate_series(1, 10)", &POSTGRES).is_ok());
}
#[test]
fn show_databases_across_dialects() {
assert!(validate_read_only("SHOW DATABASES", &MYSQL).is_ok());
let pg_result = validate_read_only("SHOW DATABASES", &POSTGRES);
let sqlite_result = validate_read_only("SHOW DATABASES", &SQLITE);
assert!(
pg_result.is_ok() || pg_result.is_err(),
"Postgres may or may not parse SHOW DATABASES"
);
assert!(
sqlite_result.is_ok() || sqlite_result.is_err(),
"SQLite may or may not parse SHOW DATABASES"
);
if let Err(e) = &pg_result {
assert!(
!matches!(e, SqlError::ReadOnlyViolation),
"SHOW DATABASES should not be classified as a write: {e}"
);
}
}
#[test]
fn unicode_cyrillic_semicolon_not_misclassified() {
let sql = "SELECT 1\u{037E} DROP TABLE users";
let result = validate_read_only(sql, &MYSQL);
assert!(
result.is_err(),
"SQL with Cyrillic question mark should not silently succeed as single SELECT"
);
}
#[test]
fn unicode_fullwidth_semicolon_not_misclassified() {
let sql = "SELECT 1\u{FF1B} DROP TABLE users";
let result = validate_read_only(sql, &MYSQL);
assert!(
result.is_err() || validate_read_only(sql, &MYSQL).is_ok(),
"fullwidth semicolon is a single token, not a statement separator"
);
}
#[test]
fn null_byte_in_sql() {
let sql = "SELECT 1\x00; DROP TABLE x";
let result = validate_read_only(sql, &MYSQL);
assert!(result.is_err(), "SQL with null byte should be rejected: {result:?}");
}
}