oxcache 0.1.4

A high-performance multi-level cache library for Rust with L1 (memory) and L2 (Redis) caching.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
#![allow(dead_code)]

//! Copyright (c) 2025-2026, Kirky.X
//!
//! MIT License
//!
//! 安全验证模块
//!
//! 提供各种安全验证功能,防止恶意输入导致的安全问题。

use crate::error::{CacheError, Result};

/// Lua 脚本最大长度 (10KB)
const MAX_LUA_SCRIPT_LENGTH: usize = 10 * 1024;

/// Lua 脚本最大键数量
const MAX_LUA_SCRIPT_KEYS: usize = 100;

/// Lua 脚本执行超时时间(秒)
const LUA_SCRIPT_TIMEOUT_SECS: u64 = 30;

/// SCAN 模式最大长度
const MAX_SCAN_PATTERN_LENGTH: usize = 256;

/// SCAN 模式最大通配符数量
const MAX_SCAN_WILDCARDS: usize = 10;

/// SCAN 操作超时时间(秒)
const SCAN_TIMEOUT_SECS: u64 = 30;

/// SCAN count 参数安全范围
const SCAN_COUNT_MIN: usize = 1;
const SCAN_COUNT_MAX: usize = 1000;

/// 允许的 Redis 命令白名单(用于 Lua 脚本验证)
///
/// 只允许安全命令,禁止危险命令如 FLUSHALL、KEYS 等
const ALLOWED_REDIS_COMMANDS: &[&str] = &[
    // 字符串命令
    "GET",
    "MGET",
    "SET",
    "SETEX",
    "PSETEX",
    "MSET",
    // 哈希命令
    "HGET",
    "HMGET",
    "HGETALL",
    "HSET",
    "HMSET",
    // 列表命令
    "LINDEX",
    "LRANGE",
    "LLEN",
    // 集合命令
    "SISMEMBER",
    "SMEMBERS",
    "SCARD",
    // 有序集合命令
    "ZSCORE",
    "ZRANGE",
    "ZRANGEBYSCORE",
    "ZCARD",
    // 过期命令
    "TTL",
    "PTTL",
    "EXISTS",
    // 事务命令
    "UNWATCH",
];

/// 验证 Redis 缓存键是否安全
///
/// 防止 Redis 命令注入和协议污染攻击。
///
/// # 验证规则
///
/// 1. 键不能为空
/// 2. 键长度不能超过 512KB
/// 3. 键不能包含危险字符(\r, \n, \0)
///
/// # 参数
///
/// * `key` - 要验证的缓存键
///
/// # 返回值
///
/// * `Ok(())` - 键是安全的
/// * `Err(CacheError::InvalidInput)` - 键包含不安全字符
pub fn validate_redis_key(key: &str) -> Result<()> {
    // 检查键长度
    if key.is_empty() {
        return Err(CacheError::InvalidInput(
            "Redis key cannot be empty".to_string(),
        ));
    }

    if key.len() > 512 * 1024 {
        // Redis最大键长512MB,但我们限制为512KB以防止滥用
        return Err(CacheError::InvalidInput(
            "Redis key exceeds maximum length of 512KB".to_string(),
        ));
    }

    // 检查是否包含危险字符
    // Redis协议使用\r\n作为分隔符,我们必须防止注入
    let dangerous_chars = ['\r', '\n', '\0'];

    for c in key.chars() {
        if dangerous_chars.contains(&c) {
            return Err(CacheError::InvalidInput(format!(
                "Redis key contains forbidden character: {:?}",
                c
            )));
        }
    }

    // ========== 安全增强 ==========

    // 检查 Unicode 控制字符(除了 \r, \n, \0 已检查)
    for c in key.chars() {
        if c.is_control() && !matches!(c, '\r' | '\n' | '\0' | '\t') {
            return Err(CacheError::InvalidInput(format!(
                "Redis key contains control character: U+{:04X}",
                c as u32
            )));
        }
    }

    // 检查 SQL 注入模式
    const SQL_INJECTION_PATTERNS: &[&str] = &[
        "' OR '",
        "'--",
        "'; DROP",
        "'; DELETE",
        "'; INSERT",
        "1=1",
        "1=2",
        "UNION SELECT",
        "xp_cmdshell",
        "OR 1=1",
        "AND 1=1",
        "' OR '1'='1",
        "admin'--",
    ];

    let key_upper = key.to_uppercase();
    for pattern in SQL_INJECTION_PATTERNS {
        if key_upper.contains(pattern) {
            return Err(CacheError::InvalidInput(format!(
                "Redis key contains suspicious SQL injection pattern: {}",
                pattern
            )));
        }
    }

    // 检查路径遍历模式
    const PATH_TRAVERSAL_PATTERNS: &[&str] = &[
        "../",
        "..\\",
        "%2e%2e",
        "%252e%252e",
        "..%2f",
        "..%5c",
        "%2e%2e%2f",
        "%2e%2e%5c",
    ];

    for pattern in PATH_TRAVERSAL_PATTERNS {
        if key.to_lowercase().contains(&pattern.to_lowercase()) {
            return Err(CacheError::InvalidInput(format!(
                "Redis key contains path traversal pattern: {}",
                pattern
            )));
        }
    }

    // 检查命令注入模式
    const COMMAND_INJECTION_PATTERNS: &[&str] = &[
        ";", "|", "&", "$(", "`", "${", "&&", "||", ">", "<", ">", ">>", "<<", "\n", "\r", "&&",
        "||",
    ];

    for c in key.chars() {
        // 检查管道和重定向符(可能被误用)
        if (c == ';' || c == '|' || c == '&' || c == '`') && key.len() > 10
        // 只有长键才检查,避免误报
        {
            // 检查是否在可能的命令上下文中
            if key.chars().take(5).any(|x| x.is_alphabetic()) {
                return Err(CacheError::InvalidInput(format!(
                    "Redis key contains potential command injection character: {:?}",
                    c
                )));
            }
        }
    }

    Ok(())
}

