use std::collections::HashMap;
use serde_json::Value;
#[must_use]
pub fn sanitize_sql(template: &str, binds: &[Value]) -> String {
let mut binds = binds.iter();
let mut sanitized = String::with_capacity(template.len() + binds.len() * 8);
for character in template.chars() {
if character == '?' {
if let Some(bind) = binds.next() {
sanitized.push_str(&sql_literal(bind));
} else {
sanitized.push('?');
}
} else {
sanitized.push(character);
}
}
sanitized
}
#[must_use]
pub fn sanitize_sql_like(input: &str) -> String {
let mut sanitized = String::with_capacity(input.len());
for character in input.chars() {
match character {
'\\' | '%' | '_' => {
sanitized.push('\\');
sanitized.push(character);
}
_ => sanitized.push(character),
}
}
sanitized
}
#[must_use]
pub fn sanitize_sql_array(values: &[Value]) -> String {
values
.iter()
.map(sql_literal)
.collect::<Vec<_>>()
.join(", ")
}
#[must_use]
pub fn sanitize_sql_hash(hash: &HashMap<String, Value>) -> String {
let mut pairs = hash.iter().collect::<Vec<_>>();
pairs.sort_by(|left, right| left.0.cmp(right.0));
pairs
.into_iter()
.map(|(key, value)| {
if value.is_null() {
format!("{key} IS NULL")
} else {
format!("{key} = {}", sql_literal(value))
}
})
.collect::<Vec<_>>()
.join(" AND ")
}
fn sql_literal(value: &Value) -> String {
match value {
Value::Null => "NULL".to_owned(),
Value::Bool(flag) => {
if *flag {
"TRUE".to_owned()
} else {
"FALSE".to_owned()
}
}
Value::Number(number) => number.to_string(),
Value::String(text) => format!("'{}'", text.replace('\'', "''")),
Value::Array(_) | Value::Object(_) => match serde_json::to_string(value) {
Ok(serialized) => format!("'{}'", serialized.replace('\'', "''")),
Err(_) => "NULL".to_owned(),
},
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use serde_json::{Value, json};
use super::{sanitize_sql, sanitize_sql_array, sanitize_sql_hash, sanitize_sql_like};
#[test]
fn sanitize_sql_replaces_placeholders_in_order() {
let sql = sanitize_sql("name = ? AND age >= ?", &[json!("Alice"), json!(21)]);
assert_eq!(sql, "name = 'Alice' AND age >= 21");
}
#[test]
fn sanitize_sql_leaves_extra_placeholders_when_binds_run_out() {
let sql = sanitize_sql("name = ? AND age = ?", &[json!("Alice")]);
assert_eq!(sql, "name = 'Alice' AND age = ?");
}
#[test]
fn sanitize_sql_ignores_extra_binds() {
let sql = sanitize_sql("id = ?", &[json!(1), json!(2)]);
assert_eq!(sql, "id = 1");
}
#[test]
fn sanitize_sql_escapes_string_quotes() {
let sql = sanitize_sql("name = ?", &[json!("O'Brien")]);
assert_eq!(sql, "name = 'O''Brien'");
}
#[test]
fn sanitize_sql_handles_null_boolean_and_numbers() {
let sql = sanitize_sql(
"deleted_at IS ? OR active = ? OR score = ?",
&[Value::Null, json!(true), json!(12.5)],
);
assert_eq!(sql, "deleted_at IS NULL OR active = TRUE OR score = 12.5");
}
#[test]
fn sanitize_sql_serializes_json_values() {
let sql = sanitize_sql("payload = ?", &[json!({"role": "admin"})]);
assert_eq!(sql, "payload = '{\"role\":\"admin\"}'");
}
#[test]
fn sanitize_sql_neutralizes_common_injection_payloads() {
let payload = "' OR 1=1 --";
let sql = sanitize_sql("name = ?", &[json!(payload)]);
assert_eq!(sql, "name = ''' OR 1=1 --'");
assert!(!sql.contains("name = ' OR 1=1 --"));
}
#[test]
fn sanitize_sql_like_escapes_percent_underscore_and_backslash() {
assert_eq!(
sanitize_sql_like("100%_done\\today"),
"100\\%\\_done\\\\today"
);
}
#[test]
fn sanitize_sql_like_leaves_safe_text_unchanged() {
assert_eq!(sanitize_sql_like("plain-text"), "plain-text");
}
#[test]
fn sanitize_sql_array_joins_values() {
let sql = sanitize_sql_array(&[json!(1), json!("Alice"), Value::Null]);
assert_eq!(sql, "1, 'Alice', NULL");
}
#[test]
fn sanitize_sql_array_returns_empty_string_for_empty_input() {
assert_eq!(sanitize_sql_array(&[]), "");
}
#[test]
fn sanitize_sql_hash_sorts_keys_for_determinism() {
let hash = HashMap::from([
("name".to_owned(), json!("Alice")),
("age".to_owned(), json!(30)),
]);
assert_eq!(sanitize_sql_hash(&hash), "age = 30 AND name = 'Alice'");
}
#[test]
fn sanitize_sql_hash_uses_is_null_for_null_values() {
let hash = HashMap::from([("deleted_at".to_owned(), Value::Null)]);
assert_eq!(sanitize_sql_hash(&hash), "deleted_at IS NULL");
}
#[test]
fn sanitize_sql_hash_escapes_injection_strings() {
let hash = HashMap::from([("name".to_owned(), json!("Robert'); DROP TABLE users;--"))]);
assert_eq!(
sanitize_sql_hash(&hash),
"name = 'Robert''); DROP TABLE users;--'"
);
}
macro_rules! sanitize_sql_case {
($name:ident, $value:expr, $expected:expr) => {
#[test]
fn $name() {
assert_eq!(sanitize_sql("value = ?", &[$value]), $expected);
}
};
}
sanitize_sql_case!(sanitize_sql_string_case, json!("hello"), "value = 'hello'");
sanitize_sql_case!(sanitize_sql_integer_case, json!(42), "value = 42");
sanitize_sql_case!(sanitize_sql_float_case, json!(2.72), "value = 2.72");
sanitize_sql_case!(sanitize_sql_false_case, json!(false), "value = FALSE");
sanitize_sql_case!(sanitize_sql_null_case, Value::Null, "value = NULL");
}