Skip to main content

dbmcp_sql/
sanitize.rs

1//! SQL quoting and validation for identifiers and literals.
2
3use crate::SqlError;
4use sqlparser::dialect::Dialect;
5
6/// Wraps `value` in the dialect's identifier quote character.
7///
8/// Derives the quote character from [`Dialect::identifier_quote_style`],
9/// falling back to `"` (ANSI double-quote) when the dialect returns `None`.
10/// Escapes internal occurrences of the quote character by doubling them.
11#[must_use]
12pub fn quote_ident(value: &str, dialect: &impl Dialect) -> String {
13    let q = dialect.identifier_quote_style(value).unwrap_or('"');
14    let mut out = String::with_capacity(value.len() + 2);
15    out.push(q);
16    for ch in value.chars() {
17        if ch == q {
18            out.push(q);
19        }
20        out.push(ch);
21    }
22    out.push(q);
23    out
24}
25
26/// Wraps `value` in single quotes for use as a SQL string literal.
27///
28/// Escapes backslashes and single quotes by doubling them. Backslash
29/// doubling is required for safety under `MySQL`'s default SQL mode,
30/// which treats `\` as an escape character inside string literals.
31#[must_use]
32pub fn quote_literal(value: &str) -> String {
33    let mut out = String::with_capacity(value.len() + 2);
34    out.push('\'');
35    for ch in value.chars() {
36        if ch == '\\' {
37            out.push('\\');
38        } else if ch == '\'' {
39            out.push('\'');
40        }
41        out.push(ch);
42    }
43    out.push('\'');
44    out
45}
46
47/// Validates that `name` is a non-empty identifier without control characters.
48///
49/// Returns the input on success so the call composes inside iterator chains
50/// (e.g. `Option::map(validate_ident).transpose()?`).
51///
52/// # Errors
53///
54/// Returns [`SqlError::InvalidIdentifier`] if the name is empty,
55/// whitespace-only, or contains control characters.
56pub fn validate_ident(name: &str) -> Result<&str, SqlError> {
57    if name.trim().is_empty() || name.chars().any(char::is_control) {
58        return Err(SqlError::InvalidIdentifier(name.to_string()));
59    }
60    Ok(name)
61}
62
63#[cfg(test)]
64mod tests {
65    use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
66
67    use super::*;
68
69    #[test]
70    fn accepts_standard_names() {
71        assert!(validate_ident("users").is_ok());
72        assert!(validate_ident("my_table").is_ok());
73        assert!(validate_ident("DB_123").is_ok());
74    }
75
76    #[test]
77    fn accepts_hyphenated_names() {
78        assert!(validate_ident("eu-docker").is_ok());
79        assert!(validate_ident("access-logs").is_ok());
80    }
81
82    #[test]
83    fn accepts_special_chars() {
84        assert!(validate_ident("my.db").is_ok());
85        assert!(validate_ident("123db").is_ok());
86        assert!(validate_ident("café").is_ok());
87        assert!(validate_ident("a b").is_ok());
88    }
89
90    #[test]
91    fn rejects_empty() {
92        assert!(validate_ident("").is_err());
93    }
94
95    #[test]
96    fn rejects_whitespace_only() {
97        assert!(validate_ident("   ").is_err());
98        assert!(validate_ident("\t").is_err());
99    }
100
101    #[test]
102    fn rejects_control_chars() {
103        assert!(validate_ident("test\x00db").is_err());
104        assert!(validate_ident("test\ndb").is_err());
105        assert!(validate_ident("test\x1Fdb").is_err());
106    }
107
108    #[test]
109    fn quote_with_postgres_dialect() {
110        let d = PostgreSqlDialect {};
111        assert_eq!(quote_ident("users", &d), "\"users\"");
112        assert_eq!(quote_ident("eu-docker", &d), "\"eu-docker\"");
113        assert_eq!(quote_ident("test\"db", &d), "\"test\"\"db\"");
114    }
115
116    #[test]
117    fn quote_with_mysql_dialect() {
118        let d = MySqlDialect {};
119        assert_eq!(quote_ident("users", &d), "`users`");
120        assert_eq!(quote_ident("test`db", &d), "`test``db`");
121    }
122
123    #[test]
124    fn quote_with_sqlite_dialect() {
125        let d = SQLiteDialect {};
126        assert_eq!(quote_ident("users", &d), "`users`");
127        assert_eq!(quote_ident("test`db", &d), "`test``db`");
128    }
129
130    #[test]
131    fn quote_literal_escapes_single_quotes() {
132        assert_eq!(quote_literal("my_db"), "'my_db'");
133        assert_eq!(quote_literal(""), "''");
134        assert_eq!(quote_literal("it's"), "'it''s'");
135        assert_eq!(quote_literal("a'b'c"), "'a''b''c'");
136    }
137
138    // === T006: validate_ident boundary tests ===
139
140    #[test]
141    fn accepts_long_identifier() {
142        let long_name: String = "a".repeat(10_000);
143        assert!(validate_ident(&long_name).is_ok());
144    }
145
146    #[test]
147    fn rejects_mixed_valid_and_control() {
148        assert!(validate_ident("valid\x00").is_err());
149        assert!(validate_ident("\x01start").is_err());
150        assert!(validate_ident("mid\x7Fdle").is_err());
151    }
152
153    #[test]
154    fn accepts_sql_injection_payload_in_ident() {
155        assert!(validate_ident("Robert'; DROP TABLE students;--").is_ok());
156    }
157
158    #[test]
159    fn accepts_emoji() {
160        assert!(validate_ident("🎉").is_ok());
161        assert!(validate_ident("table_🔥").is_ok());
162    }
163
164    #[test]
165    fn accepts_cjk() {
166        assert!(validate_ident("数据库").is_ok());
167        assert!(validate_ident("テーブル").is_ok());
168    }
169
170    // === T007: quote_ident adversarial tests ===
171
172    #[test]
173    fn quote_ident_only_backticks_mysql() {
174        let d = MySqlDialect {};
175        // Input: `` (2 backticks). Each doubled → 4, plus wrapping → 6.
176        assert_eq!(quote_ident("``", &d), "``````");
177    }
178
179    #[test]
180    fn quote_ident_only_double_quotes_postgres() {
181        let d = PostgreSqlDialect {};
182        // Input: "" (2 double-quotes). Each doubled → 4, plus wrapping → 6.
183        assert_eq!(quote_ident("\"\"", &d), "\"\"\"\"\"\"");
184    }
185
186    #[test]
187    fn quote_ident_quote_at_start_and_end() {
188        let mysql = MySqlDialect {};
189        // Input: `x` (3 chars). Backticks doubled → ``x`` plus wrapping → 7.
190        assert_eq!(quote_ident("`x`", &mysql), "```x```");
191
192        let pg = PostgreSqlDialect {};
193        assert_eq!(quote_ident("\"x\"", &pg), "\"\"\"x\"\"\"");
194    }
195
196    #[test]
197    fn quote_ident_cross_dialect_foreign_quote_passes_through() {
198        let mysql = MySqlDialect {};
199        assert_eq!(quote_ident("test\"db", &mysql), "`test\"db`");
200
201        let pg = PostgreSqlDialect {};
202        assert_eq!(quote_ident("test`db", &pg), "\"test`db\"");
203    }
204
205    #[test]
206    fn quote_ident_empty_string() {
207        let mysql = MySqlDialect {};
208        assert_eq!(quote_ident("", &mysql), "``");
209
210        let pg = PostgreSqlDialect {};
211        assert_eq!(quote_ident("", &pg), "\"\"");
212    }
213
214    #[test]
215    fn quote_ident_long_string_completes() {
216        let long_name: String = "a".repeat(10_000);
217        let pg = PostgreSqlDialect {};
218        let quoted = quote_ident(&long_name, &pg);
219        assert_eq!(quoted.len(), 10_002);
220    }
221
222    // === T008: quote_literal backslash tests ===
223
224    #[test]
225    fn quote_literal_trailing_backslash() {
226        assert_eq!(quote_literal("test\\"), "'test\\\\'");
227    }
228
229    #[test]
230    fn quote_literal_single_backslash() {
231        assert_eq!(quote_literal("\\"), "'\\\\'");
232    }
233
234    #[test]
235    fn quote_literal_backslash_then_quote() {
236        // Input: \' (2 chars). \ doubled → \\, ' doubled → ''. Wrapped: '\\'''
237        assert_eq!(quote_literal("\\'"), "'\\\\'''");
238    }
239
240    #[test]
241    fn quote_literal_only_backslashes() {
242        assert_eq!(quote_literal("\\\\\\"), "'\\\\\\\\\\\\'");
243    }
244
245    #[test]
246    fn quote_literal_sql_injection_payload() {
247        assert_eq!(
248            quote_literal("Robert'; DROP TABLE students;--"),
249            "'Robert''; DROP TABLE students;--'"
250        );
251    }
252
253    #[test]
254    fn quote_literal_many_quotes_completes() {
255        let input: String = "'".repeat(1_000);
256        let result = quote_literal(&input);
257        assert_eq!(result.len(), 2_002);
258    }
259
260    // === T009: quote_literal combined edge cases ===
261
262    #[test]
263    fn quote_literal_backslash_and_quotes_mixed() {
264        // Input: it\'s (4 chars). \ doubled, ' doubled. Wrapped: 'it\\''s'
265        assert_eq!(quote_literal("it\\'s"), "'it\\\\''s'");
266    }
267
268    #[test]
269    fn quote_literal_no_special_chars() {
270        assert_eq!(quote_literal("plain"), "'plain'");
271    }
272
273    #[test]
274    fn quote_literal_unicode_untouched() {
275        assert_eq!(quote_literal("café"), "'café'");
276        assert_eq!(quote_literal("数据"), "'数据'");
277    }
278}