use std::fmt;
use crate::engine::{encode_loop, is_unicode_noncharacter};
pub fn for_sql(input: &str) -> String {
let mut out = String::with_capacity(input.len());
write_sql(&mut out, input).expect("writing to string cannot fail");
out
}
pub fn write_sql<W: fmt::Write>(out: &mut W, input: &str) -> fmt::Result {
encode_loop(out, input, needs_sql_encoding, write_sql_encoded)
}
fn needs_sql_encoding(c: char) -> bool {
c == '\'' || c == '\0' || is_unicode_noncharacter(c as u32)
}
fn write_sql_encoded<W: fmt::Write>(out: &mut W, c: char, _next: Option<char>) -> fmt::Result {
match c {
'\'' => out.write_str("''"),
'\0' => Ok(()), _ if is_unicode_noncharacter(c as u32) => out.write_char(' '),
_ => out.write_char(c),
}
}
pub fn for_sql_backslash(input: &str) -> String {
let mut out = String::with_capacity(input.len());
write_sql_backslash(&mut out, input).expect("writing to string cannot fail");
out
}
pub fn write_sql_backslash<W: fmt::Write>(out: &mut W, input: &str) -> fmt::Result {
encode_loop(
out,
input,
needs_sql_backslash_encoding,
write_sql_backslash_encoded,
)
}
fn needs_sql_backslash_encoding(c: char) -> bool {
matches!(c, '\0' | '\x08' | '\t' | '\n' | '\r' | '\x1A' | '\'' | '\\')
|| is_unicode_noncharacter(c as u32)
}
fn write_sql_backslash_encoded<W: fmt::Write>(
out: &mut W,
c: char,
_next: Option<char>,
) -> fmt::Result {
match c {
'\0' => out.write_str("\\0"),
'\x08' => out.write_str("\\b"),
'\t' => out.write_str("\\t"),
'\n' => out.write_str("\\n"),
'\r' => out.write_str("\\r"),
'\x1A' => out.write_str("\\Z"),
'\'' => out.write_str("\\'"),
'\\' => out.write_str("\\\\"),
_ if is_unicode_noncharacter(c as u32) => out.write_char(' '),
_ => out.write_char(c),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sql_passthrough() {
assert_eq!(for_sql("hello world"), "hello world");
assert_eq!(for_sql(""), "");
assert_eq!(for_sql("SELECT 1"), "SELECT 1");
assert_eq!(for_sql("café"), "café");
assert_eq!(for_sql("日本語"), "日本語");
assert_eq!(for_sql("\u{1F600}"), "\u{1F600}");
}
#[test]
fn sql_doubles_single_quote() {
assert_eq!(for_sql("it's"), "it''s");
assert_eq!(for_sql("'quoted'"), "''quoted''");
assert_eq!(for_sql("a''b"), "a''''b");
}
#[test]
fn sql_backslash_passes_through() {
assert_eq!(for_sql(r"back\slash"), r"back\slash");
assert_eq!(for_sql(r"a\\b"), r"a\\b");
}
#[test]
fn sql_double_quote_passes_through() {
assert_eq!(for_sql(r#"a"b"#), r#"a"b"#);
}
#[test]
fn sql_removes_nul() {
assert_eq!(for_sql("before\x00after"), "beforeafter");
assert_eq!(for_sql("\x00"), "");
assert_eq!(for_sql("\x00\x00"), "");
}
#[test]
fn sql_control_chars_pass_through() {
assert_eq!(for_sql("\t"), "\t");
assert_eq!(for_sql("\n"), "\n");
assert_eq!(for_sql("\r"), "\r");
assert_eq!(for_sql("\x08"), "\x08");
}
#[test]
fn sql_nonchars_replaced() {
assert_eq!(for_sql("\u{FDD0}"), " ");
assert_eq!(for_sql("\u{FFFE}"), " ");
assert_eq!(for_sql("\u{1FFFE}"), " ");
}
#[test]
fn sql_injection_attempt() {
assert_eq!(
for_sql("'; DROP TABLE users; --"),
"''; DROP TABLE users; --"
);
}
#[test]
fn sql_writer_matches() {
let input = "test\x00'escape' café\u{FDD0}";
let mut w = String::new();
write_sql(&mut w, input).unwrap();
assert_eq!(for_sql(input), w);
}
#[test]
fn backslash_passthrough() {
assert_eq!(for_sql_backslash("hello world"), "hello world");
assert_eq!(for_sql_backslash(""), "");
assert_eq!(for_sql_backslash("SELECT 1"), "SELECT 1");
assert_eq!(for_sql_backslash("café"), "café");
assert_eq!(for_sql_backslash("日本語"), "日本語");
assert_eq!(for_sql_backslash("\u{1F600}"), "\u{1F600}");
}
#[test]
fn backslash_escapes_single_quote() {
assert_eq!(for_sql_backslash("it's"), r"it\'s");
assert_eq!(for_sql_backslash("'quoted'"), r"\'quoted\'");
}
#[test]
fn backslash_escapes_backslash() {
assert_eq!(for_sql_backslash(r"a\b"), r"a\\b");
assert_eq!(for_sql_backslash(r"a\\b"), r"a\\\\b");
}
#[test]
fn backslash_escapes_nul() {
assert_eq!(for_sql_backslash("before\x00after"), r"before\0after");
assert_eq!(for_sql_backslash("\x00"), r"\0");
}
#[test]
fn backslash_escapes_newline() {
assert_eq!(for_sql_backslash("line\nbreak"), r"line\nbreak");
}
#[test]
fn backslash_escapes_carriage_return() {
assert_eq!(for_sql_backslash("line\rbreak"), r"line\rbreak");
}
#[test]
fn backslash_escapes_tab() {
assert_eq!(for_sql_backslash("col\tcol"), r"col\tcol");
}
#[test]
fn backslash_escapes_backspace() {
assert_eq!(for_sql_backslash("a\x08b"), r"a\bb");
}
#[test]
fn backslash_escapes_control_z() {
assert_eq!(for_sql_backslash("a\x1Ab"), r"a\Zb");
}
#[test]
fn backslash_double_quote_passes_through() {
assert_eq!(for_sql_backslash(r#"a"b"#), r#"a"b"#);
}
#[test]
fn backslash_other_controls_pass_through() {
assert_eq!(for_sql_backslash("\x01"), "\x01");
assert_eq!(for_sql_backslash("\x7F"), "\x7F");
}
#[test]
fn backslash_nonchars_replaced() {
assert_eq!(for_sql_backslash("\u{FDD0}"), " ");
assert_eq!(for_sql_backslash("\u{FFFE}"), " ");
}
#[test]
fn backslash_injection_attempt() {
assert_eq!(
for_sql_backslash("'; DROP TABLE users; --"),
r"\'; DROP TABLE users; --"
);
}
#[test]
fn backslash_injection_via_backslash() {
assert_eq!(for_sql_backslash("\\'"), r"\\\'");
}
#[test]
fn backslash_writer_matches() {
let input = "test\x00\x08\t\n\r\x1A'\\café\u{FDD0}";
let mut w = String::new();
write_sql_backslash(&mut w, input).unwrap();
assert_eq!(for_sql_backslash(input), w);
}
}