Skip to main content

database_mcp_sql/
validation.rs

1//! AST-based SQL validation for read-only mode enforcement.
2//!
3//! Parses SQL using sqlparser's `MySQL` dialect and validates statement
4//! type, single-statement enforcement, and dangerous function blocking.
5
6use database_mcp_server::AppError;
7use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
8use sqlparser::dialect::Dialect;
9#[cfg(test)]
10use sqlparser::dialect::MySqlDialect;
11use sqlparser::parser::Parser;
12
13/// Validates that a SQL query is read-only.
14///
15/// Parses the query using the given dialect and checks:
16/// 1. Exactly one statement (multi-statement injection blocked)
17/// 2. Statement type is read-only (SELECT, SHOW, DESCRIBE, USE, EXPLAIN)
18/// 3. No dangerous functions (`LOAD_FILE`)
19/// 4. No INTO OUTFILE/DUMPFILE clauses
20///
21/// # Errors
22///
23/// Returns [`AppError`] if the query is not allowed in read-only mode.
24pub fn validate_read_only_with_dialect(sql: &str, dialect: &impl Dialect) -> Result<(), AppError> {
25    let trimmed = sql.trim();
26    if trimmed.is_empty() {
27        return Err(AppError::ReadOnlyViolation);
28    }
29
30    // Pre-check for INTO OUTFILE/DUMPFILE — sqlparser may not parse MySQL-specific syntax
31    let upper = trimmed.to_uppercase();
32    if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
33        return Err(AppError::IntoOutfileBlocked);
34    }
35
36    let statements =
37        Parser::parse_sql(dialect, trimmed).map_err(|e| AppError::Query(format!("SQL parse error: {e}")))?;
38
39    // Must be exactly one statement
40    if statements.is_empty() {
41        return Err(AppError::ReadOnlyViolation);
42    }
43    if statements.len() > 1 {
44        return Err(AppError::MultiStatement);
45    }
46
47    let stmt = &statements[0];
48
49    // Check statement type is read-only
50    match stmt {
51        Statement::Query(_) => {
52            // SELECT — but check for dangerous functions
53            check_dangerous_functions(stmt)?;
54        }
55        Statement::ShowTables { .. }
56        | Statement::ShowColumns { .. }
57        | Statement::ShowCreate { .. }
58        | Statement::ShowVariable { .. }
59        | Statement::ShowVariables { .. }
60        | Statement::ShowStatus { .. }
61        | Statement::ShowDatabases { .. }
62        | Statement::ShowSchemas { .. }
63        | Statement::ShowCollation { .. }
64        | Statement::ShowFunctions { .. }
65        | Statement::ShowViews { .. }
66        | Statement::ShowObjects(_)
67        | Statement::ExplainTable { .. }
68        | Statement::Explain { .. }
69        | Statement::Use(_) => {
70            // SHOW, DESCRIBE, EXPLAIN, USE are all read-only
71        }
72        _ => {
73            return Err(AppError::ReadOnlyViolation);
74        }
75    }
76
77    Ok(())
78}
79
80/// Convenience wrapper using `MySQL` dialect (for tests).
81///
82/// # Errors
83///
84/// Returns `AppError` if the SQL is not a read-only statement.
85#[cfg(test)]
86pub fn validate_read_only(sql: &str) -> Result<(), AppError> {
87    validate_read_only_with_dialect(sql, &MySqlDialect {})
88}
89
90/// Check for dangerous function calls like `LOAD_FILE()` in the AST.
91fn check_dangerous_functions(stmt: &Statement) -> Result<(), AppError> {
92    let mut checker = DangerousFunctionChecker { found: None };
93    let _ = stmt.visit(&mut checker);
94    if let Some(err) = checker.found {
95        return Err(err);
96    }
97    Ok(())
98}
99
100struct DangerousFunctionChecker {
101    found: Option<AppError>,
102}
103
104impl Visitor for DangerousFunctionChecker {
105    type Break = ();
106
107    fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
108        if let Expr::Function(Function { name, .. }) = expr {
109            let func_name = name.to_string().to_uppercase();
110            if func_name == "LOAD_FILE" {
111                self.found = Some(AppError::LoadFileBlocked);
112                return std::ops::ControlFlow::Break(());
113            }
114        }
115        std::ops::ControlFlow::Continue(())
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    // === Allowed queries ===
124
125    #[test]
126    fn test_select_allowed() {
127        assert!(validate_read_only("SELECT * FROM users").is_ok());
128        assert!(validate_read_only("select * from users").is_ok());
129    }
130
131    #[test]
132    fn test_show_allowed() {
133        assert!(validate_read_only("SHOW DATABASES").is_ok());
134        assert!(validate_read_only("SHOW TABLES").is_ok());
135    }
136
137    #[test]
138    fn test_describe_allowed() {
139        // sqlparser parses DESC/DESCRIBE as ExplainTable
140        assert!(validate_read_only("DESC users").is_ok());
141        assert!(validate_read_only("DESCRIBE users").is_ok());
142    }
143
144    #[test]
145    fn test_use_allowed() {
146        assert!(validate_read_only("USE mydb").is_ok());
147    }
148
149    // === Blocked statement types ===
150
151    #[test]
152    fn test_insert_blocked() {
153        assert!(matches!(
154            validate_read_only("INSERT INTO users VALUES (1)"),
155            Err(AppError::ReadOnlyViolation)
156        ));
157    }
158
159    #[test]
160    fn test_update_blocked() {
161        assert!(matches!(
162            validate_read_only("UPDATE users SET name='x'"),
163            Err(AppError::ReadOnlyViolation)
164        ));
165    }
166
167    #[test]
168    fn test_delete_blocked() {
169        assert!(matches!(
170            validate_read_only("DELETE FROM users"),
171            Err(AppError::ReadOnlyViolation)
172        ));
173    }
174
175    #[test]
176    fn test_drop_blocked() {
177        assert!(matches!(
178            validate_read_only("DROP TABLE users"),
179            Err(AppError::ReadOnlyViolation)
180        ));
181    }
182
183    #[test]
184    fn test_create_blocked() {
185        assert!(matches!(
186            validate_read_only("CREATE TABLE test (id INT)"),
187            Err(AppError::ReadOnlyViolation)
188        ));
189    }
190
191    // === Comment bypass attacks ===
192
193    #[test]
194    fn test_comment_bypass_single_line() {
195        // With AST parsing, "SELECT 1 -- \nDELETE FROM users" is parsed as two statements
196        // (or the comment hides the DELETE, making it one SELECT).
197        // Either way, if it parses as multiple statements, it's blocked.
198        // If the parser treats -- as a comment and only sees SELECT 1, it's allowed.
199        let result = validate_read_only("SELECT 1 -- \nDELETE FROM users");
200        // The parser should treat -- as comment, so only SELECT 1 remains → allowed
201        assert!(result.is_ok() || matches!(result, Err(AppError::MultiStatement)));
202    }
203
204    #[test]
205    fn test_comment_bypass_multi_line() {
206        // "/* SELECT */ DELETE FROM users" — parser strips comment, sees DELETE
207        assert!(matches!(
208            validate_read_only("/* SELECT */ DELETE FROM users"),
209            Err(AppError::ReadOnlyViolation)
210        ));
211    }
212
213    // === Dangerous functions ===
214
215    #[test]
216    fn test_load_file_blocked() {
217        assert!(matches!(
218            validate_read_only("SELECT LOAD_FILE('/etc/passwd')"),
219            Err(AppError::LoadFileBlocked)
220        ));
221    }
222
223    #[test]
224    fn test_load_file_case_insensitive() {
225        assert!(matches!(
226            validate_read_only("SELECT load_file('/etc/passwd')"),
227            Err(AppError::LoadFileBlocked)
228        ));
229    }
230
231    #[test]
232    fn test_load_file_with_spaces() {
233        // sqlparser normalizes function calls, so spaces before ( are handled
234        assert!(matches!(
235            validate_read_only("SELECT LOAD_FILE ('/etc/passwd')"),
236            Err(AppError::LoadFileBlocked)
237        ));
238    }
239
240    // === INTO OUTFILE/DUMPFILE ===
241
242    #[test]
243    fn test_into_outfile_blocked() {
244        assert!(matches!(
245            validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'"),
246            Err(AppError::IntoOutfileBlocked)
247        ));
248    }
249
250    #[test]
251    fn test_into_dumpfile_blocked() {
252        assert!(matches!(
253            validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'"),
254            Err(AppError::IntoOutfileBlocked)
255        ));
256    }
257
258    // === String literals should NOT trigger false positives ===
259
260    #[test]
261    fn test_load_file_in_string_allowed() {
262        // LOAD_FILE inside a string literal is NOT a function call in the AST
263        assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual").is_ok());
264    }
265
266    // === Empty / comment-only queries ===
267
268    #[test]
269    fn test_empty_query_blocked() {
270        assert!(matches!(validate_read_only(""), Err(AppError::ReadOnlyViolation)));
271    }
272
273    #[test]
274    fn test_comment_only_blocked() {
275        // Comment-only input: parser returns empty statements or parse error
276        let result = validate_read_only("-- just a comment");
277        assert!(result.is_err());
278    }
279
280    // === New tests for AST-based validation ===
281
282    #[test]
283    fn test_multi_statement_blocked() {
284        assert!(matches!(
285            validate_read_only("SELECT 1; SELECT 2"),
286            Err(AppError::MultiStatement)
287        ));
288    }
289
290    #[test]
291    fn test_multi_statement_injection_blocked() {
292        assert!(matches!(
293            validate_read_only("SELECT 1; DROP TABLE users"),
294            Err(AppError::MultiStatement)
295        ));
296    }
297
298    #[test]
299    fn test_set_statement_blocked() {
300        assert!(matches!(
301            validate_read_only("SET @var = 1"),
302            Err(AppError::ReadOnlyViolation)
303        ));
304    }
305
306    #[test]
307    fn test_malformed_sql_rejected() {
308        let result = validate_read_only("SELEC * FORM users");
309        assert!(result.is_err());
310    }
311
312    #[test]
313    fn test_select_with_subquery_allowed() {
314        assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t").is_ok());
315    }
316
317    #[test]
318    fn test_select_with_where_allowed() {
319        assert!(validate_read_only("SELECT * FROM users WHERE id = 1").is_ok());
320    }
321
322    #[test]
323    fn test_select_count_allowed() {
324        assert!(validate_read_only("SELECT COUNT(*) FROM users").is_ok());
325    }
326}