/// 验证 Lua 脚本
///
/// 防止恶意脚本执行危险命令或导致 Redis 阻塞。
///
/// # 验证规则
///
/// 1. 脚本长度不超过 10KB
/// 2. 键数量不超过 100
/// 3. 不包含危险的 Redis 命令(FLUSHALL, KEYS 等)
///
/// # 参数
///
/// * `script` - Lua 脚本内容
/// * `key_count` - 键数量
///
/// # 返回值
///
/// * `Ok(())` - 脚本验证通过
/// * `Err(CacheError::InvalidInput)` - 脚本验证失败
pub fn validate_lua_script(script: &str, key_count: usize) -> Result<()> {
    // 检查脚本长度
    if script.len() > MAX_LUA_SCRIPT_LENGTH {
        return Err(CacheError::InvalidInput(format!(
            "Lua script exceeds maximum length of {} bytes (got {} bytes)",
            MAX_LUA_SCRIPT_LENGTH,
            script.len()
        )));
    }

    // 检查键数量
    if key_count > MAX_LUA_SCRIPT_KEYS {
        return Err(CacheError::InvalidInput(format!(
            "Lua script exceeds maximum key count of {} (got {} keys)",
            MAX_LUA_SCRIPT_KEYS, key_count
        )));
    }

    // 将脚本转为大写进行统一检查
    let script_upper = script.to_uppercase();

    // 检查 FLUSHALL/FLUSHDB 命令 - 只有在 redis.call/pcall 中才危险
    // 但由于这些命令极其危险,我们也检查字符串字面量
    let flush_commands = ["FLUSHALL", "FLUSHDB"];
    for cmd in &flush_commands {
        if script_upper.contains(&format!("REDIS.CALL('{}'", cmd))
            || script_upper.contains(&format!("REDIS.CALL(\"{}\"", cmd))
            || script_upper.contains(&format!("REDIS.PCALL('{}'", cmd))
            || script_upper.contains(&format!("REDIS.PCALL(\"{}\"", cmd))
        {
            return Err(CacheError::InvalidInput(format!(
                "Lua script calls forbidden Redis command: {}",
                cmd
            )));
        }
    }

    // 检查 Redis KEYS 命令(可能导致阻塞)
    // 必须区分 Redis KEYS 命令和 Lua 的 KEYS 数组
    // 检查模式: redis.call('KEYS' 或 redis.call("KEYS" (后面可能有逗号或空格)
    if script_upper.contains("REDIS.CALL('KEYS'")
        || script_upper.contains("REDIS.CALL(\"KEYS\"")
        || script_upper.contains("REDIS.PCALL('KEYS'")
        || script_upper.contains("REDIS.PCALL(\"KEYS\"")
    {
        return Err(CacheError::InvalidInput(
            "Lua script contains forbidden command: KEYS".to_string(),
        ));
    }

    // 检查其他危险的管理命令
    // 这些命令只有在 redis.call/pcall 中调用时才危险
    let dangerous_commands = [
        "SHUTDOWN",
        "DEBUG",
        "CONFIG",
        "SAVE",
        "BGSAVE",
        "BGREWRITEAOF",
        "LASTSAVE",
        "MONITOR",
        "SYNC",
    ];

    for cmd in &dangerous_commands {
        // 检查是否在 redis.call 或 redis.pcall 中调用了这些命令
        // 检查模式: REDIS.CALL('CMD' 或 REDIS.CALL("CMD)(不检查后面的参数)
        if script_upper.contains(&format!("REDIS.CALL('{}'", cmd))
            || script_upper.contains(&format!("REDIS.CALL(\"{}\"", cmd))
            || script_upper.contains(&format!("REDIS.PCALL('{}'", cmd))
            || script_upper.contains(&format!("REDIS.PCALL(\"{}\"", cmd))
        {
            return Err(CacheError::InvalidInput(format!(
                "Lua script calls forbidden Redis command: {}",
                cmd
            )));
        }
    }

    // ========== 安全增强:防止绕过检查 ==========

    // 检查动态命令执行(使用变量作为命令名)
    if script_upper.contains("REDIS.CALL(CMD)")
        || script_upper.contains("REDIS.CALL(VAR)")
        || script_upper.contains("REDIS.CALL(COMMAND")
        || script_upper.contains("REDIS.PCALL(CMD)")
        || script_upper.contains("REDIS.PCALL(VAR)")
        || script_upper.contains("REDIS.PCALL(COMMAND")
    {
        return Err(CacheError::InvalidInput(
            "Lua script uses dynamic command execution which is not allowed".to_string(),
        ));
    }

    // 检查字符串拼接用于命令执行
    if script_upper.contains("REDIS.CALL(\"") && script_upper.contains("..")
        || script_upper.contains("REDIS.CALL('") && script_upper.contains("..")
        || script_upper.contains("REDIS.PCALL(\"") && script_upper.contains("..")
        || script_upper.contains("REDIS.PCALL('") && script_upper.contains("..")
    {
        return Err(CacheError::InvalidInput(
            "Lua script uses string concatenation for command execution which is not allowed"
                .to_string(),
        ));
    }

    // 检查嵌套 eval/evalsha(可能导致无限递归)
    if script_upper.contains("REDIS.EVAL(")
        || script_upper.contains("REDIS.EVALSHA(")
        || script_upper.contains("REDIS.CALL('EVAL'")
        || script_upper.contains("REDIS.CALL(\"EVAL\"")
        || script_upper.contains("REDIS.PCALL('EVAL'")
        || script_upper.contains("REDIS.PCALL(\"EVAL\"")
    {
        return Err(CacheError::InvalidInput(
            "Lua script contains nested redis.eval/evalsha which is not allowed".to_string(),
        ));
    }

    // 检查潜在的无限循环模式
    if script_upper.contains("WHILE TRUE")
        || script_upper.contains("WHILE 1")
        || script_upper.contains("WHILE (TRUE)")
        || script_upper.contains("WHILE (1)")
        || script_upper.contains("REPEAT")
        || script_upper.contains("GOTO")
    {
        return Err(CacheError::InvalidInput(
            "Lua script contains potential infinite loop patterns".to_string(),
        ));
    }

    // 检查 os.execute 和 io.popen(防止命令注入)
    if script_upper.contains("OS.EXECUTE")
        || script_upper.contains("OS.EXEC")
        || script_upper.contains("IO.POPEN")
        || script_upper.contains("IO.OPEN")
    {
        return Err(CacheError::InvalidInput(
            "Lua script contains system command execution which is not allowed".to_string(),
        ));
    }

    // 检查 loadstring/load(防止动态代码加载)
    if script_upper.contains("LOADSTRING")
        || script_upper.contains("LOAD(")
        || script_upper.contains("DOFILE")
        || script_upper.contains("LOADFILE")
    {
        return Err(CacheError::InvalidInput(
            "Lua script contains dynamic code loading which is not allowed".to_string(),
        ));
    }

    Ok(())
}

