1use anyhow::Error;
2use bb8::{Pool, PooledConnection};
3use bb8_redis::RedisConnectionManager;
4use lru::LruCache;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7use std::{num::NonZeroUsize, sync::Arc};
8use tokio::sync::RwLock;
9
10pub type Result<T = (), E = Error> = std::result::Result<T, E>;
11
12#[derive(Clone)]
13pub struct CacheClient {
14 pool: Pool<RedisConnectionManager>,
15 json_cache: Arc<RwLock<LruCache<String, String>>>, }
17
18impl CacheClient {
19 pub async fn new(
21 url: &str,
22 db_index: Option<u8>,
23 cache_capacity: usize,
24 connection_size: u32,
25 ) -> Result<Self> {
26 let redis_url = match db_index {
27 Some(db) => format!("{}/{}", url.trim_end_matches('/'), db),
28 None => url.to_string(),
29 };
30
31 let manager = RedisConnectionManager::new(redis_url)?;
32 let pool = Pool::builder()
33 .max_size(connection_size)
34 .build(manager)
35 .await?;
36
37 let json_cache = LruCache::new(NonZeroUsize::new(cache_capacity).unwrap());
39
40 Ok(Self {
41 pool,
42 json_cache: Arc::new(RwLock::new(json_cache)),
43 })
44 }
45
46 async fn get_conn(&self) -> Result<PooledConnection<'_, RedisConnectionManager>> {
48 self.pool
49 .get()
50 .await
51 .map_err(|e| Error::msg(format!("Failed to get Redis connection from pool: {}", e)))
52 }
53
54 pub async fn set_json<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
56 let serialized = serde_json::to_string(value)?;
57
58 self.json_cache
60 .write()
61 .await
62 .put(key.to_string(), serialized.clone());
63
64 self.set(key, &serialized).await
66 }
67
68 pub async fn get_json<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
71 {
74 let mut json_cache = self.json_cache.write().await; if let Some(cached_value) = json_cache.get_mut(key) {
76 return serde_json::from_str(cached_value).map(Some).map_err(|e| {
78 Error::msg(format!(
79 "Failed to deserialize cached JSON for key '{}': {}",
80 key, e
81 ))
82 });
83 }
84 } if let Some(serialized) = self.get(key).await? {
88 let value: T = serde_json::from_str(&serialized).map_err(|e| {
89 Error::msg(format!(
90 "Failed to deserialize JSON from Redis for key '{}': {}",
91 key, e
92 ))
93 })?;
94
95 self.json_cache
97 .write()
98 .await
99 .put(key.to_string(), serialized);
100
101 Ok(Some(value))
102 } else {
103 Ok(None)
104 }
105 }
106
107 pub async fn set_ex(&self, key: &str, value: &str, sec: u64) -> Result<()> {
109 let mut conn = self.get_conn().await?;
110 conn.set_ex(key, value, sec).await.map_err(Error::from)
111 }
112
113 pub async fn set(&self, key: &str, value: &str) -> Result<()> {
115 let mut conn = self.get_conn().await?;
116 conn.set(key, value).await.map_err(Error::from)
117 }
118
119 pub async fn set_json_ex<T: Serialize>(&self, key: &str, value: &T, sec: u64) -> Result<()> {
121 let serialized = serde_json::to_string(value)?;
122
123 self.json_cache
125 .write()
126 .await
127 .put(key.to_string(), serialized.clone());
128
129 self.set_ex(key, &serialized, sec).await
131 }
132
133 pub async fn get(&self, key: &str) -> Result<Option<String>> {
135 let mut conn = self.get_conn().await?;
136 conn.get(key).await.map_err(Error::from)
137 }
138
139 pub async fn exists(&self, key: &str) -> Result<bool> {
141 let mut conn = self.get_conn().await?;
142 conn.exists(key).await.map_err(Error::from)
143 }
144
145 pub async fn delete(&self, key: &str) -> Result<()> {
147 let mut json_cache = self.json_cache.write().await; json_cache.demote(key);
149 let mut conn = self.get_conn().await?;
150 conn.del(key).await.map_err(Error::from)
151 }
152
153 pub async fn keys(&self, pattern: &str) -> Result<Vec<String>> {
155 let mut conn = self.get_conn().await?;
156 conn.keys(pattern).await.map_err(Error::from)
157 }
158
159 pub async fn add_cache_json<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
161 self.set_json_ex(key, value, 60).await
162 }
163
164 pub async fn get_cache_json<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
166 self.get_json(key).await
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use serde::{Deserialize, Serialize};
174 use std::time::Duration;
175 use tokio::time::sleep;
176
177 #[derive(Serialize, Deserialize, Debug, PartialEq)]
178 struct TestData {
179 field1: String,
180 field2: i32,
181 }
182
183 async fn get_client() -> CacheClient {
184 CacheClient::new("redis://:123456@127.0.0.1/", None, 1000, 15)
185 .await
186 .unwrap()
187 }
188
189 #[tokio::test]
190 async fn test_set_and_get() {
191 let client = get_client().await;
192 let key = "test_key";
193 let value = "test_value";
194
195 client.set(key, value).await.unwrap();
196 let result: Option<String> = client.get(key).await.unwrap();
197 assert_eq!(result, Some(value.to_string()));
198
199 client.delete(key).await.unwrap();
200 }
201
202 #[tokio::test]
203 async fn test_set_ex_and_expire() {
204 let client = get_client().await;
205 let key = "test_key_ex";
206 let value = "test_value_ex";
207
208 client.set_ex(key, value, 1).await.unwrap();
209 let result: Option<String> = client.get(key).await.unwrap();
210 assert_eq!(result, Some(value.to_string()));
211
212 sleep(Duration::from_secs(2)).await;
214 let expired_result: Option<String> = client.get(key).await.unwrap();
215 assert_eq!(expired_result, None);
216 }
217
218 #[tokio::test]
219 async fn test_set_and_get_json() {
220 let client = get_client().await;
221 let key = "test_json_key";
222 let data = TestData {
223 field1: "test".to_string(),
224 field2: 123,
225 };
226
227 client.set_json(key, &data).await.unwrap();
228 let result: Option<TestData> = client.get_json(key).await.unwrap();
229 assert_eq!(result, Some(data));
230
231 client.delete(key).await.unwrap();
232 }
233
234 #[tokio::test]
235 async fn test_exists() {
236 let client = get_client().await;
237 let key = "test_exists_key";
238
239 client.set(key, "value").await.unwrap();
240 let exists = client.exists(key).await.unwrap();
241 assert!(exists);
242
243 client.delete(key).await.unwrap();
244 let exists_after_delete = client.exists(key).await.unwrap();
245 assert!(!exists_after_delete);
246 }
247
248 #[tokio::test]
249 async fn test_keys_pattern() {
250 let client = get_client().await;
251 let key1 = "test_key_pattern_1";
252 let key2 = "test_key_pattern_2";
253
254 client.set(key1, "value1").await.unwrap();
255 client.set(key2, "value2").await.unwrap();
256
257 let keys = client.keys("test_key_pattern_*").await.unwrap();
258 assert!(keys.contains(&key1.to_string()));
259 assert!(keys.contains(&key2.to_string()));
260
261 client.delete(key1).await.unwrap();
262 client.delete(key2).await.unwrap();
263 }
264}