1use crate::error::CoolResult;
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use redis::aio::MultiplexedConnection;
9use redis::AsyncCommands;
10use serde::{de::DeserializeOwned, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15#[async_trait]
17pub trait CacheStore: Send + Sync {
18 async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> CoolResult<Option<T>>;
20
21 async fn set<T: Serialize + Send + Sync>(
23 &self,
24 key: &str,
25 value: &T,
26 ttl: Option<Duration>,
27 ) -> CoolResult<()>;
28
29 async fn del(&self, key: &str) -> CoolResult<()>;
31
32 async fn exists(&self, key: &str) -> CoolResult<bool>;
34
35 async fn clear(&self) -> CoolResult<()>;
37
38 async fn keys(&self, pattern: &str) -> CoolResult<Vec<String>>;
40}
41
42struct CacheItem {
44 value: String,
45 expire_at: Option<Instant>,
46}
47
48impl CacheItem {
49 fn is_expired(&self) -> bool {
50 self.expire_at
51 .map(|exp| Instant::now() > exp)
52 .unwrap_or(false)
53 }
54}
55
56pub struct MemoryCache {
58 store: Arc<RwLock<HashMap<String, CacheItem>>>,
59}
60
61impl MemoryCache {
62 pub fn new() -> Self {
63 Self {
64 store: Arc::new(RwLock::new(HashMap::new())),
65 }
66 }
67
68 pub fn cleanup(&self) {
70 let mut store = self.store.write();
71 store.retain(|_, item| !item.is_expired());
72 }
73}
74
75impl Default for MemoryCache {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[async_trait]
82impl CacheStore for MemoryCache {
83 async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> CoolResult<Option<T>> {
84 let store = self.store.read();
85 match store.get(key) {
86 Some(item) if !item.is_expired() => {
87 let value: T = serde_json::from_str(&item.value)?;
88 Ok(Some(value))
89 }
90 _ => Ok(None),
91 }
92 }
93
94 async fn set<T: Serialize + Send + Sync>(
95 &self,
96 key: &str,
97 value: &T,
98 ttl: Option<Duration>,
99 ) -> CoolResult<()> {
100 let mut store = self.store.write();
101 let value_str = serde_json::to_string(value)?;
102 let item = CacheItem {
103 value: value_str,
104 expire_at: ttl.map(|d| Instant::now() + d),
105 };
106 store.insert(key.to_string(), item);
107 Ok(())
108 }
109
110 async fn del(&self, key: &str) -> CoolResult<()> {
111 let mut store = self.store.write();
112 store.remove(key);
113 Ok(())
114 }
115
116 async fn exists(&self, key: &str) -> CoolResult<bool> {
117 let store = self.store.read();
118 match store.get(key) {
119 Some(item) => Ok(!item.is_expired()),
120 None => Ok(false),
121 }
122 }
123
124 async fn clear(&self) -> CoolResult<()> {
125 let mut store = self.store.write();
126 store.clear();
127 Ok(())
128 }
129
130 async fn keys(&self, pattern: &str) -> CoolResult<Vec<String>> {
131 let store = self.store.read();
132 let pattern = pattern.replace('*', "");
133 let keys: Vec<String> = store
134 .keys()
135 .filter(|k| {
136 if pattern.is_empty() {
137 true
138 } else {
139 k.contains(&pattern)
140 }
141 })
142 .cloned()
143 .collect();
144 Ok(keys)
145 }
146}
147
148pub struct RedisCache {
150 conn: MultiplexedConnection,
151 prefix: String,
152}
153
154impl RedisCache {
155 pub async fn new(url: &str, prefix: impl Into<String>) -> CoolResult<Self> {
156 let client = redis::Client::open(url)?;
157 let conn = client.get_multiplexed_async_connection().await?;
158 Ok(Self {
159 conn,
160 prefix: prefix.into(),
161 })
162 }
163
164 fn full_key(&self, key: &str) -> String {
165 if self.prefix.is_empty() {
166 key.to_string()
167 } else {
168 format!("{}:{}", self.prefix, key)
169 }
170 }
171}
172
173#[async_trait]
174impl CacheStore for RedisCache {
175 async fn get<T: DeserializeOwned + Send>(&self, key: &str) -> CoolResult<Option<T>> {
176 let mut conn = self.conn.clone();
177 let full_key = self.full_key(key);
178 let value: Option<String> = conn.get(&full_key).await?;
179 match value {
180 Some(v) => {
181 let result: T = serde_json::from_str(&v)?;
182 Ok(Some(result))
183 }
184 None => Ok(None),
185 }
186 }
187
188 async fn set<T: Serialize + Send + Sync>(
189 &self,
190 key: &str,
191 value: &T,
192 ttl: Option<Duration>,
193 ) -> CoolResult<()> {
194 let mut conn = self.conn.clone();
195 let full_key = self.full_key(key);
196 let value_str = serde_json::to_string(value)?;
197
198 match ttl {
199 Some(duration) => {
200 conn.set_ex::<_, _, ()>(&full_key, &value_str, duration.as_secs())
201 .await?;
202 }
203 None => {
204 conn.set::<_, _, ()>(&full_key, &value_str).await?;
205 }
206 }
207 Ok(())
208 }
209
210 async fn del(&self, key: &str) -> CoolResult<()> {
211 let mut conn = self.conn.clone();
212 let full_key = self.full_key(key);
213 conn.del::<_, ()>(&full_key).await?;
214 Ok(())
215 }
216
217 async fn exists(&self, key: &str) -> CoolResult<bool> {
218 let mut conn = self.conn.clone();
219 let full_key = self.full_key(key);
220 let exists: bool = conn.exists(&full_key).await?;
221 Ok(exists)
222 }
223
224 async fn clear(&self) -> CoolResult<()> {
225 let mut conn = self.conn.clone();
226 let pattern = self.full_key("*");
227 let keys: Vec<String> = conn.keys(&pattern).await?;
228 if !keys.is_empty() {
229 conn.del::<_, ()>(keys).await?;
230 }
231 Ok(())
232 }
233
234 async fn keys(&self, pattern: &str) -> CoolResult<Vec<String>> {
235 let mut conn = self.conn.clone();
236 let full_pattern = self.full_key(pattern);
237 let keys: Vec<String> = conn.keys(&full_pattern).await?;
238 let prefix_len = if self.prefix.is_empty() {
240 0
241 } else {
242 self.prefix.len() + 1
243 };
244 let keys: Vec<String> = keys
245 .into_iter()
246 .map(|k| k[prefix_len..].to_string())
247 .collect();
248 Ok(keys)
249 }
250}
251
252pub struct CacheFactory;
254
255impl CacheFactory {
256 pub fn memory() -> Arc<MemoryCache> {
258 Arc::new(MemoryCache::new())
259 }
260
261 pub async fn redis(url: &str, prefix: impl Into<String>) -> CoolResult<Arc<RedisCache>> {
263 let cache = RedisCache::new(url, prefix).await?;
264 Ok(Arc::new(cache))
265 }
266}