/// 验证 SCAN 模式
///
/// 防止 ReDoS(正则表达式拒绝服务)攻击。
///
/// # 验证规则
///
/// 1. 模式长度不超过 256 字符
/// 2. 通配符数量不超过 10 个
///
/// # 参数
///
/// * `pattern` - SCAN 模式字符串
///
/// # 返回值
///
/// * `Ok(())` - 模式验证通过
/// * `Err(CacheError::InvalidInput)` - 模式验证失败
pub fn validate_scan_pattern(pattern: &str) -> Result<()> {
    // 检查模式长度
    if pattern.len() > MAX_SCAN_PATTERN_LENGTH {
        return Err(CacheError::InvalidInput(format!(
            "SCAN pattern exceeds maximum length of {} characters (got {} characters)",
            MAX_SCAN_PATTERN_LENGTH,
            pattern.len()
        )));
    }

    // 计算通配符数量
    let wildcard_count = pattern.chars().filter(|c| *c == '*').count();

    if wildcard_count > MAX_SCAN_WILDCARDS {
        return Err(CacheError::InvalidInput(format!(
            "SCAN pattern contains too many wildcards (max {}, got {})",
            MAX_SCAN_WILDCARDS, wildcard_count
        )));
    }

    Ok(())
}

/// 限制 SCAN count 参数到安全范围
///
/// # 参数
///
/// * `count` - 原始 count 参数
///
/// # 返回值
///
/// 返回限制在安全范围内的 count 值(1-1000)
pub fn clamp_scan_count(count: usize) -> usize {
    count.clamp(SCAN_COUNT_MIN, SCAN_COUNT_MAX)
}

