Skip to main content

alun_cache/
lib.rs

1//! 缓存模块:本地内存缓存 + Redis 缓存
2//!
3//! 通过配置 `cache.type` 切换:
4//! - `local` → 内存缓存(默认)
5//! - `redis` → Redis 缓存(需配置 redis_url)
6
7use async_trait::async_trait;
8use serde::{Serialize, de::DeserializeOwned};
9use serde_json;
10use alun_core::Result;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, Ordering};
14use parking_lot::RwLock;
15use std::time::{Instant, Duration};
16use redis::aio::ConnectionManager;
17
18// ──── 缓存 trait ────────────────────────────────────
19
20/// 统一缓存接口(本地/Redis 实现同一 trait)
21///
22/// # 示例
23///
24/// ```ignore
25/// let cache: &dyn Cache = &local_cache;
26/// cache.set::<String>("key", &"value".to_string()).await?;
27/// let val: Option<String> = cache.get("key").await?;
28/// ```
29#[async_trait]
30pub trait Cache: Send + Sync {
31    /// 读取缓存值,返回 `Ok(None)` 表示 key 不存在或已过期
32    async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> Result<Option<T>>;
33
34    /// 设置缓存(永不过期),值通过 serde_json 序列化
35    async fn set<T: Serialize + Send + Sync>(&self, key: &str, value: &T) -> Result<()>;
36
37    /// 设置缓存(指定过期秒数),到期后自动不可见
38    async fn set_ex<T: Serialize + Send + Sync>(&self, key: &str, value: &T, ttl_secs: u64) -> Result<()>;
39
40    /// 删除单个 key(不存在不报错)
41    async fn del(&self, key: &str) -> Result<()>;
42
43    /// 检查 key 是否存在且未过期
44    async fn exists(&self, key: &str) -> Result<bool>;
45
46    /// 计数器递增(key 不存在则从 0 开始),返回递增后的值
47    async fn incr(&self, key: &str, delta: i64) -> Result<i64>;
48
49    /// 获取匹配模式(glob:`*`/`?`)的所有 key
50    async fn keys(&self, pattern: &str) -> Result<Vec<String>>;
51
52    /// 删除匹配模式的所有 key,返回删除数
53    async fn delete_pattern(&self, pattern: &str) -> Result<u64>;
54
55    /// 缓存统计信息(内存缓存支持,Redis 返回全零)
56    fn stats(&self) -> CacheStats { CacheStats::default() }
57}
58
59// ──── 本地缓存条目 ──────────────────────────────────
60
61struct CacheEntry {
62    value: serde_json::Value,
63    expires_at: Option<Instant>,
64}
65
66// ──── 本地内存缓存 ─────────────────────────────────
67
68/// 缓存统计指标
69#[derive(Debug, Clone, Default)]
70pub struct CacheStats {
71    /// 缓存命中次数
72    pub hits: u64,
73    /// 缓存未命中次数
74    pub misses: u64,
75    /// 设置缓存次数
76    pub sets: u64,
77    /// 删除缓存次数
78    pub deletes: u64,
79    /// 淘汰次数
80    pub evictions: u64,
81    /// 过期清理次数
82    pub expired_cleanups: u64,
83}
84
85/// 本地内存缓存(HashMap + RwLock + TTL + 统计 + 后台清理)
86#[derive(Clone)]
87pub struct LocalCache {
88    /// 缓存数据存储(key → 条目)
89    data: Arc<RwLock<HashMap<String, CacheEntry>>>,
90    /// 最大容量(超过后 LRU 淘汰)
91    max_capacity: u64,
92    /// 默认 TTL 秒数(set 时未指定 TTL 则使用此值)
93    default_ttl_secs: u64,
94    /// 缓存统计信息(原子计数器)
95    stats: Arc<AtomicCacheStats>,
96    /// 后台清理任务的间隔秒数
97    cleanup_interval_secs: u64,
98}
99
100struct AtomicCacheStats {
101    hits: AtomicU64,
102    misses: AtomicU64,
103    sets: AtomicU64,
104    deletes: AtomicU64,
105    evictions: AtomicU64,
106    expired_cleanups: AtomicU64,
107}
108
109impl Clone for AtomicCacheStats {
110    fn clone(&self) -> Self {
111        Self {
112            hits: AtomicU64::new(self.hits.load(Ordering::Relaxed)),
113            misses: AtomicU64::new(self.misses.load(Ordering::Relaxed)),
114            sets: AtomicU64::new(self.sets.load(Ordering::Relaxed)),
115            deletes: AtomicU64::new(self.deletes.load(Ordering::Relaxed)),
116            evictions: AtomicU64::new(self.evictions.load(Ordering::Relaxed)),
117            expired_cleanups: AtomicU64::new(self.expired_cleanups.load(Ordering::Relaxed)),
118        }
119    }
120}
121
122impl LocalCache {
123    /// 创建本地内存缓存
124    ///
125    /// - `max_capacity`: 超过此容量后按 LRU 策略淘汰
126    /// - `default_ttl_secs`: 默认过期秒数(0 = 永不过期)
127    pub fn new(max_capacity: u64, default_ttl_secs: u64) -> Self {
128        Self {
129            data: Arc::new(RwLock::new(HashMap::new())),
130            max_capacity,
131            default_ttl_secs,
132            stats: Arc::new(AtomicCacheStats {
133                hits: AtomicU64::new(0),
134                misses: AtomicU64::new(0),
135                sets: AtomicU64::new(0),
136                deletes: AtomicU64::new(0),
137                evictions: AtomicU64::new(0),
138                expired_cleanups: AtomicU64::new(0),
139            }),
140            cleanup_interval_secs: 60,
141        }
142    }
143
144    pub fn with_cleanup_interval(mut self, interval_secs: u64) -> Self {
145        self.cleanup_interval_secs = interval_secs;
146        self
147    }
148
149    /// 获取缓存统计快照
150    pub fn stats(&self) -> CacheStats {
151        CacheStats {
152            hits: self.stats.hits.load(Ordering::Relaxed),
153            misses: self.stats.misses.load(Ordering::Relaxed),
154            sets: self.stats.sets.load(Ordering::Relaxed),
155            deletes: self.stats.deletes.load(Ordering::Relaxed),
156            evictions: self.stats.evictions.load(Ordering::Relaxed),
157            expired_cleanups: self.stats.expired_cleanups.load(Ordering::Relaxed),
158        }
159    }
160
161    /// 获取当前缓存条目数量
162    pub fn len(&self) -> usize {
163        self.data.read().len()
164    }
165
166    /// 缓存是否为空
167    pub fn is_empty(&self) -> bool {
168        self.data.read().is_empty()
169    }
170
171    /// 手动清理所有过期条目,返回清理数
172    pub fn cleanup_expired(&self) -> u64 {
173        let mut guard = self.data.write();
174        let expired: Vec<String> = guard.iter()
175            .filter(|(_, entry)| entry.expires_at.map_or(false, |t| Instant::now() > t))
176            .map(|(k, _)| k.clone())
177            .collect();
178        let count = expired.len() as u64;
179        for k in &expired { guard.remove(k); }
180        self.stats.expired_cleanups.fetch_add(count, Ordering::Relaxed);
181        count
182    }
183
184    /// 启动后台过期清理任务(每 `interval_secs` 秒执行一次)
185    pub fn start_background_cleanup(&self) {
186        let data = Arc::clone(&self.data);
187        let stats = Arc::clone(&self.stats);
188        let interval = self.cleanup_interval_secs;
189
190        tokio::spawn(async move {
191            loop {
192                tokio::time::sleep(Duration::from_secs(interval)).await;
193                let mut guard = data.write();
194                let now = Instant::now();
195                let expired: Vec<String> = guard.iter()
196                    .filter(|(_, entry)| entry.expires_at.map_or(false, |t| now > t))
197                    .map(|(k, _)| k.clone())
198                    .collect();
199                let count = expired.len() as u64;
200                for k in &expired { guard.remove(k); }
201                if count > 0 {
202                    stats.expired_cleanups.fetch_add(count, Ordering::Relaxed);
203                    tracing::debug!("缓存后台清理: 移除 {} 个过期条目", count);
204                }
205            }
206        });
207    }
208}
209
210#[async_trait]
211impl Cache for LocalCache {
212    async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> Result<Option<T>> {
213        let guard = self.data.read();
214        if let Some(entry) = guard.get(key) {
215            if let Some(expires) = entry.expires_at {
216                if Instant::now() > expires {
217                    drop(guard);
218                    self.data.write().remove(key);
219                    self.stats.misses.fetch_add(1, Ordering::Relaxed);
220                    return Ok(None);
221                }
222            }
223            self.stats.hits.fetch_add(1, Ordering::Relaxed);
224            let val: T = serde_json::from_value(entry.value.clone())
225                .map_err(|e| alun_core::Error::Msg(e.to_string()))?;
226            return Ok(Some(val));
227        }
228        self.stats.misses.fetch_add(1, Ordering::Relaxed);
229        Ok(None)
230    }
231
232    async fn set<T: Serialize + Send + Sync>(&self, key: &str, value: &T) -> Result<()> {
233        let v = serde_json::to_value(value)
234            .map_err(|e| alun_core::Error::Msg(e.to_string()))?;
235        let mut guard = self.data.write();
236        if self.max_capacity > 0 && guard.len() as u64 >= self.max_capacity {
237            drop(guard);
238            return Err(alun_core::Error::Msg(format!("缓存容量已达上限: {}", self.max_capacity)));
239        }
240        self.stats.sets.fetch_add(1, Ordering::Relaxed);
241        let expires_at = if self.default_ttl_secs > 0 {
242            Some(Instant::now() + Duration::from_secs(self.default_ttl_secs))
243        } else {
244            None
245        };
246        guard.insert(key.to_string(), CacheEntry { value: v, expires_at });
247        Ok(())
248    }
249
250    async fn set_ex<T: Serialize + Send + Sync>(&self, key: &str, value: &T, ttl_secs: u64) -> Result<()> {
251        let v = serde_json::to_value(value)
252            .map_err(|e| alun_core::Error::Msg(e.to_string()))?;
253        self.stats.sets.fetch_add(1, Ordering::Relaxed);
254        self.data.write().insert(key.to_string(), CacheEntry {
255            value: v,
256            expires_at: Some(Instant::now() + Duration::from_secs(ttl_secs)),
257        });
258        Ok(())
259    }
260
261    async fn del(&self, key: &str) -> Result<()> {
262        let removed = self.data.write().remove(key).is_some();
263        if removed { self.stats.deletes.fetch_add(1, Ordering::Relaxed); }
264        Ok(())
265    }
266
267    async fn exists(&self, key: &str) -> Result<bool> {
268        let guard = self.data.read();
269        let found = guard.get(key).map_or(false, |entry| {
270            entry.expires_at.map_or(true, |exp| Instant::now() <= exp)
271        });
272        if found { self.stats.hits.fetch_add(1, Ordering::Relaxed); }
273        else { self.stats.misses.fetch_add(1, Ordering::Relaxed); }
274        Ok(found)
275    }
276
277    async fn incr(&self, key: &str, delta: i64) -> Result<i64> {
278        let mut guard = self.data.write();
279        let entry = guard.entry(key.to_string()).or_insert_with(|| CacheEntry {
280            value: serde_json::Value::Number(serde_json::Number::from(0i64)),
281            expires_at: None,
282        });
283        let current = entry.value.as_i64().unwrap_or(0);
284        let new_val = current + delta;
285        entry.value = serde_json::Value::Number(serde_json::Number::from(new_val));
286        Ok(new_val)
287    }
288
289    async fn keys(&self, pattern: &str) -> Result<Vec<String>> {
290        let guard = self.data.read();
291        Ok(guard.keys()
292            .filter(|k| match_pattern(k, pattern))
293            .cloned()
294            .collect())
295    }
296
297    async fn delete_pattern(&self, pattern: &str) -> Result<u64> {
298        let mut guard = self.data.write();
299        let to_remove: Vec<String> = guard.keys()
300            .filter(|k| match_pattern(k, pattern))
301            .cloned()
302            .collect();
303        let count = to_remove.len() as u64;
304        for k in to_remove { guard.remove(&k); }
305        Ok(count)
306    }
307}
308
309// ──── Redis 缓存 ────────────────────────────────────
310
311/// Redis 缓存实现
312#[derive(Clone)]
313pub struct RedisCache {
314    /// Redis 连接管理器
315    conn: ConnectionManager,
316}
317
318impl RedisCache {
319    /// 创建 Redis 缓存(需传入已建立的连接管理器)
320    pub fn new(conn: ConnectionManager) -> Self {
321        Self { conn }
322    }
323
324    /// 从 URL 创建连接
325    pub async fn connect(url: &str) -> Result<Self> {
326        let client = redis::Client::open(url)
327            .map_err(|e| alun_core::Error::Config(format!("Redis URL 无效: {}", e)))?;
328        let conn = ConnectionManager::new(client).await
329            .map_err(|e| alun_core::Error::Config(format!("Redis 连接失败: {}", e)))?;
330        Ok(Self { conn })
331    }
332
333    fn map_err(e: redis::RedisError) -> alun_core::Error {
334        alun_core::Error::Msg(e.to_string())
335    }
336}
337
338#[async_trait]
339impl Cache for RedisCache {
340    async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> Result<Option<T>> {
341        let result: Option<String> = redis::cmd("GET")
342            .arg(key)
343            .query_async(&mut self.conn.clone())
344            .await
345            .map_err(Self::map_err)?;
346
347        if let Some(json) = result {
348            let val: T = serde_json::from_str(&json)
349                .map_err(|e| alun_core::Error::Msg(e.to_string()))?;
350            Ok(Some(val))
351        } else {
352            Ok(None)
353        }
354    }
355
356    async fn set<T: Serialize + Send + Sync>(&self, key: &str, value: &T) -> Result<()> {
357        let json = serde_json::to_string(value)
358            .map_err(|e| alun_core::Error::Msg(e.to_string()))?;
359        redis::cmd("SET")
360            .arg(key).arg(&json)
361            .query_async::<()>(&mut self.conn.clone())
362            .await
363            .map_err(Self::map_err)
364    }
365
366    async fn set_ex<T: Serialize + Send + Sync>(&self, key: &str, value: &T, ttl_secs: u64) -> Result<()> {
367        let json = serde_json::to_string(value)
368            .map_err(|e| alun_core::Error::Msg(e.to_string()))?;
369        redis::cmd("SETEX")
370            .arg(key).arg(ttl_secs).arg(&json)
371            .query_async::<()>(&mut self.conn.clone())
372            .await
373            .map_err(Self::map_err)
374    }
375
376    async fn del(&self, key: &str) -> Result<()> {
377        redis::cmd("DEL")
378            .arg(key)
379            .query_async::<()>(&mut self.conn.clone())
380            .await
381            .map_err(Self::map_err)
382    }
383
384    async fn exists(&self, key: &str) -> Result<bool> {
385        redis::cmd("EXISTS")
386            .arg(key)
387            .query_async::<i32>(&mut self.conn.clone())
388            .await
389            .map_err(Self::map_err)
390            .map(|v| v > 0)
391    }
392
393    async fn incr(&self, key: &str, delta: i64) -> Result<i64> {
394        let result: i64 = if delta == 1 {
395            redis::cmd("INCR")
396                .arg(key)
397                .query_async(&mut self.conn.clone())
398                .await
399                .map_err(Self::map_err)?
400        } else {
401            redis::cmd("INCRBY")
402                .arg(key).arg(delta)
403                .query_async(&mut self.conn.clone())
404                .await
405                .map_err(Self::map_err)?
406        };
407        Ok(result)
408    }
409
410    async fn keys(&self, pattern: &str) -> Result<Vec<String>> {
411        redis::cmd("KEYS")
412            .arg(pattern)
413            .query_async::<Vec<String>>(&mut self.conn.clone())
414            .await
415            .map_err(Self::map_err)
416    }
417
418    async fn delete_pattern(&self, pattern: &str) -> Result<u64> {
419        let keys: Vec<String> = self.keys(pattern).await?;
420        if keys.is_empty() { return Ok(0); }
421        let mut cmd = redis::cmd("DEL");
422        for k in &keys { cmd.arg(k); }
423        cmd.query_async::<u64>(&mut self.conn.clone())
424            .await
425            .map_err(Self::map_err)
426    }
427}
428
429// ──── 模式匹配 ──────────────────────────────────────
430
431fn match_pattern(key: &str, pattern: &str) -> bool {
432    if pattern.is_empty() { return key.is_empty(); }
433    match_pattern_rec(key.as_bytes(), 0, pattern.as_bytes(), 0)
434}
435
436fn match_pattern_rec(key: &[u8], ki: usize, pat: &[u8], pi: usize) -> bool {
437    if ki >= key.len() && pi >= pat.len() { return true; }
438    if pi >= pat.len() { return false; }
439    match pat[pi] {
440        b'*' => {
441            if pi + 1 >= pat.len() { return true; }
442            for nk in ki..=key.len() {
443                if match_pattern_rec(key, nk, pat, pi + 1) { return true; }
444            }
445            false
446        }
447        b'?' => {
448            ki < key.len() && match_pattern_rec(key, ki + 1, pat, pi + 1)
449        }
450        c => {
451            ki < key.len() && key[ki] == c && match_pattern_rec(key, ki + 1, pat, pi + 1)
452        }
453    }
454}
455
456// ──── 共享缓存(枚举消除 dyn 不兼容) ────────────────
457
458/// 共享缓存——枚举包装所有缓存实现,避免 `dyn Cache` 的对象安全问题
459#[derive(Clone)]
460pub enum SharedCache {
461    Local(LocalCache),
462    Redis(RedisCache),
463}
464
465#[async_trait]
466impl Cache for SharedCache {
467    async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> Result<Option<T>> {
468        match self {
469            SharedCache::Local(c) => c.get(key).await,
470            SharedCache::Redis(c) => c.get(key).await,
471        }
472    }
473
474    async fn set<T: Serialize + Send + Sync>(&self, key: &str, value: &T) -> Result<()> {
475        match self {
476            SharedCache::Local(c) => c.set(key, value).await,
477            SharedCache::Redis(c) => c.set(key, value).await,
478        }
479    }
480
481    async fn set_ex<T: Serialize + Send + Sync>(&self, key: &str, value: &T, ttl_secs: u64) -> Result<()> {
482        match self {
483            SharedCache::Local(c) => c.set_ex(key, value, ttl_secs).await,
484            SharedCache::Redis(c) => c.set_ex(key, value, ttl_secs).await,
485        }
486    }
487
488    async fn del(&self, key: &str) -> Result<()> {
489        match self {
490            SharedCache::Local(c) => c.del(key).await,
491            SharedCache::Redis(c) => c.del(key).await,
492        }
493    }
494
495    async fn exists(&self, key: &str) -> Result<bool> {
496        match self { SharedCache::Local(c) => c.exists(key).await, SharedCache::Redis(c) => c.exists(key).await }
497    }
498
499    async fn incr(&self, key: &str, delta: i64) -> Result<i64> {
500        match self { SharedCache::Local(c) => c.incr(key, delta).await, SharedCache::Redis(c) => c.incr(key, delta).await }
501    }
502
503    async fn keys(&self, pattern: &str) -> Result<Vec<String>> {
504        match self { SharedCache::Local(c) => c.keys(pattern).await, SharedCache::Redis(c) => c.keys(pattern).await }
505    }
506
507    async fn delete_pattern(&self, pattern: &str) -> Result<u64> {
508        match self { SharedCache::Local(c) => c.delete_pattern(pattern).await, SharedCache::Redis(c) => c.delete_pattern(pattern).await }
509    }
510}
511
512// ──── 工厂函数 ──────────────────────────────────────
513
514/// 从配置创建共享缓存实例
515pub async fn create_cache(cache_config: &alun_config::CacheConfig, redis_config: &alun_config::RedisConfig) -> Result<SharedCache> {
516    match cache_config.r#type.as_str() {
517        "redis" => {
518            tracing::info!("使用 Redis 缓存 url={}", redis_config.url);
519            Ok(SharedCache::Redis(RedisCache::connect(&redis_config.url).await?))
520        }
521        _ => {
522            tracing::info!("使用本地缓存 capacity={}", cache_config.max_capacity);
523            Ok(SharedCache::Local(LocalCache::new(cache_config.max_capacity, cache_config.default_ttl)))
524        }
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[tokio::test]
533    async fn test_local_cache_get_set() {
534        let c = LocalCache::new(100, 0);
535        c.set("key1", &"value1".to_string()).await.unwrap();
536        let val: Option<String> = c.get("key1").await.unwrap();
537        assert_eq!(val, Some("value1".to_string()));
538        c.del("key1").await.unwrap();
539        let val: Option<String> = c.get("key1").await.unwrap();
540        assert_eq!(val, None);
541    }
542
543    #[tokio::test]
544    async fn test_set_ex_expiration() {
545        let c = LocalCache::new(100, 0);
546        c.set_ex("temp", &"expire".to_string(), 1).await.unwrap();
547        let val: Option<String> = c.get("temp").await.unwrap();
548        assert_eq!(val, Some("expire".to_string()));
549        tokio::time::sleep(Duration::from_secs(2)).await;
550        let val: Option<String> = c.get("temp").await.unwrap();
551        assert_eq!(val, None);
552    }
553
554    #[tokio::test]
555    async fn test_incr() {
556        let c = LocalCache::new(100, 0);
557        assert_eq!(c.incr("counter", 1).await.unwrap(), 1);
558        assert_eq!(c.incr("counter", 5).await.unwrap(), 6);
559        assert_eq!(c.incr("counter", -2).await.unwrap(), 4);
560    }
561
562    #[tokio::test]
563    async fn test_keys_pattern() {
564        let c = LocalCache::new(100, 0);
565        c.set("user:1", &"alice").await.unwrap();
566        c.set("user:2", &"bob").await.unwrap();
567        c.set("order:1", &"o1").await.unwrap();
568        let keys = c.keys("user:*").await.unwrap();
569        assert_eq!(keys.len(), 2);
570        assert!(keys.contains(&"user:1".to_string()));
571        assert!(keys.contains(&"user:2".to_string()));
572    }
573
574    #[tokio::test]
575    async fn test_delete_pattern() {
576        let c = LocalCache::new(100, 0);
577        c.set("session:a", &"s1").await.unwrap();
578        c.set("session:b", &"s2").await.unwrap();
579        c.set("user:1", &"alice").await.unwrap();
580        let deleted = c.delete_pattern("session:*").await.unwrap();
581        assert_eq!(deleted, 2);
582        assert!(!c.exists("session:a").await.unwrap());
583        assert!(!c.exists("session:b").await.unwrap());
584        assert!(c.exists("user:1").await.unwrap());
585    }
586}