rustrails-record 0.1.2

ORM layer (ActiveRecord equivalent)
Documentation
use std::collections::HashMap;

use serde_json::Value;

/// Replaces `?` placeholders in an SQL template with sanitized literal values.
#[must_use]
pub fn sanitize_sql(template: &str, binds: &[Value]) -> String {
    let mut binds = binds.iter();
    let mut sanitized = String::with_capacity(template.len() + binds.len() * 8);

    for character in template.chars() {
        if character == '?' {
            if let Some(bind) = binds.next() {
                sanitized.push_str(&sql_literal(bind));
            } else {
                sanitized.push('?');
            }
        } else {
            sanitized.push(character);
        }
    }

    sanitized
}

/// Escapes SQL `LIKE` wildcard characters using backslash escapes.
#[must_use]
pub fn sanitize_sql_like(input: &str) -> String {
    let mut sanitized = String::with_capacity(input.len());
    for character in input.chars() {
        match character {
            '\\' | '%' | '_' => {
                sanitized.push('\\');
                sanitized.push(character);
            }
            _ => sanitized.push(character),
        }
    }
    sanitized
}

/// Joins SQL literals into a comma-separated list.
#[must_use]
pub fn sanitize_sql_array(values: &[Value]) -> String {
    values
        .iter()
        .map(sql_literal)
        .collect::<Vec<_>>()
        .join(", ")
}

/// Builds deterministic `key = value` predicates joined by `AND`.
#[must_use]
pub fn sanitize_sql_hash(hash: &HashMap<String, Value>) -> String {
    let mut pairs = hash.iter().collect::<Vec<_>>();
    pairs.sort_by(|left, right| left.0.cmp(right.0));

    pairs
        .into_iter()
        .map(|(key, value)| {
            if value.is_null() {
                format!("{key} IS NULL")
            } else {
                format!("{key} = {}", sql_literal(value))
            }
        })
        .collect::<Vec<_>>()
        .join(" AND ")
}

fn sql_literal(value: &Value) -> String {
    match value {
        Value::Null => "NULL".to_owned(),
        Value::Bool(flag) => {
            if *flag {
                "TRUE".to_owned()
            } else {
                "FALSE".to_owned()
            }
        }
        Value::Number(number) => number.to_string(),
        Value::String(text) => format!("'{}'", text.replace('\'', "''")),
        Value::Array(_) | Value::Object(_) => match serde_json::to_string(value) {
            Ok(serialized) => format!("'{}'", serialized.replace('\'', "''")),
            Err(_) => "NULL".to_owned(),
        },
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use serde_json::{Value, json};

    use super::{sanitize_sql, sanitize_sql_array, sanitize_sql_hash, sanitize_sql_like};

    #[test]
    fn sanitize_sql_replaces_placeholders_in_order() {
        let sql = sanitize_sql("name = ? AND age >= ?", &[json!("Alice"), json!(21)]);

        assert_eq!(sql, "name = 'Alice' AND age >= 21");
    }

    #[test]
    fn sanitize_sql_leaves_extra_placeholders_when_binds_run_out() {
        let sql = sanitize_sql("name = ? AND age = ?", &[json!("Alice")]);
        assert_eq!(sql, "name = 'Alice' AND age = ?");
    }

    #[test]
    fn sanitize_sql_ignores_extra_binds() {
        let sql = sanitize_sql("id = ?", &[json!(1), json!(2)]);
        assert_eq!(sql, "id = 1");
    }

    #[test]
    fn sanitize_sql_escapes_string_quotes() {
        let sql = sanitize_sql("name = ?", &[json!("O'Brien")]);
        assert_eq!(sql, "name = 'O''Brien'");
    }

    #[test]
    fn sanitize_sql_handles_null_boolean_and_numbers() {
        let sql = sanitize_sql(
            "deleted_at IS ? OR active = ? OR score = ?",
            &[Value::Null, json!(true), json!(12.5)],
        );

        assert_eq!(sql, "deleted_at IS NULL OR active = TRUE OR score = 12.5");
    }

    #[test]
    fn sanitize_sql_serializes_json_values() {
        let sql = sanitize_sql("payload = ?", &[json!({"role": "admin"})]);
        assert_eq!(sql, "payload = '{\"role\":\"admin\"}'");
    }

    #[test]
    fn sanitize_sql_neutralizes_common_injection_payloads() {
        let payload = "' OR 1=1 --";
        let sql = sanitize_sql("name = ?", &[json!(payload)]);

        assert_eq!(sql, "name = ''' OR 1=1 --'");
        assert!(!sql.contains("name = ' OR 1=1 --"));
    }

    #[test]
    fn sanitize_sql_like_escapes_percent_underscore_and_backslash() {
        assert_eq!(
            sanitize_sql_like("100%_done\\today"),
            "100\\%\\_done\\\\today"
        );
    }

    #[test]
    fn sanitize_sql_like_leaves_safe_text_unchanged() {
        assert_eq!(sanitize_sql_like("plain-text"), "plain-text");
    }

    #[test]
    fn sanitize_sql_array_joins_values() {
        let sql = sanitize_sql_array(&[json!(1), json!("Alice"), Value::Null]);
        assert_eq!(sql, "1, 'Alice', NULL");
    }

    #[test]
    fn sanitize_sql_array_returns_empty_string_for_empty_input() {
        assert_eq!(sanitize_sql_array(&[]), "");
    }

    #[test]
    fn sanitize_sql_hash_sorts_keys_for_determinism() {
        let hash = HashMap::from([
            ("name".to_owned(), json!("Alice")),
            ("age".to_owned(), json!(30)),
        ]);

        assert_eq!(sanitize_sql_hash(&hash), "age = 30 AND name = 'Alice'");
    }

    #[test]
    fn sanitize_sql_hash_uses_is_null_for_null_values() {
        let hash = HashMap::from([("deleted_at".to_owned(), Value::Null)]);
        assert_eq!(sanitize_sql_hash(&hash), "deleted_at IS NULL");
    }

    #[test]
    fn sanitize_sql_hash_escapes_injection_strings() {
        let hash = HashMap::from([("name".to_owned(), json!("Robert'); DROP TABLE users;--"))]);
        assert_eq!(
            sanitize_sql_hash(&hash),
            "name = 'Robert''); DROP TABLE users;--'"
        );
    }

    macro_rules! sanitize_sql_case {
        ($name:ident, $value:expr, $expected:expr) => {
            #[test]
            fn $name() {
                assert_eq!(sanitize_sql("value = ?", &[$value]), $expected);
            }
        };
    }

    sanitize_sql_case!(sanitize_sql_string_case, json!("hello"), "value = 'hello'");
    sanitize_sql_case!(sanitize_sql_integer_case, json!(42), "value = 42");
    sanitize_sql_case!(sanitize_sql_float_case, json!(2.72), "value = 2.72");
    sanitize_sql_case!(sanitize_sql_false_case, json!(false), "value = FALSE");
    sanitize_sql_case!(sanitize_sql_null_case, Value::Null, "value = NULL");
}