database_mcp_backend/
validation.rs1use crate::error::AppError;
7use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
8use sqlparser::dialect::Dialect;
9#[cfg(test)]
10use sqlparser::dialect::MySqlDialect;
11use sqlparser::parser::Parser;
12
13pub fn validate_read_only_with_dialect(sql: &str, dialect: &impl Dialect) -> Result<(), AppError> {
25 let trimmed = sql.trim();
26 if trimmed.is_empty() {
27 return Err(AppError::ReadOnlyViolation);
28 }
29
30 let upper = trimmed.to_uppercase();
32 if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
33 return Err(AppError::IntoOutfileBlocked);
34 }
35
36 let statements =
37 Parser::parse_sql(dialect, trimmed).map_err(|e| AppError::Query(format!("SQL parse error: {e}")))?;
38
39 if statements.is_empty() {
41 return Err(AppError::ReadOnlyViolation);
42 }
43 if statements.len() > 1 {
44 return Err(AppError::MultiStatement);
45 }
46
47 let stmt = &statements[0];
48
49 match stmt {
51 Statement::Query(_) => {
52 check_dangerous_functions(stmt)?;
54 }
55 Statement::ShowTables { .. }
56 | Statement::ShowColumns { .. }
57 | Statement::ShowCreate { .. }
58 | Statement::ShowVariable { .. }
59 | Statement::ShowVariables { .. }
60 | Statement::ShowStatus { .. }
61 | Statement::ShowDatabases { .. }
62 | Statement::ShowSchemas { .. }
63 | Statement::ShowCollation { .. }
64 | Statement::ShowFunctions { .. }
65 | Statement::ShowViews { .. }
66 | Statement::ShowObjects(_)
67 | Statement::ExplainTable { .. }
68 | Statement::Explain { .. }
69 | Statement::Use(_) => {
70 }
72 _ => {
73 return Err(AppError::ReadOnlyViolation);
74 }
75 }
76
77 Ok(())
78}
79
80#[cfg(test)]
82pub fn validate_read_only(sql: &str) -> Result<(), AppError> {
83 validate_read_only_with_dialect(sql, &MySqlDialect {})
84}
85
86fn check_dangerous_functions(stmt: &Statement) -> Result<(), AppError> {
88 let mut checker = DangerousFunctionChecker { found: None };
89 let _ = stmt.visit(&mut checker);
90 if let Some(err) = checker.found {
91 return Err(err);
92 }
93 Ok(())
94}
95
96struct DangerousFunctionChecker {
97 found: Option<AppError>,
98}
99
100impl Visitor for DangerousFunctionChecker {
101 type Break = ();
102
103 fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
104 if let Expr::Function(Function { name, .. }) = expr {
105 let func_name = name.to_string().to_uppercase();
106 if func_name == "LOAD_FILE" {
107 self.found = Some(AppError::LoadFileBlocked);
108 return std::ops::ControlFlow::Break(());
109 }
110 }
111 std::ops::ControlFlow::Continue(())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
122 fn test_select_allowed() {
123 assert!(validate_read_only("SELECT * FROM users").is_ok());
124 assert!(validate_read_only("select * from users").is_ok());
125 }
126
127 #[test]
128 fn test_show_allowed() {
129 assert!(validate_read_only("SHOW DATABASES").is_ok());
130 assert!(validate_read_only("SHOW TABLES").is_ok());
131 }
132
133 #[test]
134 fn test_describe_allowed() {
135 assert!(validate_read_only("DESC users").is_ok());
137 assert!(validate_read_only("DESCRIBE users").is_ok());
138 }
139
140 #[test]
141 fn test_use_allowed() {
142 assert!(validate_read_only("USE mydb").is_ok());
143 }
144
145 #[test]
148 fn test_insert_blocked() {
149 assert!(matches!(
150 validate_read_only("INSERT INTO users VALUES (1)"),
151 Err(AppError::ReadOnlyViolation)
152 ));
153 }
154
155 #[test]
156 fn test_update_blocked() {
157 assert!(matches!(
158 validate_read_only("UPDATE users SET name='x'"),
159 Err(AppError::ReadOnlyViolation)
160 ));
161 }
162
163 #[test]
164 fn test_delete_blocked() {
165 assert!(matches!(
166 validate_read_only("DELETE FROM users"),
167 Err(AppError::ReadOnlyViolation)
168 ));
169 }
170
171 #[test]
172 fn test_drop_blocked() {
173 assert!(matches!(
174 validate_read_only("DROP TABLE users"),
175 Err(AppError::ReadOnlyViolation)
176 ));
177 }
178
179 #[test]
180 fn test_create_blocked() {
181 assert!(matches!(
182 validate_read_only("CREATE TABLE test (id INT)"),
183 Err(AppError::ReadOnlyViolation)
184 ));
185 }
186
187 #[test]
190 fn test_comment_bypass_single_line() {
191 let result = validate_read_only("SELECT 1 -- \nDELETE FROM users");
196 assert!(result.is_ok() || matches!(result, Err(AppError::MultiStatement)));
198 }
199
200 #[test]
201 fn test_comment_bypass_multi_line() {
202 assert!(matches!(
204 validate_read_only("/* SELECT */ DELETE FROM users"),
205 Err(AppError::ReadOnlyViolation)
206 ));
207 }
208
209 #[test]
212 fn test_load_file_blocked() {
213 assert!(matches!(
214 validate_read_only("SELECT LOAD_FILE('/etc/passwd')"),
215 Err(AppError::LoadFileBlocked)
216 ));
217 }
218
219 #[test]
220 fn test_load_file_case_insensitive() {
221 assert!(matches!(
222 validate_read_only("SELECT load_file('/etc/passwd')"),
223 Err(AppError::LoadFileBlocked)
224 ));
225 }
226
227 #[test]
228 fn test_load_file_with_spaces() {
229 assert!(matches!(
231 validate_read_only("SELECT LOAD_FILE ('/etc/passwd')"),
232 Err(AppError::LoadFileBlocked)
233 ));
234 }
235
236 #[test]
239 fn test_into_outfile_blocked() {
240 assert!(matches!(
241 validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'"),
242 Err(AppError::IntoOutfileBlocked)
243 ));
244 }
245
246 #[test]
247 fn test_into_dumpfile_blocked() {
248 assert!(matches!(
249 validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'"),
250 Err(AppError::IntoOutfileBlocked)
251 ));
252 }
253
254 #[test]
257 fn test_load_file_in_string_allowed() {
258 assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual").is_ok());
260 }
261
262 #[test]
265 fn test_empty_query_blocked() {
266 assert!(matches!(validate_read_only(""), Err(AppError::ReadOnlyViolation)));
267 }
268
269 #[test]
270 fn test_comment_only_blocked() {
271 let result = validate_read_only("-- just a comment");
273 assert!(result.is_err());
274 }
275
276 #[test]
279 fn test_multi_statement_blocked() {
280 assert!(matches!(
281 validate_read_only("SELECT 1; SELECT 2"),
282 Err(AppError::MultiStatement)
283 ));
284 }
285
286 #[test]
287 fn test_multi_statement_injection_blocked() {
288 assert!(matches!(
289 validate_read_only("SELECT 1; DROP TABLE users"),
290 Err(AppError::MultiStatement)
291 ));
292 }
293
294 #[test]
295 fn test_set_statement_blocked() {
296 assert!(matches!(
297 validate_read_only("SET @var = 1"),
298 Err(AppError::ReadOnlyViolation)
299 ));
300 }
301
302 #[test]
303 fn test_malformed_sql_rejected() {
304 let result = validate_read_only("SELEC * FORM users");
305 assert!(result.is_err());
306 }
307
308 #[test]
309 fn test_select_with_subquery_allowed() {
310 assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t").is_ok());
311 }
312
313 #[test]
314 fn test_select_with_where_allowed() {
315 assert!(validate_read_only("SELECT * FROM users WHERE id = 1").is_ok());
316 }
317
318 #[test]
319 fn test_select_count_allowed() {
320 assert!(validate_read_only("SELECT COUNT(*) FROM users").is_ok());
321 }
322}