use crate::error::CacheError;
pub fn validate_no_dangerous_chars(input: &str, dangerous_chars: &[char], error_context: &str) -> crate::Result<()> {
for c in input.chars() {
if dangerous_chars.contains(&c) {
return Err(CacheError::InvalidInput(format!(
"{} contains dangerous character '\\u{:04x}'",
error_context, c as u32
)));
}
}
Ok(())
}
pub fn validate_not_empty(input: &str, error_context: &str) -> crate::Result<()> {
if input.is_empty() {
return Err(CacheError::InvalidInput(format!("{} cannot be empty", error_context)));
}
Ok(())
}
pub fn validate_max_length(input: &str, max_length: usize, error_context: &str) -> crate::Result<()> {
if input.len() > max_length {
return Err(CacheError::InvalidInput(format!(
"{} exceeds maximum length of {} (got {})",
error_context,
max_length,
input.len()
)));
}
Ok(())
}
pub mod redis {
pub const MAX_KEY_LENGTH: usize = 512 * 1024;
pub const DANGEROUS_CHARS: [char; 3] = ['\r', '\n', '\0'];
}
pub mod lua_script {
pub const MAX_SCRIPT_LENGTH: usize = 10 * 1024;
#[allow(dead_code)]
pub const DANGEROUS_PATTERNS: &[&str] = &[
"FLUSHALL", "FLUSHDB", "KEYS", "SHUTDOWN", "DEBUG", "CONFIG", "SAVE", "BGSAVE", "MONITOR",
];
#[allow(dead_code)]
pub fn validate_length(script: &str) -> crate::Result<()> {
super::validate_max_length(script, MAX_SCRIPT_LENGTH, "Lua script")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_no_dangerous_chars_valid() {
let dangerous_chars = ['\r', '\n', '\0'];
let result = validate_no_dangerous_chars("safe_string", &dangerous_chars, "test");
assert!(result.is_ok());
}
#[test]
fn test_validate_no_dangerous_chars_invalid() {
let dangerous_chars = ['\r', '\n', '\0'];
let result = validate_no_dangerous_chars("unsafe\nstring", &dangerous_chars, "test");
assert!(result.is_err());
}
#[test]
fn test_validate_not_empty_valid() {
let result = validate_not_empty("non_empty", "test");
assert!(result.is_ok());
}
#[test]
fn test_validate_not_empty_invalid() {
let result = validate_not_empty("", "test");
assert!(result.is_err());
}
#[test]
fn test_validate_max_length_valid() {
let result = validate_max_length("short", 100, "test");
assert!(result.is_ok());
}
#[test]
fn test_validate_max_length_invalid() {
let result = validate_max_length(&"a".repeat(101), 100, "test");
assert!(result.is_err());
}
#[test]
fn test_redis_validate_key_valid() {
let result = validate_not_empty("my_key", "Redis key");
assert!(result.is_ok());
}
#[test]
fn test_redis_validate_key_empty() {
let result = validate_not_empty("", "Redis key");
assert!(result.is_err());
}
#[test]
fn test_redis_validate_key_dangerous_chars() {
let result = validate_no_dangerous_chars("key\nwith\nnewlines", &redis::DANGEROUS_CHARS, "Redis key");
assert!(result.is_err());
}
#[test]
fn test_redis_validate_key_too_long() {
let long_key = "a".repeat(redis::MAX_KEY_LENGTH + 1);
let result = validate_max_length(&long_key, redis::MAX_KEY_LENGTH, "Redis key");
assert!(result.is_err());
}
#[test]
fn test_lua_validate_length_valid() {
let script = "return 1";
let result = lua_script::validate_length(script);
assert!(result.is_ok());
}
#[test]
fn test_lua_validate_length_too_long() {
let long_script = "a".repeat(lua_script::MAX_SCRIPT_LENGTH + 1);
let result = lua_script::validate_length(&long_script);
if result.is_ok() {
panic!(
"Expected error for script length {} > max {}, but got Ok",
long_script.len(),
lua_script::MAX_SCRIPT_LENGTH
);
}
}
}