#![allow(dead_code)]
use crate::error::{CacheError, Result};
const MAX_LUA_SCRIPT_LENGTH: usize = 10 * 1024;
const MAX_LUA_SCRIPT_KEYS: usize = 100;
const LUA_SCRIPT_TIMEOUT_SECS: u64 = 30;
const MAX_SCAN_PATTERN_LENGTH: usize = 256;
const MAX_SCAN_WILDCARDS: usize = 10;
const SCAN_TIMEOUT_SECS: u64 = 30;
const SCAN_COUNT_MIN: usize = 1;
const SCAN_COUNT_MAX: usize = 1000;
const ALLOWED_REDIS_COMMANDS: &[&str] = &[
"GET",
"MGET",
"SET",
"SETEX",
"PSETEX",
"MSET",
"HGET",
"HMGET",
"HGETALL",
"HSET",
"HMSET",
"LINDEX",
"LRANGE",
"LLEN",
"SISMEMBER",
"SMEMBERS",
"SCARD",
"ZSCORE",
"ZRANGE",
"ZRANGEBYSCORE",
"ZCARD",
"TTL",
"PTTL",
"EXISTS",
"UNWATCH",
];
pub fn validate_redis_key(key: &str) -> Result<()> {
if key.is_empty() {
return Err(CacheError::InvalidInput(
"Redis key cannot be empty".to_string(),
));
}
if key.len() > 512 * 1024 {
return Err(CacheError::InvalidInput(
"Redis key exceeds maximum length of 512KB".to_string(),
));
}
let dangerous_chars = ['\r', '\n', '\0'];
for c in key.chars() {
if dangerous_chars.contains(&c) {
return Err(CacheError::InvalidInput(format!(
"Redis key contains forbidden character: {:?}",
c
)));
}
}
for c in key.chars() {
if c.is_control() && !matches!(c, '\r' | '\n' | '\0' | '\t') {
return Err(CacheError::InvalidInput(format!(
"Redis key contains control character: U+{:04X}",
c as u32
)));
}
}
const SQL_INJECTION_PATTERNS: &[&str] = &[
"' OR '",
"'--",
"'; DROP",
"'; DELETE",
"'; INSERT",
"1=1",
"1=2",
"UNION SELECT",
"xp_cmdshell",
"OR 1=1",
"AND 1=1",
"' OR '1'='1",
"admin'--",
];
let key_upper = key.to_uppercase();
for pattern in SQL_INJECTION_PATTERNS {
if key_upper.contains(pattern) {
return Err(CacheError::InvalidInput(format!(
"Redis key contains suspicious SQL injection pattern: {}",
pattern
)));
}
}
const PATH_TRAVERSAL_PATTERNS: &[&str] = &[
"../",
"..\\",
"%2e%2e",
"%252e%252e",
"..%2f",
"..%5c",
"%2e%2e%2f",
"%2e%2e%5c",
];
for pattern in PATH_TRAVERSAL_PATTERNS {
if key.to_lowercase().contains(&pattern.to_lowercase()) {
return Err(CacheError::InvalidInput(format!(
"Redis key contains path traversal pattern: {}",
pattern
)));
}
}
const COMMAND_INJECTION_PATTERNS: &[&str] = &[
";", "|", "&", "$(", "`", "${", "&&", "||", ">", "<", ">", ">>", "<<", "\n", "\r", "&&",
"||",
];
for c in key.chars() {
if (c == ';' || c == '|' || c == '&' || c == '`') && key.len() > 10
{
if key.chars().take(5).any(|x| x.is_alphabetic()) {
return Err(CacheError::InvalidInput(format!(
"Redis key contains potential command injection character: {:?}",
c
)));
}
}
}
Ok(())
}
pub fn validate_lua_script(script: &str, key_count: usize) -> Result<()> {
if script.len() > MAX_LUA_SCRIPT_LENGTH {
return Err(CacheError::InvalidInput(format!(
"Lua script exceeds maximum length of {} bytes (got {} bytes)",
MAX_LUA_SCRIPT_LENGTH,
script.len()
)));
}
if key_count > MAX_LUA_SCRIPT_KEYS {
return Err(CacheError::InvalidInput(format!(
"Lua script exceeds maximum key count of {} (got {} keys)",
MAX_LUA_SCRIPT_KEYS, key_count
)));
}
let script_upper = script.to_uppercase();
let flush_commands = ["FLUSHALL", "FLUSHDB"];
for cmd in &flush_commands {
if script_upper.contains(&format!("REDIS.CALL('{}'", cmd))
|| script_upper.contains(&format!("REDIS.CALL(\"{}\"", cmd))
|| script_upper.contains(&format!("REDIS.PCALL('{}'", cmd))
|| script_upper.contains(&format!("REDIS.PCALL(\"{}\"", cmd))
{
return Err(CacheError::InvalidInput(format!(
"Lua script calls forbidden Redis command: {}",
cmd
)));
}
}
if script_upper.contains("REDIS.CALL('KEYS'")
|| script_upper.contains("REDIS.CALL(\"KEYS\"")
|| script_upper.contains("REDIS.PCALL('KEYS'")
|| script_upper.contains("REDIS.PCALL(\"KEYS\"")
{
return Err(CacheError::InvalidInput(
"Lua script contains forbidden command: KEYS".to_string(),
));
}
let dangerous_commands = [
"SHUTDOWN",
"DEBUG",
"CONFIG",
"SAVE",
"BGSAVE",
"BGREWRITEAOF",
"LASTSAVE",
"MONITOR",
"SYNC",
];
for cmd in &dangerous_commands {
if script_upper.contains(&format!("REDIS.CALL('{}'", cmd))
|| script_upper.contains(&format!("REDIS.CALL(\"{}\"", cmd))
|| script_upper.contains(&format!("REDIS.PCALL('{}'", cmd))
|| script_upper.contains(&format!("REDIS.PCALL(\"{}\"", cmd))
{
return Err(CacheError::InvalidInput(format!(
"Lua script calls forbidden Redis command: {}",
cmd
)));
}
}
if script_upper.contains("REDIS.CALL(CMD)")
|| script_upper.contains("REDIS.CALL(VAR)")
|| script_upper.contains("REDIS.CALL(COMMAND")
|| script_upper.contains("REDIS.PCALL(CMD)")
|| script_upper.contains("REDIS.PCALL(VAR)")
|| script_upper.contains("REDIS.PCALL(COMMAND")
{
return Err(CacheError::InvalidInput(
"Lua script uses dynamic command execution which is not allowed".to_string(),
));
}
if script_upper.contains("REDIS.CALL(\"") && script_upper.contains("..")
|| script_upper.contains("REDIS.CALL('") && script_upper.contains("..")
|| script_upper.contains("REDIS.PCALL(\"") && script_upper.contains("..")
|| script_upper.contains("REDIS.PCALL('") && script_upper.contains("..")
{
return Err(CacheError::InvalidInput(
"Lua script uses string concatenation for command execution which is not allowed"
.to_string(),
));
}
if script_upper.contains("REDIS.EVAL(")
|| script_upper.contains("REDIS.EVALSHA(")
|| script_upper.contains("REDIS.CALL('EVAL'")
|| script_upper.contains("REDIS.CALL(\"EVAL\"")
|| script_upper.contains("REDIS.PCALL('EVAL'")
|| script_upper.contains("REDIS.PCALL(\"EVAL\"")
{
return Err(CacheError::InvalidInput(
"Lua script contains nested redis.eval/evalsha which is not allowed".to_string(),
));
}
if script_upper.contains("WHILE TRUE")
|| script_upper.contains("WHILE 1")
|| script_upper.contains("WHILE (TRUE)")
|| script_upper.contains("WHILE (1)")
|| script_upper.contains("REPEAT")
|| script_upper.contains("GOTO")
{
return Err(CacheError::InvalidInput(
"Lua script contains potential infinite loop patterns".to_string(),
));
}
if script_upper.contains("OS.EXECUTE")
|| script_upper.contains("OS.EXEC")
|| script_upper.contains("IO.POPEN")
|| script_upper.contains("IO.OPEN")
{
return Err(CacheError::InvalidInput(
"Lua script contains system command execution which is not allowed".to_string(),
));
}
if script_upper.contains("LOADSTRING")
|| script_upper.contains("LOAD(")
|| script_upper.contains("DOFILE")
|| script_upper.contains("LOADFILE")
{
return Err(CacheError::InvalidInput(
"Lua script contains dynamic code loading which is not allowed".to_string(),
));
}
Ok(())
}
pub fn validate_scan_pattern(pattern: &str) -> Result<()> {
if pattern.len() > MAX_SCAN_PATTERN_LENGTH {
return Err(CacheError::InvalidInput(format!(
"SCAN pattern exceeds maximum length of {} characters (got {} characters)",
MAX_SCAN_PATTERN_LENGTH,
pattern.len()
)));
}
let wildcard_count = pattern.chars().filter(|c| *c == '*').count();
if wildcard_count > MAX_SCAN_WILDCARDS {
return Err(CacheError::InvalidInput(format!(
"SCAN pattern contains too many wildcards (max {}, got {})",
MAX_SCAN_WILDCARDS, wildcard_count
)));
}
Ok(())
}
pub fn clamp_scan_count(count: usize) -> usize {
count.clamp(SCAN_COUNT_MIN, SCAN_COUNT_MAX)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_redis_key_valid() {
assert!(validate_redis_key("user:123").is_ok());
assert!(validate_redis_key("cache:data:value").is_ok());
assert!(validate_redis_key("test_key").is_ok());
}
#[test]
fn test_validate_redis_key_empty() {
let result = validate_redis_key("");
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_redis_key_too_long() {
let key = "x".repeat(512 * 1024 + 1);
let result = validate_redis_key(&key);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_redis_key_contains_crlf() {
assert!(validate_redis_key("key\r\n").is_err());
assert!(validate_redis_key("key\rvalue").is_err());
assert!(validate_redis_key("key\nvalue").is_err());
}
#[test]
fn test_validate_redis_key_contains_null() {
assert!(validate_redis_key("key\0value").is_err());
}
#[test]
fn test_validate_lua_script_valid() {
let script = "return redis.call('GET', KEYS[1])";
match validate_lua_script(script, 1) {
Ok(()) => (),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
#[test]
fn test_validate_lua_script_too_long() {
let script = "x".repeat(MAX_LUA_SCRIPT_LENGTH + 1);
let result = validate_lua_script(&script, 1);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_too_many_keys() {
let script = "return redis.call('GET', KEYS[1])";
let result = validate_lua_script(script, MAX_LUA_SCRIPT_KEYS + 1);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_flushall() {
let script = "return redis.call('FLUSHALL')";
let result = validate_lua_script(script, 0);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_flushdb() {
let script = "return redis.call('FLUSHDB')";
let result = validate_lua_script(script, 0);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_keys_command() {
let script = "return redis.call('KEYS', '*')";
let result = validate_lua_script(script, 0);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_shutdown() {
let script = "return redis.call('SHUTDOWN')";
let result = validate_lua_script(script, 0);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_case_insensitive() {
let script = "return redis.call('flushall')";
let result = validate_lua_script(script, 0);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_lua_script_safe_commands() {
let script = r#"
local val = redis.call('GET', KEYS[1])
if val then
redis.call('SETEX', KEYS[2], 60, val)
end
return val
"#;
assert!(validate_lua_script(script, 2).is_ok());
}
#[test]
fn test_validate_scan_pattern_valid() {
assert!(validate_scan_pattern("user:*").is_ok());
assert!(validate_scan_pattern("session:*:data").is_ok());
assert!(validate_scan_pattern("cache?").is_ok());
}
#[test]
fn test_validate_scan_pattern_too_long() {
let pattern = "x".repeat(MAX_SCAN_PATTERN_LENGTH + 1);
let result = validate_scan_pattern(&pattern);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_scan_pattern_too_many_wildcards() {
let pattern = "*".repeat(MAX_SCAN_WILDCARDS + 1);
let result = validate_scan_pattern(&pattern);
assert!(result.is_err());
assert!(matches!(result, Err(CacheError::InvalidInput(_))));
}
#[test]
fn test_validate_scan_pattern_exact_wildcard_limit() {
let pattern = "*".repeat(MAX_SCAN_WILDCARDS);
assert!(validate_scan_pattern(&pattern).is_ok());
}
#[test]
fn test_clamp_scan_count() {
assert_eq!(clamp_scan_count(0), SCAN_COUNT_MIN);
assert_eq!(clamp_scan_count(500), 500);
assert_eq!(clamp_scan_count(1000), SCAN_COUNT_MAX);
assert_eq!(clamp_scan_count(2000), SCAN_COUNT_MAX);
}
}