redis_sdk/
lib.rs

1use anyhow::Error;
2use bb8::{Pool, PooledConnection};
3use bb8_redis::RedisConnectionManager;
4use lru::LruCache;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7use std::{num::NonZeroUsize, sync::Arc};
8use tokio::sync::RwLock;
9
10pub type Result<T = (), E = Error> = std::result::Result<T, E>;
11
12#[derive(Clone)]
13pub struct CacheClient {
14    pool: Pool<RedisConnectionManager>,
15    json_cache: Arc<RwLock<LruCache<String, String>>>, // In-memory cache for JSON values
16}
17
18impl CacheClient {
19    /// Initializes a new CacheClient with a connection pool and an optional in-memory cache.
20    pub async fn new(
21        url: &str,
22        db_index: Option<u8>,
23        cache_capacity: usize,
24        connection_size: u32,
25    ) -> Result<Self> {
26        let redis_url = match db_index {
27            Some(db) => format!("{}/{}", url.trim_end_matches('/'), db),
28            None => url.to_string(),
29        };
30
31        let manager = RedisConnectionManager::new(redis_url)?;
32        let pool = Pool::builder()
33            .max_size(connection_size)
34            .build(manager)
35            .await?;
36
37        // 设置 LRU 缓存容量
38        let json_cache = LruCache::new(NonZeroUsize::new(cache_capacity).unwrap());
39
40        Ok(Self {
41            pool,
42            json_cache: Arc::new(RwLock::new(json_cache)),
43        })
44    }
45
46    /// Gets a pooled Redis connection.
47    async fn get_conn(&self) -> Result<PooledConnection<'_, RedisConnectionManager>> {
48        self.pool
49            .get()
50            .await
51            .map_err(|e| Error::msg(format!("Failed to get Redis connection from pool: {}", e)))
52    }
53
54    /// 将 JSON 数据写入缓存,并同时写入 Redis
55    pub async fn set_json<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
56        let serialized = serde_json::to_string(value)?;
57
58        // 插入到 LRU 缓存中
59        self.json_cache
60            .write()
61            .await
62            .put(key.to_string(), serialized.clone());
63
64        // 写入 Redis
65        self.set(key, &serialized).await
66    }
67
68    /// 从缓存或 Redis 中获取 JSON 数据
69    /// 从缓存或 Redis 中获取 JSON 数据
70    pub async fn get_json<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
71        // 尝试从缓存中读取
72        // 尝试从缓存中读取
73        {
74            let mut json_cache = self.json_cache.write().await; // 改为写锁
75            if let Some(cached_value) = json_cache.get_mut(key) {
76                // 使用 get_mut
77                return serde_json::from_str(cached_value).map(Some).map_err(|e| {
78                    Error::msg(format!(
79                        "Failed to deserialize cached JSON for key '{}': {}",
80                        key, e
81                    ))
82                });
83            }
84        } // 锁在这里释放
85
86        // 如果缓存中没有,尝试从 Redis 获取
87        if let Some(serialized) = self.get(key).await? {
88            let value: T = serde_json::from_str(&serialized).map_err(|e| {
89                Error::msg(format!(
90                    "Failed to deserialize JSON from Redis for key '{}': {}",
91                    key, e
92                ))
93            })?;
94
95            // 将获取到的数据存入缓存
96            self.json_cache
97                .write()
98                .await
99                .put(key.to_string(), serialized);
100
101            Ok(Some(value))
102        } else {
103            Ok(None)
104        }
105    }
106
107    /// 设置键值对,带过期时间
108    pub async fn set_ex(&self, key: &str, value: &str, sec: u64) -> Result<()> {
109        let mut conn = self.get_conn().await?;
110        conn.set_ex(key, value, sec).await.map_err(Error::from)
111    }
112
113    /// 设置键值对,不带过期时间
114    pub async fn set(&self, key: &str, value: &str) -> Result<()> {
115        let mut conn = self.get_conn().await?;
116        conn.set(key, value).await.map_err(Error::from)
117    }
118
119    /// 将 JSON 数据写入缓存并设置过期时间
120    pub async fn set_json_ex<T: Serialize>(&self, key: &str, value: &T, sec: u64) -> Result<()> {
121        let serialized = serde_json::to_string(value)?;
122
123        // 将序列化后的数据写入 LRU 缓存
124        self.json_cache
125            .write()
126            .await
127            .put(key.to_string(), serialized.clone());
128
129        // 将数据写入 Redis,并设置过期时间
130        self.set_ex(key, &serialized, sec).await
131    }
132
133    /// Get a key's value.
134    pub async fn get(&self, key: &str) -> Result<Option<String>> {
135        let mut conn = self.get_conn().await?;
136        conn.get(key).await.map_err(Error::from)
137    }
138
139    /// Check if a key exists.
140    pub async fn exists(&self, key: &str) -> Result<bool> {
141        let mut conn = self.get_conn().await?;
142        conn.exists(key).await.map_err(Error::from)
143    }
144
145    /// Delete a key.
146    pub async fn delete(&self, key: &str) -> Result<()> {
147        let mut json_cache = self.json_cache.write().await; // 获取可变引用
148        json_cache.demote(key);
149        let mut conn = self.get_conn().await?;
150        conn.del(key).await.map_err(Error::from)
151    }
152
153    /// Get all keys matching a pattern.
154    pub async fn keys(&self, pattern: &str) -> Result<Vec<String>> {
155        let mut conn = self.get_conn().await?;
156        conn.keys(pattern).await.map_err(Error::from)
157    }
158
159    /// Add JSON data to cache with a 60-second expiration time.
160    pub async fn add_cache_json<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
161        self.set_json_ex(key, value, 60).await
162    }
163
164    /// Get JSON data from cache.
165    pub async fn get_cache_json<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
166        self.get_json(key).await
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use serde::{Deserialize, Serialize};
174    use std::time::Duration;
175    use tokio::time::sleep;
176
177    #[derive(Serialize, Deserialize, Debug, PartialEq)]
178    struct TestData {
179        field1: String,
180        field2: i32,
181    }
182
183    async fn get_client() -> CacheClient {
184        CacheClient::new("redis://:123456@127.0.0.1/", None, 1000, 15)
185            .await
186            .unwrap()
187    }
188
189    #[tokio::test]
190    async fn test_set_and_get() {
191        let client = get_client().await;
192        let key = "test_key";
193        let value = "test_value";
194
195        client.set(key, value).await.unwrap();
196        let result: Option<String> = client.get(key).await.unwrap();
197        assert_eq!(result, Some(value.to_string()));
198
199        client.delete(key).await.unwrap();
200    }
201
202    #[tokio::test]
203    async fn test_set_ex_and_expire() {
204        let client = get_client().await;
205        let key = "test_key_ex";
206        let value = "test_value_ex";
207
208        client.set_ex(key, value, 1).await.unwrap();
209        let result: Option<String> = client.get(key).await.unwrap();
210        assert_eq!(result, Some(value.to_string()));
211
212        // Wait for the key to expire
213        sleep(Duration::from_secs(2)).await;
214        let expired_result: Option<String> = client.get(key).await.unwrap();
215        assert_eq!(expired_result, None);
216    }
217
218    #[tokio::test]
219    async fn test_set_and_get_json() {
220        let client = get_client().await;
221        let key = "test_json_key";
222        let data = TestData {
223            field1: "test".to_string(),
224            field2: 123,
225        };
226
227        client.set_json(key, &data).await.unwrap();
228        let result: Option<TestData> = client.get_json(key).await.unwrap();
229        assert_eq!(result, Some(data));
230
231        client.delete(key).await.unwrap();
232    }
233
234    #[tokio::test]
235    async fn test_exists() {
236        let client = get_client().await;
237        let key = "test_exists_key";
238
239        client.set(key, "value").await.unwrap();
240        let exists = client.exists(key).await.unwrap();
241        assert!(exists);
242
243        client.delete(key).await.unwrap();
244        let exists_after_delete = client.exists(key).await.unwrap();
245        assert!(!exists_after_delete);
246    }
247
248    #[tokio::test]
249    async fn test_keys_pattern() {
250        let client = get_client().await;
251        let key1 = "test_key_pattern_1";
252        let key2 = "test_key_pattern_2";
253
254        client.set(key1, "value1").await.unwrap();
255        client.set(key2, "value2").await.unwrap();
256
257        let keys = client.keys("test_key_pattern_*").await.unwrap();
258        assert!(keys.contains(&key1.to_string()));
259        assert!(keys.contains(&key2.to_string()));
260
261        client.delete(key1).await.unwrap();
262        client.delete(key2).await.unwrap();
263    }
264}