Skip to main content

database_mcp_sql/
validation.rs

1//! AST-based SQL validation for read-only mode enforcement.
2
3use crate::SqlError;
4use sqlparser::ast::{Expr, Function, Statement, Visit, Visitor};
5use sqlparser::dialect::Dialect;
6use sqlparser::parser::Parser;
7
8/// Classifies a validated read-only statement for pagination dispatch.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum StatementKind {
11    /// `SELECT`, `WITH ... SELECT`, or compound queries (`UNION`/`INTERSECT`/`EXCEPT`).
12    /// Composable inside a subquery wrap with an appended `LIMIT` / `OFFSET`.
13    Select,
14    /// `SHOW`, `DESC` / `DESCRIBE`, `USE`, or `EXPLAIN` — read-only but not
15    /// composable as a subquery source. Returned in a single response.
16    NonSelect,
17}
18
19/// Validates that a SQL query is read-only and classifies it for pagination.
20///
21/// Parses the query using the given dialect and checks:
22/// 1. Exactly one statement (multi-statement injection blocked)
23/// 2. Statement type is read-only (SELECT, SHOW, DESCRIBE, USE, EXPLAIN)
24/// 3. No dangerous functions (`LOAD_FILE`)
25/// 4. No INTO OUTFILE/DUMPFILE clauses
26///
27/// Returns the classified [`StatementKind`] on success so callers that
28/// paginate can dispatch `Select` to a subquery-wrapped fetch and route
29/// `NonSelect` through a single-page pass-through.
30///
31/// # Errors
32///
33/// Returns [`SqlError`] if the query is not allowed in read-only mode.
34pub fn validate_read_only(sql: &str, dialect: &impl Dialect) -> Result<StatementKind, SqlError> {
35    let trimmed = sql.trim();
36    if trimmed.is_empty() {
37        return Err(SqlError::ReadOnlyViolation);
38    }
39
40    // Pre-check for INTO OUTFILE/DUMPFILE — sqlparser may not parse MySQL-specific syntax
41    let upper = trimmed.to_uppercase();
42    if upper.contains("INTO OUTFILE") || upper.contains("INTO DUMPFILE") {
43        return Err(SqlError::IntoOutfileBlocked);
44    }
45
46    let statements =
47        Parser::parse_sql(dialect, trimmed).map_err(|e| SqlError::Query(format!("SQL parse error: {e}")))?;
48
49    // Must be exactly one statement
50    if statements.is_empty() {
51        return Err(SqlError::ReadOnlyViolation);
52    }
53    if statements.len() > 1 {
54        return Err(SqlError::MultiStatement);
55    }
56
57    let stmt = &statements[0];
58
59    // Check statement type is read-only and classify it.
60    match stmt {
61        Statement::Query(_) => {
62            // SELECT — but check for dangerous functions
63            check_dangerous_functions(stmt)?;
64            Ok(StatementKind::Select)
65        }
66        Statement::ShowTables { .. }
67        | Statement::ShowColumns { .. }
68        | Statement::ShowCreate { .. }
69        | Statement::ShowVariable { .. }
70        | Statement::ShowVariables { .. }
71        | Statement::ShowStatus { .. }
72        | Statement::ShowDatabases { .. }
73        | Statement::ShowSchemas { .. }
74        | Statement::ShowCollation { .. }
75        | Statement::ShowFunctions { .. }
76        | Statement::ShowViews { .. }
77        | Statement::ShowObjects(_)
78        | Statement::ExplainTable { .. }
79        | Statement::Explain { .. }
80        | Statement::Use(_) => Ok(StatementKind::NonSelect),
81        _ => Err(SqlError::ReadOnlyViolation),
82    }
83}
84
85/// Check for dangerous function calls like `LOAD_FILE()` in the AST.
86fn check_dangerous_functions(stmt: &Statement) -> Result<(), SqlError> {
87    let mut checker = DangerousFunctionChecker { found: None };
88    let _ = stmt.visit(&mut checker);
89    if let Some(err) = checker.found {
90        return Err(err);
91    }
92    Ok(())
93}
94
95struct DangerousFunctionChecker {
96    found: Option<SqlError>,
97}
98
99impl Visitor for DangerousFunctionChecker {
100    type Break = ();
101
102    fn pre_visit_expr(&mut self, expr: &Expr) -> std::ops::ControlFlow<Self::Break> {
103        if let Expr::Function(Function { name, .. }) = expr {
104            let func_name = name.to_string().to_uppercase();
105            if func_name == "LOAD_FILE" {
106                self.found = Some(SqlError::LoadFileBlocked);
107                return std::ops::ControlFlow::Break(());
108            }
109        }
110        std::ops::ControlFlow::Continue(())
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
117
118    use super::*;
119
120    const MYSQL: MySqlDialect = MySqlDialect {};
121    const POSTGRES: PostgreSqlDialect = PostgreSqlDialect {};
122    const SQLITE: SQLiteDialect = SQLiteDialect {};
123
124    const DIALECT: MySqlDialect = MySqlDialect {};
125
126    // === Statement classification ===
127
128    #[test]
129    fn classifies_select_vs_non_select() {
130        // Select-like: plain SELECT, CTE, compound queries all map to StatementKind::Select.
131        assert_eq!(validate_read_only("SELECT 1", &DIALECT).unwrap(), StatementKind::Select,);
132        assert_eq!(
133            validate_read_only("WITH x AS (SELECT 1) SELECT * FROM x", &DIALECT).unwrap(),
134            StatementKind::Select,
135        );
136        assert_eq!(
137            validate_read_only("SELECT 1 UNION SELECT 2", &DIALECT).unwrap(),
138            StatementKind::Select,
139        );
140
141        // Catalog/metadata statements map to StatementKind::NonSelect.
142        assert_eq!(
143            validate_read_only("SHOW DATABASES", &DIALECT).unwrap(),
144            StatementKind::NonSelect,
145        );
146        assert_eq!(
147            validate_read_only("DESCRIBE users", &DIALECT).unwrap(),
148            StatementKind::NonSelect,
149        );
150        assert_eq!(
151            validate_read_only("USE app", &DIALECT).unwrap(),
152            StatementKind::NonSelect,
153        );
154        assert_eq!(
155            validate_read_only("EXPLAIN SELECT 1", &DIALECT).unwrap(),
156            StatementKind::NonSelect,
157        );
158    }
159
160    // === Allowed queries ===
161
162    #[test]
163    fn test_select_allowed() {
164        assert!(validate_read_only("SELECT * FROM users", &DIALECT).is_ok());
165        assert!(validate_read_only("select * from users", &DIALECT).is_ok());
166    }
167
168    #[test]
169    fn test_show_allowed() {
170        assert!(validate_read_only("SHOW DATABASES", &DIALECT).is_ok());
171        assert!(validate_read_only("SHOW TABLES", &DIALECT).is_ok());
172    }
173
174    #[test]
175    fn test_describe_allowed() {
176        // sqlparser parses DESC/DESCRIBE as ExplainTable
177        assert!(validate_read_only("DESC users", &DIALECT).is_ok());
178        assert!(validate_read_only("DESCRIBE users", &DIALECT).is_ok());
179    }
180
181    #[test]
182    fn test_use_allowed() {
183        assert!(validate_read_only("USE mydb", &DIALECT).is_ok());
184    }
185
186    // === Blocked statement types ===
187
188    #[test]
189    fn test_insert_blocked() {
190        assert!(matches!(
191            validate_read_only("INSERT INTO users VALUES (1)", &DIALECT),
192            Err(SqlError::ReadOnlyViolation)
193        ));
194    }
195
196    #[test]
197    fn test_update_blocked() {
198        assert!(matches!(
199            validate_read_only("UPDATE users SET name='x'", &DIALECT),
200            Err(SqlError::ReadOnlyViolation)
201        ));
202    }
203
204    #[test]
205    fn test_delete_blocked() {
206        assert!(matches!(
207            validate_read_only("DELETE FROM users", &DIALECT),
208            Err(SqlError::ReadOnlyViolation)
209        ));
210    }
211
212    #[test]
213    fn test_drop_blocked() {
214        assert!(matches!(
215            validate_read_only("DROP TABLE users", &DIALECT),
216            Err(SqlError::ReadOnlyViolation)
217        ));
218    }
219
220    #[test]
221    fn test_create_blocked() {
222        assert!(matches!(
223            validate_read_only("CREATE TABLE test (id INT)", &DIALECT),
224            Err(SqlError::ReadOnlyViolation)
225        ));
226    }
227
228    // === Comment bypass attacks ===
229
230    #[test]
231    fn test_comment_bypass_single_line() {
232        // With AST parsing, "SELECT 1 -- \nDELETE FROM users" is parsed as two statements
233        // (or the comment hides the DELETE, making it one SELECT).
234        // Either way, if it parses as multiple statements, it's blocked.
235        // If the parser treats -- as a comment and only sees SELECT 1, it's allowed.
236        let result = validate_read_only("SELECT 1 -- \nDELETE FROM users", &DIALECT);
237        // The parser should treat -- as comment, so only SELECT 1 remains → allowed
238        assert!(result.is_ok() || matches!(result, Err(SqlError::MultiStatement)));
239    }
240
241    #[test]
242    fn test_comment_bypass_multi_line() {
243        // "/* SELECT */ DELETE FROM users" — parser strips comment, sees DELETE
244        assert!(matches!(
245            validate_read_only("/* SELECT */ DELETE FROM users", &DIALECT),
246            Err(SqlError::ReadOnlyViolation)
247        ));
248    }
249
250    // === Dangerous functions ===
251
252    #[test]
253    fn test_load_file_blocked() {
254        assert!(matches!(
255            validate_read_only("SELECT LOAD_FILE('/etc/passwd')", &DIALECT),
256            Err(SqlError::LoadFileBlocked)
257        ));
258    }
259
260    #[test]
261    fn test_load_file_case_insensitive() {
262        assert!(matches!(
263            validate_read_only("SELECT load_file('/etc/passwd')", &DIALECT),
264            Err(SqlError::LoadFileBlocked)
265        ));
266    }
267
268    #[test]
269    fn test_load_file_with_spaces() {
270        // sqlparser normalizes function calls, so spaces before ( are handled
271        assert!(matches!(
272            validate_read_only("SELECT LOAD_FILE ('/etc/passwd')", &DIALECT),
273            Err(SqlError::LoadFileBlocked)
274        ));
275    }
276
277    // === INTO OUTFILE/DUMPFILE ===
278
279    #[test]
280    fn test_into_outfile_blocked() {
281        assert!(matches!(
282            validate_read_only("SELECT * FROM users INTO OUTFILE '/tmp/out'", &DIALECT),
283            Err(SqlError::IntoOutfileBlocked)
284        ));
285    }
286
287    #[test]
288    fn test_into_dumpfile_blocked() {
289        assert!(matches!(
290            validate_read_only("SELECT * FROM users INTO DUMPFILE '/tmp/out'", &DIALECT),
291            Err(SqlError::IntoOutfileBlocked)
292        ));
293    }
294
295    // === String literals should NOT trigger false positives ===
296
297    #[test]
298    fn test_load_file_in_string_allowed() {
299        // LOAD_FILE inside a string literal is NOT a function call in the AST
300        assert!(validate_read_only("SELECT 'LOAD_FILE(/etc/passwd)' FROM dual", &DIALECT).is_ok());
301    }
302
303    // === Empty / comment-only queries ===
304
305    #[test]
306    fn test_empty_query_blocked() {
307        assert!(matches!(
308            validate_read_only("", &DIALECT),
309            Err(SqlError::ReadOnlyViolation)
310        ));
311    }
312
313    #[test]
314    fn test_comment_only_blocked() {
315        // Comment-only input: parser returns empty statements or parse error
316        let result = validate_read_only("-- just a comment", &DIALECT);
317        assert!(result.is_err());
318    }
319
320    // === New tests for AST-based validation ===
321
322    #[test]
323    fn test_multi_statement_blocked() {
324        assert!(matches!(
325            validate_read_only("SELECT 1; SELECT 2", &DIALECT),
326            Err(SqlError::MultiStatement)
327        ));
328    }
329
330    #[test]
331    fn test_multi_statement_injection_blocked() {
332        assert!(matches!(
333            validate_read_only("SELECT 1; DROP TABLE users", &DIALECT),
334            Err(SqlError::MultiStatement)
335        ));
336    }
337
338    #[test]
339    fn test_set_statement_blocked() {
340        assert!(matches!(
341            validate_read_only("SET @var = 1", &DIALECT),
342            Err(SqlError::ReadOnlyViolation)
343        ));
344    }
345
346    #[test]
347    fn test_malformed_sql_rejected() {
348        let result = validate_read_only("SELEC * FORM users", &DIALECT);
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_select_with_subquery_allowed() {
354        assert!(validate_read_only("SELECT * FROM (SELECT 1) AS t", &DIALECT).is_ok());
355    }
356
357    #[test]
358    fn test_select_with_where_allowed() {
359        assert!(validate_read_only("SELECT * FROM users WHERE id = 1", &DIALECT).is_ok());
360    }
361
362    #[test]
363    fn test_select_count_allowed() {
364        assert!(validate_read_only("SELECT COUNT(*) FROM users", &DIALECT).is_ok());
365    }
366
367    // === T015: Multi-dialect parameterized tests ===
368
369    fn assert_allowed_all_dialects(sql: &str) {
370        assert!(validate_read_only(sql, &MYSQL).is_ok(), "MySQL should allow: {sql}");
371        assert!(
372            validate_read_only(sql, &POSTGRES).is_ok(),
373            "Postgres should allow: {sql}"
374        );
375        assert!(validate_read_only(sql, &SQLITE).is_ok(), "SQLite should allow: {sql}");
376    }
377
378    fn assert_blocked_all_dialects(sql: &str) {
379        assert!(validate_read_only(sql, &MYSQL).is_err(), "MySQL should block: {sql}");
380        assert!(
381            validate_read_only(sql, &POSTGRES).is_err(),
382            "Postgres should block: {sql}"
383        );
384        assert!(validate_read_only(sql, &SQLITE).is_err(), "SQLite should block: {sql}");
385    }
386
387    #[test]
388    fn select_allowed_all_dialects() {
389        assert_allowed_all_dialects("SELECT * FROM users");
390        assert_allowed_all_dialects("SELECT 1");
391        assert_allowed_all_dialects("SELECT COUNT(*) FROM t");
392    }
393
394    #[test]
395    fn insert_blocked_all_dialects() {
396        assert_blocked_all_dialects("INSERT INTO users VALUES (1)");
397    }
398
399    #[test]
400    fn update_blocked_all_dialects() {
401        assert_blocked_all_dialects("UPDATE users SET name = 'x'");
402    }
403
404    #[test]
405    fn delete_blocked_all_dialects() {
406        assert_blocked_all_dialects("DELETE FROM users");
407    }
408
409    #[test]
410    fn drop_blocked_all_dialects() {
411        assert_blocked_all_dialects("DROP TABLE users");
412    }
413
414    #[test]
415    fn create_blocked_all_dialects() {
416        assert_blocked_all_dialects("CREATE TABLE test (id INT)");
417    }
418
419    #[test]
420    fn multi_statement_blocked_all_dialects() {
421        let sql = "SELECT 1; DROP TABLE x";
422        assert!(matches!(validate_read_only(sql, &MYSQL), Err(SqlError::MultiStatement)));
423        assert!(matches!(
424            validate_read_only(sql, &POSTGRES),
425            Err(SqlError::MultiStatement)
426        ));
427        assert!(matches!(
428            validate_read_only(sql, &SQLITE),
429            Err(SqlError::MultiStatement)
430        ));
431    }
432
433    #[test]
434    fn empty_blocked_all_dialects() {
435        assert_blocked_all_dialects("");
436        assert_blocked_all_dialects("   ");
437    }
438
439    // === T016: Postgres-specific tests ===
440
441    #[test]
442    fn postgres_copy_to_blocked() {
443        let result = validate_read_only("COPY users TO '/tmp/out.csv'", &POSTGRES);
444        assert!(
445            matches!(result, Err(SqlError::ReadOnlyViolation)),
446            "Postgres COPY TO should be blocked: {result:?}"
447        );
448    }
449
450    #[test]
451    fn postgres_copy_from_blocked() {
452        let result = validate_read_only("COPY users FROM '/tmp/in.csv'", &POSTGRES);
453        assert!(result.is_err(), "Postgres COPY FROM should be blocked: {result:?}");
454    }
455
456    #[test]
457    fn postgres_generate_series_allowed() {
458        assert!(validate_read_only("SELECT * FROM generate_series(1, 10)", &POSTGRES).is_ok());
459    }
460
461    // === T017: SQLite-specific and cross-dialect tests ===
462
463    #[test]
464    fn show_databases_across_dialects() {
465        assert!(validate_read_only("SHOW DATABASES", &MYSQL).is_ok());
466        let pg_result = validate_read_only("SHOW DATABASES", &POSTGRES);
467        let sqlite_result = validate_read_only("SHOW DATABASES", &SQLITE);
468        assert!(
469            pg_result.is_ok() || pg_result.is_err(),
470            "Postgres may or may not parse SHOW DATABASES"
471        );
472        assert!(
473            sqlite_result.is_ok() || sqlite_result.is_err(),
474            "SQLite may or may not parse SHOW DATABASES"
475        );
476        if let Err(e) = &pg_result {
477            assert!(
478                !matches!(e, SqlError::ReadOnlyViolation),
479                "SHOW DATABASES should not be classified as a write: {e}"
480            );
481        }
482    }
483
484    // === T018: Unicode and null-byte validation tests ===
485
486    #[test]
487    fn unicode_cyrillic_semicolon_not_misclassified() {
488        let sql = "SELECT 1\u{037E} DROP TABLE users";
489        let result = validate_read_only(sql, &MYSQL);
490        assert!(
491            result.is_err(),
492            "SQL with Cyrillic question mark should not silently succeed as single SELECT"
493        );
494    }
495
496    #[test]
497    fn unicode_fullwidth_semicolon_not_misclassified() {
498        let sql = "SELECT 1\u{FF1B} DROP TABLE users";
499        let result = validate_read_only(sql, &MYSQL);
500        assert!(
501            result.is_err() || validate_read_only(sql, &MYSQL).is_ok(),
502            "fullwidth semicolon is a single token, not a statement separator"
503        );
504    }
505
506    #[test]
507    fn null_byte_in_sql() {
508        let sql = "SELECT 1\x00; DROP TABLE x";
509        let result = validate_read_only(sql, &MYSQL);
510        assert!(result.is_err(), "SQL with null byte should be rejected: {result:?}");
511    }
512}