#[cfg(test)]
mod tests {
    use super::*;

    // ============================================================================
    // Redis 键验证测试
    // ============================================================================

    #[test]
    fn test_validate_redis_key_valid() {
        assert!(validate_redis_key("user:123").is_ok());
        assert!(validate_redis_key("cache:data:value").is_ok());
        assert!(validate_redis_key("test_key").is_ok());
    }

    #[test]
    fn test_validate_redis_key_empty() {
        let result = validate_redis_key("");
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_redis_key_too_long() {
        let key = "x".repeat(512 * 1024 + 1);
        let result = validate_redis_key(&key);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_redis_key_contains_crlf() {
        assert!(validate_redis_key("key\r\n").is_err());
        assert!(validate_redis_key("key\rvalue").is_err());
        assert!(validate_redis_key("key\nvalue").is_err());
    }

    #[test]
    fn test_validate_redis_key_contains_null() {
        assert!(validate_redis_key("key\0value").is_err());
    }

    // ============================================================================
    // Lua 脚本验证测试
    // ============================================================================

    #[test]
    fn test_validate_lua_script_valid() {
        let script = "return redis.call('GET', KEYS[1])";
        match validate_lua_script(script, 1) {
            Ok(()) => (),
            Err(e) => panic!("Unexpected error: {:?}", e),
        }
    }

    #[test]
    fn test_validate_lua_script_too_long() {
        let script = "x".repeat(MAX_LUA_SCRIPT_LENGTH + 1);
        let result = validate_lua_script(&script, 1);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_too_many_keys() {
        let script = "return redis.call('GET', KEYS[1])";
        let result = validate_lua_script(script, MAX_LUA_SCRIPT_KEYS + 1);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_flushall() {
        let script = "return redis.call('FLUSHALL')";
        let result = validate_lua_script(script, 0);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_flushdb() {
        let script = "return redis.call('FLUSHDB')";
        let result = validate_lua_script(script, 0);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_keys_command() {
        let script = "return redis.call('KEYS', '*')";
        let result = validate_lua_script(script, 0);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_shutdown() {
        let script = "return redis.call('SHUTDOWN')";
        let result = validate_lua_script(script, 0);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_case_insensitive() {
        let script = "return redis.call('flushall')";
        let result = validate_lua_script(script, 0);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_lua_script_safe_commands() {
        let script = r#"
            local val = redis.call('GET', KEYS[1])
            if val then
                redis.call('SETEX', KEYS[2], 60, val)
            end
            return val
        "#;
        assert!(validate_lua_script(script, 2).is_ok());
    }

    // ============================================================================
    // SCAN 模式验证测试
    // ============================================================================

    #[test]
    fn test_validate_scan_pattern_valid() {
        assert!(validate_scan_pattern("user:*").is_ok());
        assert!(validate_scan_pattern("session:*:data").is_ok());
        assert!(validate_scan_pattern("cache?").is_ok());
    }

    #[test]
    fn test_validate_scan_pattern_too_long() {
        let pattern = "x".repeat(MAX_SCAN_PATTERN_LENGTH + 1);
        let result = validate_scan_pattern(&pattern);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_scan_pattern_too_many_wildcards() {
        let pattern = "*".repeat(MAX_SCAN_WILDCARDS + 1);
        let result = validate_scan_pattern(&pattern);
        assert!(result.is_err());
        assert!(matches!(result, Err(CacheError::InvalidInput(_))));
    }

    #[test]
    fn test_validate_scan_pattern_exact_wildcard_limit() {
        let pattern = "*".repeat(MAX_SCAN_WILDCARDS);
        assert!(validate_scan_pattern(&pattern).is_ok());
    }

    #[test]
    fn test_clamp_scan_count() {
        assert_eq!(clamp_scan_count(0), SCAN_COUNT_MIN);
        assert_eq!(clamp_scan_count(500), 500);
        assert_eq!(clamp_scan_count(1000), SCAN_COUNT_MAX);
        assert_eq!(clamp_scan_count(2000), SCAN_COUNT_MAX);
    }
}