Skip to main content

database_mcp_sql/
sanitize.rs

1//! SQL quoting and validation for identifiers and literals.
2
3use crate::SqlError;
4use sqlparser::dialect::Dialect;
5
6/// Wraps `value` in the dialect's identifier quote character.
7///
8/// Derives the quote character from [`Dialect::identifier_quote_style`],
9/// falling back to `"` (ANSI double-quote) when the dialect returns `None`.
10/// Escapes internal occurrences of the quote character by doubling them.
11#[must_use]
12pub fn quote_ident(value: &str, dialect: &impl Dialect) -> String {
13    let q = dialect.identifier_quote_style(value).unwrap_or('"');
14    let mut out = String::with_capacity(value.len() + 2);
15    out.push(q);
16    for ch in value.chars() {
17        if ch == q {
18            out.push(q);
19        }
20        out.push(ch);
21    }
22    out.push(q);
23    out
24}
25
26/// Wraps `value` in single quotes for use as a SQL string literal.
27///
28/// Escapes backslashes and single quotes by doubling them. Backslash
29/// doubling is required for safety under `MySQL`'s default SQL mode,
30/// which treats `\` as an escape character inside string literals.
31#[must_use]
32pub fn quote_literal(value: &str) -> String {
33    let mut out = String::with_capacity(value.len() + 2);
34    out.push('\'');
35    for ch in value.chars() {
36        if ch == '\\' {
37            out.push('\\');
38        } else if ch == '\'' {
39            out.push('\'');
40        }
41        out.push(ch);
42    }
43    out.push('\'');
44    out
45}
46
47/// Validates that `name` is a non-empty identifier without control characters.
48///
49/// # Errors
50///
51/// Returns [`SqlError::InvalidIdentifier`] if the name is empty,
52/// whitespace-only, or contains control characters.
53pub fn validate_ident(name: &str) -> Result<(), SqlError> {
54    if name.trim().is_empty() || name.chars().any(char::is_control) {
55        return Err(SqlError::InvalidIdentifier(name.to_string()));
56    }
57    Ok(())
58}
59
60#[cfg(test)]
61mod tests {
62    use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
63
64    use super::*;
65
66    #[test]
67    fn accepts_standard_names() {
68        assert!(validate_ident("users").is_ok());
69        assert!(validate_ident("my_table").is_ok());
70        assert!(validate_ident("DB_123").is_ok());
71    }
72
73    #[test]
74    fn accepts_hyphenated_names() {
75        assert!(validate_ident("eu-docker").is_ok());
76        assert!(validate_ident("access-logs").is_ok());
77    }
78
79    #[test]
80    fn accepts_special_chars() {
81        assert!(validate_ident("my.db").is_ok());
82        assert!(validate_ident("123db").is_ok());
83        assert!(validate_ident("café").is_ok());
84        assert!(validate_ident("a b").is_ok());
85    }
86
87    #[test]
88    fn rejects_empty() {
89        assert!(validate_ident("").is_err());
90    }
91
92    #[test]
93    fn rejects_whitespace_only() {
94        assert!(validate_ident("   ").is_err());
95        assert!(validate_ident("\t").is_err());
96    }
97
98    #[test]
99    fn rejects_control_chars() {
100        assert!(validate_ident("test\x00db").is_err());
101        assert!(validate_ident("test\ndb").is_err());
102        assert!(validate_ident("test\x1Fdb").is_err());
103    }
104
105    #[test]
106    fn quote_with_postgres_dialect() {
107        let d = PostgreSqlDialect {};
108        assert_eq!(quote_ident("users", &d), "\"users\"");
109        assert_eq!(quote_ident("eu-docker", &d), "\"eu-docker\"");
110        assert_eq!(quote_ident("test\"db", &d), "\"test\"\"db\"");
111    }
112
113    #[test]
114    fn quote_with_mysql_dialect() {
115        let d = MySqlDialect {};
116        assert_eq!(quote_ident("users", &d), "`users`");
117        assert_eq!(quote_ident("test`db", &d), "`test``db`");
118    }
119
120    #[test]
121    fn quote_with_sqlite_dialect() {
122        let d = SQLiteDialect {};
123        assert_eq!(quote_ident("users", &d), "`users`");
124        assert_eq!(quote_ident("test`db", &d), "`test``db`");
125    }
126
127    #[test]
128    fn quote_literal_escapes_single_quotes() {
129        assert_eq!(quote_literal("my_db"), "'my_db'");
130        assert_eq!(quote_literal(""), "''");
131        assert_eq!(quote_literal("it's"), "'it''s'");
132        assert_eq!(quote_literal("a'b'c"), "'a''b''c'");
133    }
134
135    // === T006: validate_ident boundary tests ===
136
137    #[test]
138    fn accepts_long_identifier() {
139        let long_name: String = "a".repeat(10_000);
140        assert!(validate_ident(&long_name).is_ok());
141    }
142
143    #[test]
144    fn rejects_mixed_valid_and_control() {
145        assert!(validate_ident("valid\x00").is_err());
146        assert!(validate_ident("\x01start").is_err());
147        assert!(validate_ident("mid\x7Fdle").is_err());
148    }
149
150    #[test]
151    fn accepts_sql_injection_payload_in_ident() {
152        assert!(validate_ident("Robert'; DROP TABLE students;--").is_ok());
153    }
154
155    #[test]
156    fn accepts_emoji() {
157        assert!(validate_ident("🎉").is_ok());
158        assert!(validate_ident("table_🔥").is_ok());
159    }
160
161    #[test]
162    fn accepts_cjk() {
163        assert!(validate_ident("数据库").is_ok());
164        assert!(validate_ident("テーブル").is_ok());
165    }
166
167    // === T007: quote_ident adversarial tests ===
168
169    #[test]
170    fn quote_ident_only_backticks_mysql() {
171        let d = MySqlDialect {};
172        // Input: `` (2 backticks). Each doubled → 4, plus wrapping → 6.
173        assert_eq!(quote_ident("``", &d), "``````");
174    }
175
176    #[test]
177    fn quote_ident_only_double_quotes_postgres() {
178        let d = PostgreSqlDialect {};
179        // Input: "" (2 double-quotes). Each doubled → 4, plus wrapping → 6.
180        assert_eq!(quote_ident("\"\"", &d), "\"\"\"\"\"\"");
181    }
182
183    #[test]
184    fn quote_ident_quote_at_start_and_end() {
185        let mysql = MySqlDialect {};
186        // Input: `x` (3 chars). Backticks doubled → ``x`` plus wrapping → 7.
187        assert_eq!(quote_ident("`x`", &mysql), "```x```");
188
189        let pg = PostgreSqlDialect {};
190        assert_eq!(quote_ident("\"x\"", &pg), "\"\"\"x\"\"\"");
191    }
192
193    #[test]
194    fn quote_ident_cross_dialect_foreign_quote_passes_through() {
195        let mysql = MySqlDialect {};
196        assert_eq!(quote_ident("test\"db", &mysql), "`test\"db`");
197
198        let pg = PostgreSqlDialect {};
199        assert_eq!(quote_ident("test`db", &pg), "\"test`db\"");
200    }
201
202    #[test]
203    fn quote_ident_empty_string() {
204        let mysql = MySqlDialect {};
205        assert_eq!(quote_ident("", &mysql), "``");
206
207        let pg = PostgreSqlDialect {};
208        assert_eq!(quote_ident("", &pg), "\"\"");
209    }
210
211    #[test]
212    fn quote_ident_long_string_completes() {
213        let long_name: String = "a".repeat(10_000);
214        let pg = PostgreSqlDialect {};
215        let quoted = quote_ident(&long_name, &pg);
216        assert_eq!(quoted.len(), 10_002);
217    }
218
219    // === T008: quote_literal backslash tests ===
220
221    #[test]
222    fn quote_literal_trailing_backslash() {
223        assert_eq!(quote_literal("test\\"), "'test\\\\'");
224    }
225
226    #[test]
227    fn quote_literal_single_backslash() {
228        assert_eq!(quote_literal("\\"), "'\\\\'");
229    }
230
231    #[test]
232    fn quote_literal_backslash_then_quote() {
233        // Input: \' (2 chars). \ doubled → \\, ' doubled → ''. Wrapped: '\\'''
234        assert_eq!(quote_literal("\\'"), "'\\\\'''");
235    }
236
237    #[test]
238    fn quote_literal_only_backslashes() {
239        assert_eq!(quote_literal("\\\\\\"), "'\\\\\\\\\\\\'");
240    }
241
242    #[test]
243    fn quote_literal_sql_injection_payload() {
244        assert_eq!(
245            quote_literal("Robert'; DROP TABLE students;--"),
246            "'Robert''; DROP TABLE students;--'"
247        );
248    }
249
250    #[test]
251    fn quote_literal_many_quotes_completes() {
252        let input: String = "'".repeat(1_000);
253        let result = quote_literal(&input);
254        assert_eq!(result.len(), 2_002);
255    }
256
257    // === T009: quote_literal combined edge cases ===
258
259    #[test]
260    fn quote_literal_backslash_and_quotes_mixed() {
261        // Input: it\'s (4 chars). \ doubled, ' doubled. Wrapped: 'it\\''s'
262        assert_eq!(quote_literal("it\\'s"), "'it\\\\''s'");
263    }
264
265    #[test]
266    fn quote_literal_no_special_chars() {
267        assert_eq!(quote_literal("plain"), "'plain'");
268    }
269
270    #[test]
271    fn quote_literal_unicode_untouched() {
272        assert_eq!(quote_literal("café"), "'café'");
273        assert_eq!(quote_literal("数据"), "'数据'");
274    }
275}