Skip to main content

database_mcp_backend/
validation.rs

1//! AST-based SQL validation for read-only mode enforcement.
2//!
3//! Parses SQL using sqlparser's `MySQL` dialect and validates statement
4//! type, single-statement enforcement, and dangerous function blocking.
5
6use crate::error::AppError;
7use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
8use sqlparser::dialect::Dialect;
9#[cfg(test)]
10use sqlparser::dialect::MySqlDialect;
11use sqlparser::parser::Parser;
12
13/// Validates that a SQL query is read-only.
14///
15/// Parses the query using the given dialect and checks:
16/// 1. Exactly one statement (multi-statement injection blocked)
17/// 2. Statement type is read-only (SELECT, SHOW, DESCRIBE, USE, EXPLAIN)
18/// 3. No dangerous functions (`LOAD_FILE`)
19/// 4. No INTO OUTFILE/DUMPFILE clauses
20///
21/// # Errors
22///
23/// Returns [`AppError`] if the query is not allowed in read-only mode.
24pub fn validate_read_only_with_dialect(sql: &str, dialect: &impl Dialect) -> Result<(), AppError> {
25    let trimmed = sql.trim();
26    if trimmed.is_empty() {
27        return Err(AppError::ReadOnlyViolation);
28    }
29
30    // Pre-check for INTO OUTFILE/DUMPFILE — sqlparser may not parse MySQL-specific syntax
31    let upper = trimmed.to_uppercase();
32    if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
33        return Err(AppError::IntoOutfileBlocked);
34    }
35
36    let statements =
37        Parser::parse_sql(dialect, trimmed).map_err(|e| AppError::Query(format!("SQL parse error: {e}")))?;
38
39    // Must be exactly one statement
40    if statements.is_empty() {
41        return Err(AppError::ReadOnlyViolation);
42    }
43    if statements.len() > 1 {
44        return Err(AppError::MultiStatement);
45    }
46
47    let stmt = &statements[0];
48
49    // Check statement type is read-only
50    match stmt {
51        Statement::Query(_) => {
52            // SELECT — but check for dangerous functions
53            check_dangerous_functions(stmt)?;
54        }
55        Statement::ShowTables { .. }
56        | Statement::ShowColumns { .. }
57        | Statement::ShowCreate { .. }
58        | Statement::ShowVariable { .. }
59        | Statement::ShowVariables { .. }
60        | Statement::ShowStatus { .. }
61        | Statement::ShowDatabases { .. }
62        | Statement::ShowSchemas { .. }
63        | Statement::ShowCollation { .. }
64        | Statement::ShowFunctions { .. }
65        | Statement::ShowViews { .. }
66        | Statement::ShowObjects(_)
67        | Statement::ExplainTable { .. }
68        | Statement::Explain { .. }
69        | Statement::Use(_) => {
70            // SHOW, DESCRIBE, EXPLAIN, USE are all read-only
71        }
72        _ => {
73            return Err(AppError::ReadOnlyViolation);
74        }
75    }
76
77    Ok(())
78}
79
80/// Convenience wrapper using MySQL dialect (for tests).
81#[cfg(test)]
82pub fn validate_read_only(sql: &str) -> Result<(), AppError> {
83    validate_read_only_with_dialect(sql, &MySqlDialect {})
84}
85
86/// Check for dangerous function calls like `LOAD_FILE()` in the AST.
87fn check_dangerous_functions(stmt: &Statement) -> Result<(), AppError> {
88    let mut checker = DangerousFunctionChecker { found: None };
89    let _ = stmt.visit(&mut checker);
90    if let Some(err) = checker.found {
91        return Err(err);
92    }
93    Ok(())
94}
95
96struct DangerousFunctionChecker {
97    found: Option<AppError>,
98}
99
100impl Visitor for DangerousFunctionChecker {
101    type Break = ();
102
103    fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
104        if let Expr::Function(Function { name, .. }) = expr {
105            let func_name = name.to_string().to_uppercase();
106            if func_name == "LOAD_FILE" {
107                self.found = Some(AppError::LoadFileBlocked);
108                return std::ops::ControlFlow::Break(());
109            }
110        }
111        std::ops::ControlFlow::Continue(())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    // === Allowed queries ===
120
121    #[test]
122    fn test_select_allowed() {
123        assert!(validate_read_only("SELECT * FROM users").is_ok());
124        assert!(validate_read_only("select * from users").is_ok());
125    }
126
127    #[test]
128    fn test_show_allowed() {
129        assert!(validate_read_only("SHOW DATABASES").is_ok());
130        assert!(validate_read_only("SHOW TABLES").is_ok());
131    }
132
133    #[test]
134    fn test_describe_allowed() {
135        // sqlparser parses DESC/DESCRIBE as ExplainTable
136        assert!(validate_read_only("DESC users").is_ok());
137        assert!(validate_read_only("DESCRIBE users").is_ok());
138    }
139
140    #[test]
141    fn test_use_allowed() {
142        assert!(validate_read_only("USE mydb").is_ok());
143    }
144
145    // === Blocked statement types ===
146
147    #[test]
148    fn test_insert_blocked() {
149        assert!(matches!(
150            validate_read_only("INSERT INTO users VALUES (1)"),
151            Err(AppError::ReadOnlyViolation)
152        ));
153    }
154
155    #[test]
156    fn test_update_blocked() {
157        assert!(matches!(
158            validate_read_only("UPDATE users SET name='x'"),
159            Err(AppError::ReadOnlyViolation)
160        ));
161    }
162
163    #[test]
164    fn test_delete_blocked() {
165        assert!(matches!(
166            validate_read_only("DELETE FROM users"),
167            Err(AppError::ReadOnlyViolation)
168        ));
169    }
170
171    #[test]
172    fn test_drop_blocked() {
173        assert!(matches!(
174            validate_read_only("DROP TABLE users"),
175            Err(AppError::ReadOnlyViolation)
176        ));
177    }
178
179    #[test]
180    fn test_create_blocked() {
181        assert!(matches!(
182            validate_read_only("CREATE TABLE test (id INT)"),
183            Err(AppError::ReadOnlyViolation)
184        ));
185    }
186
187    // === Comment bypass attacks ===
188
189    #[test]
190    fn test_comment_bypass_single_line() {
191        // With AST parsing, "SELECT 1 -- \nDELETE FROM users" is parsed as two statements
192        // (or the comment hides the DELETE, making it one SELECT).
193        // Either way, if it parses as multiple statements, it's blocked.
194        // If the parser treats -- as a comment and only sees SELECT 1, it's allowed.
195        let result = validate_read_only("SELECT 1 -- \nDELETE FROM users");
196        // The parser should treat -- as comment, so only SELECT 1 remains → allowed
197        assert!(result.is_ok() || matches!(result, Err(AppError::MultiStatement)));
198    }
199
200    #[test]
201    fn test_comment_bypass_multi_line() {
202        // "/* SELECT */ DELETE FROM users" — parser strips comment, sees DELETE
203        assert!(matches!(
204            validate_read_only("/* SELECT */ DELETE FROM users"),
205            Err(AppError::ReadOnlyViolation)
206        ));
207    }
208
209    // === Dangerous functions ===
210
211    #[test]
212    fn test_load_file_blocked() {
213        assert!(matches!(
214            validate_read_only("SELECT LOAD_FILE('/etc/passwd')"),
215            Err(AppError::LoadFileBlocked)
216        ));
217    }
218
219    #[test]
220    fn test_load_file_case_insensitive() {
221        assert!(matches!(
222            validate_read_only("SELECT load_file('/etc/passwd')"),
223            Err(AppError::LoadFileBlocked)
224        ));
225    }
226
227    #[test]
228    fn test_load_file_with_spaces() {
229        // sqlparser normalizes function calls, so spaces before ( are handled
230        assert!(matches!(
231            validate_read_only("SELECT LOAD_FILE ('/etc/passwd')"),
232            Err(AppError::LoadFileBlocked)
233        ));
234    }
235
236    // === INTO OUTFILE/DUMPFILE ===
237
238    #[test]
239    fn test_into_outfile_blocked() {
240        assert!(matches!(
241            validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'"),
242            Err(AppError::IntoOutfileBlocked)
243        ));
244    }
245
246    #[test]
247    fn test_into_dumpfile_blocked() {
248        assert!(matches!(
249            validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'"),
250            Err(AppError::IntoOutfileBlocked)
251        ));
252    }
253
254    // === String literals should NOT trigger false positives ===
255
256    #[test]
257    fn test_load_file_in_string_allowed() {
258        // LOAD_FILE inside a string literal is NOT a function call in the AST
259        assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual").is_ok());
260    }
261
262    // === Empty / comment-only queries ===
263
264    #[test]
265    fn test_empty_query_blocked() {
266        assert!(matches!(validate_read_only(""), Err(AppError::ReadOnlyViolation)));
267    }
268
269    #[test]
270    fn test_comment_only_blocked() {
271        // Comment-only input: parser returns empty statements or parse error
272        let result = validate_read_only("-- just a comment");
273        assert!(result.is_err());
274    }
275
276    // === New tests for AST-based validation ===
277
278    #[test]
279    fn test_multi_statement_blocked() {
280        assert!(matches!(
281            validate_read_only("SELECT 1; SELECT 2"),
282            Err(AppError::MultiStatement)
283        ));
284    }
285
286    #[test]
287    fn test_multi_statement_injection_blocked() {
288        assert!(matches!(
289            validate_read_only("SELECT 1; DROP TABLE users"),
290            Err(AppError::MultiStatement)
291        ));
292    }
293
294    #[test]
295    fn test_set_statement_blocked() {
296        assert!(matches!(
297            validate_read_only("SET @var = 1"),
298            Err(AppError::ReadOnlyViolation)
299        ));
300    }
301
302    #[test]
303    fn test_malformed_sql_rejected() {
304        let result = validate_read_only("SELEC * FORM users");
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_select_with_subquery_allowed() {
310        assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t").is_ok());
311    }
312
313    #[test]
314    fn test_select_with_where_allowed() {
315        assert!(validate_read_only("SELECT * FROM users WHERE id = 1").is_ok());
316    }
317
318    #[test]
319    fn test_select_count_allowed() {
320        assert!(validate_read_only("SELECT COUNT(*) FROM users").is_ok());
321    }
322}