use database_mcp_server::AppError;
use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
use sqlparser::dialect::Dialect;
#[cfg(test)]
use sqlparser::dialect::MySqlDialect;
use sqlparser::parser::Parser;
pub fn validate_read_only_with_dialect(sql: &str, dialect: &impl Dialect) -> Result<(), AppError> {
let trimmed = sql.trim();
if trimmed.is_empty() {
return Err(AppError::ReadOnlyViolation);
}
let upper = trimmed.to_uppercase();
if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
return Err(AppError::IntoOutfileBlocked);
}
let statements =
Parser::parse_sql(dialect, trimmed).map_err(|e| AppError::Query(format!("SQL parse error: {e}")))?;
if statements.is_empty() {
return Err(AppError::ReadOnlyViolation);
}
if statements.len() > 1 {
return Err(AppError::MultiStatement);
}
let stmt = &statements[0];
match stmt {
Statement::Query(_) => {
check_dangerous_functions(stmt)?;
}
Statement::ShowTables { .. }
| Statement::ShowColumns { .. }
| Statement::ShowCreate { .. }
| Statement::ShowVariable { .. }
| Statement::ShowVariables { .. }
| Statement::ShowStatus { .. }
| Statement::ShowDatabases { .. }
| Statement::ShowSchemas { .. }
| Statement::ShowCollation { .. }
| Statement::ShowFunctions { .. }
| Statement::ShowViews { .. }
| Statement::ShowObjects(_)
| Statement::ExplainTable { .. }
| Statement::Explain { .. }
| Statement::Use(_) => {
}
_ => {
return Err(AppError::ReadOnlyViolation);
}
}
Ok(())
}
#[cfg(test)]
pub fn validate_read_only(sql: &str) -> Result<(), AppError> {
validate_read_only_with_dialect(sql, &MySqlDialect {})
}
fn check_dangerous_functions(stmt: &Statement) -> Result<(), AppError> {
let mut checker = DangerousFunctionChecker { found: None };
let _ = stmt.visit(&mut checker);
if let Some(err) = checker.found {
return Err(err);
}
Ok(())
}
struct DangerousFunctionChecker {
found: Option<AppError>,
}
impl Visitor for DangerousFunctionChecker {
type Break = ();
fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
if let Expr::Function(Function { name, .. }) = expr {
let func_name = name.to_string().to_uppercase();
if func_name == "LOAD_FILE" {
self.found = Some(AppError::LoadFileBlocked);
return std::ops::ControlFlow::Break(());
}
}
std::ops::ControlFlow::Continue(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_select_allowed() {
assert!(validate_read_only("SELECT * FROM users").is_ok());
assert!(validate_read_only("select * from users").is_ok());
}
#[test]
fn test_show_allowed() {
assert!(validate_read_only("SHOW DATABASES").is_ok());
assert!(validate_read_only("SHOW TABLES").is_ok());
}
#[test]
fn test_describe_allowed() {
assert!(validate_read_only("DESC users").is_ok());
assert!(validate_read_only("DESCRIBE users").is_ok());
}
#[test]
fn test_use_allowed() {
assert!(validate_read_only("USE mydb").is_ok());
}
#[test]
fn test_insert_blocked() {
assert!(matches!(
validate_read_only("INSERT INTO users VALUES (1)"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_update_blocked() {
assert!(matches!(
validate_read_only("UPDATE users SET name='x'"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_delete_blocked() {
assert!(matches!(
validate_read_only("DELETE FROM users"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_drop_blocked() {
assert!(matches!(
validate_read_only("DROP TABLE users"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_create_blocked() {
assert!(matches!(
validate_read_only("CREATE TABLE test (id INT)"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_comment_bypass_single_line() {
let result = validate_read_only("SELECT 1 -- \nDELETE FROM users");
assert!(result.is_ok() || matches!(result, Err(AppError::MultiStatement)));
}
#[test]
fn test_comment_bypass_multi_line() {
assert!(matches!(
validate_read_only("/* SELECT */ DELETE FROM users"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_load_file_blocked() {
assert!(matches!(
validate_read_only("SELECT LOAD_FILE('/etc/passwd')"),
Err(AppError::LoadFileBlocked)
));
}
#[test]
fn test_load_file_case_insensitive() {
assert!(matches!(
validate_read_only("SELECT load_file('/etc/passwd')"),
Err(AppError::LoadFileBlocked)
));
}
#[test]
fn test_load_file_with_spaces() {
assert!(matches!(
validate_read_only("SELECT LOAD_FILE ('/etc/passwd')"),
Err(AppError::LoadFileBlocked)
));
}
#[test]
fn test_into_outfile_blocked() {
assert!(matches!(
validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'"),
Err(AppError::IntoOutfileBlocked)
));
}
#[test]
fn test_into_dumpfile_blocked() {
assert!(matches!(
validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'"),
Err(AppError::IntoOutfileBlocked)
));
}
#[test]
fn test_load_file_in_string_allowed() {
assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual").is_ok());
}
#[test]
fn test_empty_query_blocked() {
assert!(matches!(validate_read_only(""), Err(AppError::ReadOnlyViolation)));
}
#[test]
fn test_comment_only_blocked() {
let result = validate_read_only("-- just a comment");
assert!(result.is_err());
}
#[test]
fn test_multi_statement_blocked() {
assert!(matches!(
validate_read_only("SELECT 1; SELECT 2"),
Err(AppError::MultiStatement)
));
}
#[test]
fn test_multi_statement_injection_blocked() {
assert!(matches!(
validate_read_only("SELECT 1; DROP TABLE users"),
Err(AppError::MultiStatement)
));
}
#[test]
fn test_set_statement_blocked() {
assert!(matches!(
validate_read_only("SET @var = 1"),
Err(AppError::ReadOnlyViolation)
));
}
#[test]
fn test_malformed_sql_rejected() {
let result = validate_read_only("SELEC * FORM users");
assert!(result.is_err());
}
#[test]
fn test_select_with_subquery_allowed() {
assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t").is_ok());
}
#[test]
fn test_select_with_where_allowed() {
assert!(validate_read_only("SELECT * FROM users WHERE id = 1").is_ok());
}
#[test]
fn test_select_count_allowed() {
assert!(validate_read_only("SELECT COUNT(*) FROM users").is_ok());
}
}