database_mcp_sql/
validation.rs1use database_mcp_server::AppError;
7use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
8use sqlparser::dialect::Dialect;
9#[cfg(test)]
10use sqlparser::dialect::MySqlDialect;
11use sqlparser::parser::Parser;
12
13pub fn validate_read_only_with_dialect(sql: &str, dialect: &impl Dialect) -> Result<(), AppError> {
25 let trimmed = sql.trim();
26 if trimmed.is_empty() {
27 return Err(AppError::ReadOnlyViolation);
28 }
29
30 let upper = trimmed.to_uppercase();
32 if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
33 return Err(AppError::IntoOutfileBlocked);
34 }
35
36 let statements =
37 Parser::parse_sql(dialect, trimmed).map_err(|e| AppError::Query(format!("SQL parse error: {e}")))?;
38
39 if statements.is_empty() {
41 return Err(AppError::ReadOnlyViolation);
42 }
43 if statements.len() > 1 {
44 return Err(AppError::MultiStatement);
45 }
46
47 let stmt = &statements[0];
48
49 match stmt {
51 Statement::Query(_) => {
52 check_dangerous_functions(stmt)?;
54 }
55 Statement::ShowTables { .. }
56 | Statement::ShowColumns { .. }
57 | Statement::ShowCreate { .. }
58 | Statement::ShowVariable { .. }
59 | Statement::ShowVariables { .. }
60 | Statement::ShowStatus { .. }
61 | Statement::ShowDatabases { .. }
62 | Statement::ShowSchemas { .. }
63 | Statement::ShowCollation { .. }
64 | Statement::ShowFunctions { .. }
65 | Statement::ShowViews { .. }
66 | Statement::ShowObjects(_)
67 | Statement::ExplainTable { .. }
68 | Statement::Explain { .. }
69 | Statement::Use(_) => {
70 }
72 _ => {
73 return Err(AppError::ReadOnlyViolation);
74 }
75 }
76
77 Ok(())
78}
79
80#[cfg(test)]
86pub fn validate_read_only(sql: &str) -> Result<(), AppError> {
87 validate_read_only_with_dialect(sql, &MySqlDialect {})
88}
89
90fn check_dangerous_functions(stmt: &Statement) -> Result<(), AppError> {
92 let mut checker = DangerousFunctionChecker { found: None };
93 let _ = stmt.visit(&mut checker);
94 if let Some(err) = checker.found {
95 return Err(err);
96 }
97 Ok(())
98}
99
100struct DangerousFunctionChecker {
101 found: Option<AppError>,
102}
103
104impl Visitor for DangerousFunctionChecker {
105 type Break = ();
106
107 fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
108 if let Expr::Function(Function { name, .. }) = expr {
109 let func_name = name.to_string().to_uppercase();
110 if func_name == "LOAD_FILE" {
111 self.found = Some(AppError::LoadFileBlocked);
112 return std::ops::ControlFlow::Break(());
113 }
114 }
115 std::ops::ControlFlow::Continue(())
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
126 fn test_select_allowed() {
127 assert!(validate_read_only("SELECT * FROM users").is_ok());
128 assert!(validate_read_only("select * from users").is_ok());
129 }
130
131 #[test]
132 fn test_show_allowed() {
133 assert!(validate_read_only("SHOW DATABASES").is_ok());
134 assert!(validate_read_only("SHOW TABLES").is_ok());
135 }
136
137 #[test]
138 fn test_describe_allowed() {
139 assert!(validate_read_only("DESC users").is_ok());
141 assert!(validate_read_only("DESCRIBE users").is_ok());
142 }
143
144 #[test]
145 fn test_use_allowed() {
146 assert!(validate_read_only("USE mydb").is_ok());
147 }
148
149 #[test]
152 fn test_insert_blocked() {
153 assert!(matches!(
154 validate_read_only("INSERT INTO users VALUES (1)"),
155 Err(AppError::ReadOnlyViolation)
156 ));
157 }
158
159 #[test]
160 fn test_update_blocked() {
161 assert!(matches!(
162 validate_read_only("UPDATE users SET name='x'"),
163 Err(AppError::ReadOnlyViolation)
164 ));
165 }
166
167 #[test]
168 fn test_delete_blocked() {
169 assert!(matches!(
170 validate_read_only("DELETE FROM users"),
171 Err(AppError::ReadOnlyViolation)
172 ));
173 }
174
175 #[test]
176 fn test_drop_blocked() {
177 assert!(matches!(
178 validate_read_only("DROP TABLE users"),
179 Err(AppError::ReadOnlyViolation)
180 ));
181 }
182
183 #[test]
184 fn test_create_blocked() {
185 assert!(matches!(
186 validate_read_only("CREATE TABLE test (id INT)"),
187 Err(AppError::ReadOnlyViolation)
188 ));
189 }
190
191 #[test]
194 fn test_comment_bypass_single_line() {
195 let result = validate_read_only("SELECT 1 -- \nDELETE FROM users");
200 assert!(result.is_ok() || matches!(result, Err(AppError::MultiStatement)));
202 }
203
204 #[test]
205 fn test_comment_bypass_multi_line() {
206 assert!(matches!(
208 validate_read_only("/* SELECT */ DELETE FROM users"),
209 Err(AppError::ReadOnlyViolation)
210 ));
211 }
212
213 #[test]
216 fn test_load_file_blocked() {
217 assert!(matches!(
218 validate_read_only("SELECT LOAD_FILE('/etc/passwd')"),
219 Err(AppError::LoadFileBlocked)
220 ));
221 }
222
223 #[test]
224 fn test_load_file_case_insensitive() {
225 assert!(matches!(
226 validate_read_only("SELECT load_file('/etc/passwd')"),
227 Err(AppError::LoadFileBlocked)
228 ));
229 }
230
231 #[test]
232 fn test_load_file_with_spaces() {
233 assert!(matches!(
235 validate_read_only("SELECT LOAD_FILE ('/etc/passwd')"),
236 Err(AppError::LoadFileBlocked)
237 ));
238 }
239
240 #[test]
243 fn test_into_outfile_blocked() {
244 assert!(matches!(
245 validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'"),
246 Err(AppError::IntoOutfileBlocked)
247 ));
248 }
249
250 #[test]
251 fn test_into_dumpfile_blocked() {
252 assert!(matches!(
253 validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'"),
254 Err(AppError::IntoOutfileBlocked)
255 ));
256 }
257
258 #[test]
261 fn test_load_file_in_string_allowed() {
262 assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual").is_ok());
264 }
265
266 #[test]
269 fn test_empty_query_blocked() {
270 assert!(matches!(validate_read_only(""), Err(AppError::ReadOnlyViolation)));
271 }
272
273 #[test]
274 fn test_comment_only_blocked() {
275 let result = validate_read_only("-- just a comment");
277 assert!(result.is_err());
278 }
279
280 #[test]
283 fn test_multi_statement_blocked() {
284 assert!(matches!(
285 validate_read_only("SELECT 1; SELECT 2"),
286 Err(AppError::MultiStatement)
287 ));
288 }
289
290 #[test]
291 fn test_multi_statement_injection_blocked() {
292 assert!(matches!(
293 validate_read_only("SELECT 1; DROP TABLE users"),
294 Err(AppError::MultiStatement)
295 ));
296 }
297
298 #[test]
299 fn test_set_statement_blocked() {
300 assert!(matches!(
301 validate_read_only("SET @var = 1"),
302 Err(AppError::ReadOnlyViolation)
303 ));
304 }
305
306 #[test]
307 fn test_malformed_sql_rejected() {
308 let result = validate_read_only("SELEC * FORM users");
309 assert!(result.is_err());
310 }
311
312 #[test]
313 fn test_select_with_subquery_allowed() {
314 assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t").is_ok());
315 }
316
317 #[test]
318 fn test_select_with_where_allowed() {
319 assert!(validate_read_only("SELECT * FROM users WHERE id = 1").is_ok());
320 }
321
322 #[test]
323 fn test_select_count_allowed() {
324 assert!(validate_read_only("SELECT COUNT(*) FROM users").is_ok());
325 }
326}