contextual_encoder/
sql.rs1use std::fmt;
54
55use crate::engine::{encode_loop, is_unicode_noncharacter};
56
57pub fn for_sql(input: &str) -> String {
81 let mut out = String::with_capacity(input.len());
82 write_sql(&mut out, input).expect("writing to string cannot fail");
83 out
84}
85
86pub fn write_sql<W: fmt::Write>(out: &mut W, input: &str) -> fmt::Result {
90 encode_loop(out, input, needs_sql_encoding, write_sql_encoded)
91}
92
93fn needs_sql_encoding(c: char) -> bool {
94 c == '\'' || c == '\0' || is_unicode_noncharacter(c as u32)
95}
96
97fn write_sql_encoded<W: fmt::Write>(out: &mut W, c: char, _next: Option<char>) -> fmt::Result {
98 match c {
99 '\'' => out.write_str("''"),
100 '\0' => Ok(()), _ if is_unicode_noncharacter(c as u32) => out.write_char(' '),
102 _ => out.write_char(c),
103 }
104}
105
106pub fn for_sql_backslash(input: &str) -> String {
128 let mut out = String::with_capacity(input.len());
129 write_sql_backslash(&mut out, input).expect("writing to string cannot fail");
130 out
131}
132
133pub fn write_sql_backslash<W: fmt::Write>(out: &mut W, input: &str) -> fmt::Result {
137 encode_loop(
138 out,
139 input,
140 needs_sql_backslash_encoding,
141 write_sql_backslash_encoded,
142 )
143}
144
145fn needs_sql_backslash_encoding(c: char) -> bool {
146 matches!(c, '\0' | '\x08' | '\t' | '\n' | '\r' | '\x1A' | '\'' | '\\')
147 || is_unicode_noncharacter(c as u32)
148}
149
150fn write_sql_backslash_encoded<W: fmt::Write>(
151 out: &mut W,
152 c: char,
153 _next: Option<char>,
154) -> fmt::Result {
155 match c {
156 '\0' => out.write_str("\\0"),
157 '\x08' => out.write_str("\\b"),
158 '\t' => out.write_str("\\t"),
159 '\n' => out.write_str("\\n"),
160 '\r' => out.write_str("\\r"),
161 '\x1A' => out.write_str("\\Z"),
162 '\'' => out.write_str("\\'"),
163 '\\' => out.write_str("\\\\"),
164 _ if is_unicode_noncharacter(c as u32) => out.write_char(' '),
165 _ => out.write_char(c),
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
176 fn sql_passthrough() {
177 assert_eq!(for_sql("hello world"), "hello world");
178 assert_eq!(for_sql(""), "");
179 assert_eq!(for_sql("SELECT 1"), "SELECT 1");
180 assert_eq!(for_sql("café"), "café");
181 assert_eq!(for_sql("日本語"), "日本語");
182 assert_eq!(for_sql("\u{1F600}"), "\u{1F600}");
183 }
184
185 #[test]
186 fn sql_doubles_single_quote() {
187 assert_eq!(for_sql("it's"), "it''s");
188 assert_eq!(for_sql("'quoted'"), "''quoted''");
189 assert_eq!(for_sql("a''b"), "a''''b");
190 }
191
192 #[test]
193 fn sql_backslash_passes_through() {
194 assert_eq!(for_sql(r"back\slash"), r"back\slash");
195 assert_eq!(for_sql(r"a\\b"), r"a\\b");
196 }
197
198 #[test]
199 fn sql_double_quote_passes_through() {
200 assert_eq!(for_sql(r#"a"b"#), r#"a"b"#);
201 }
202
203 #[test]
204 fn sql_removes_nul() {
205 assert_eq!(for_sql("before\x00after"), "beforeafter");
206 assert_eq!(for_sql("\x00"), "");
207 assert_eq!(for_sql("\x00\x00"), "");
208 }
209
210 #[test]
211 fn sql_control_chars_pass_through() {
212 assert_eq!(for_sql("\t"), "\t");
215 assert_eq!(for_sql("\n"), "\n");
216 assert_eq!(for_sql("\r"), "\r");
217 assert_eq!(for_sql("\x08"), "\x08");
218 }
219
220 #[test]
221 fn sql_nonchars_replaced() {
222 assert_eq!(for_sql("\u{FDD0}"), " ");
223 assert_eq!(for_sql("\u{FFFE}"), " ");
224 assert_eq!(for_sql("\u{1FFFE}"), " ");
225 }
226
227 #[test]
228 fn sql_injection_attempt() {
229 assert_eq!(
230 for_sql("'; DROP TABLE users; --"),
231 "''; DROP TABLE users; --"
232 );
233 }
234
235 #[test]
236 fn sql_writer_matches() {
237 let input = "test\x00'escape' café\u{FDD0}";
238 let mut w = String::new();
239 write_sql(&mut w, input).unwrap();
240 assert_eq!(for_sql(input), w);
241 }
242
243 #[test]
246 fn backslash_passthrough() {
247 assert_eq!(for_sql_backslash("hello world"), "hello world");
248 assert_eq!(for_sql_backslash(""), "");
249 assert_eq!(for_sql_backslash("SELECT 1"), "SELECT 1");
250 assert_eq!(for_sql_backslash("café"), "café");
251 assert_eq!(for_sql_backslash("日本語"), "日本語");
252 assert_eq!(for_sql_backslash("\u{1F600}"), "\u{1F600}");
253 }
254
255 #[test]
256 fn backslash_escapes_single_quote() {
257 assert_eq!(for_sql_backslash("it's"), r"it\'s");
258 assert_eq!(for_sql_backslash("'quoted'"), r"\'quoted\'");
259 }
260
261 #[test]
262 fn backslash_escapes_backslash() {
263 assert_eq!(for_sql_backslash(r"a\b"), r"a\\b");
264 assert_eq!(for_sql_backslash(r"a\\b"), r"a\\\\b");
265 }
266
267 #[test]
268 fn backslash_escapes_nul() {
269 assert_eq!(for_sql_backslash("before\x00after"), r"before\0after");
270 assert_eq!(for_sql_backslash("\x00"), r"\0");
271 }
272
273 #[test]
274 fn backslash_escapes_newline() {
275 assert_eq!(for_sql_backslash("line\nbreak"), r"line\nbreak");
276 }
277
278 #[test]
279 fn backslash_escapes_carriage_return() {
280 assert_eq!(for_sql_backslash("line\rbreak"), r"line\rbreak");
281 }
282
283 #[test]
284 fn backslash_escapes_tab() {
285 assert_eq!(for_sql_backslash("col\tcol"), r"col\tcol");
286 }
287
288 #[test]
289 fn backslash_escapes_backspace() {
290 assert_eq!(for_sql_backslash("a\x08b"), r"a\bb");
291 }
292
293 #[test]
294 fn backslash_escapes_control_z() {
295 assert_eq!(for_sql_backslash("a\x1Ab"), r"a\Zb");
296 }
297
298 #[test]
299 fn backslash_double_quote_passes_through() {
300 assert_eq!(for_sql_backslash(r#"a"b"#), r#"a"b"#);
301 }
302
303 #[test]
304 fn backslash_other_controls_pass_through() {
305 assert_eq!(for_sql_backslash("\x01"), "\x01");
307 assert_eq!(for_sql_backslash("\x7F"), "\x7F");
308 }
309
310 #[test]
311 fn backslash_nonchars_replaced() {
312 assert_eq!(for_sql_backslash("\u{FDD0}"), " ");
313 assert_eq!(for_sql_backslash("\u{FFFE}"), " ");
314 }
315
316 #[test]
317 fn backslash_injection_attempt() {
318 assert_eq!(
319 for_sql_backslash("'; DROP TABLE users; --"),
320 r"\'; DROP TABLE users; --"
321 );
322 }
323
324 #[test]
325 fn backslash_injection_via_backslash() {
326 assert_eq!(for_sql_backslash("\\'"), r"\\\'");
328 }
329
330 #[test]
331 fn backslash_writer_matches() {
332 let input = "test\x00\x08\t\n\r\x1A'\\café\u{FDD0}";
333 let mut w = String::new();
334 write_sql_backslash(&mut w, input).unwrap();
335 assert_eq!(for_sql_backslash(input), w);
336 }
337}