1use crate::error::Error;
4use crate::tagged::TaggedCache;
5use async_trait::async_trait;
6use serde::{de::DeserializeOwned, Serialize};
7use std::env;
8use std::future::Future;
9use std::sync::Arc;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct CacheConfig {
15 pub default_ttl: Duration,
17 pub prefix: String,
19}
20
21impl Default for CacheConfig {
22 fn default() -> Self {
23 Self {
24 default_ttl: Duration::from_secs(3600),
25 prefix: String::new(),
26 }
27 }
28}
29
30impl CacheConfig {
31 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn from_env() -> Self {
50 let prefix = env::var("CACHE_PREFIX").unwrap_or_default();
51 let default_ttl = env::var("CACHE_TTL")
52 .ok()
53 .and_then(|v| v.parse().ok())
54 .map(Duration::from_secs)
55 .unwrap_or_else(|| Duration::from_secs(3600));
56
57 Self {
58 default_ttl,
59 prefix,
60 }
61 }
62
63 pub fn with_ttl(mut self, ttl: Duration) -> Self {
65 self.default_ttl = ttl;
66 self
67 }
68
69 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
71 self.prefix = prefix.into();
72 self
73 }
74}
75
76#[async_trait]
78pub trait CacheStore: Send + Sync {
79 async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
81
82 async fn put_raw(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<(), Error>;
84
85 async fn has(&self, key: &str) -> Result<bool, Error>;
87
88 async fn forget(&self, key: &str) -> Result<bool, Error>;
90
91 async fn flush(&self) -> Result<(), Error>;
93
94 async fn increment(&self, key: &str, value: i64) -> Result<i64, Error>;
96
97 async fn decrement(&self, key: &str, value: i64) -> Result<i64, Error>;
99
100 async fn tag_add(&self, tag: &str, key: &str) -> Result<(), Error>;
102
103 async fn tag_members(&self, tag: &str) -> Result<Vec<String>, Error>;
105
106 async fn tag_flush(&self, tag: &str) -> Result<(), Error>;
108}
109
110#[derive(Clone)]
112pub struct Cache {
113 store: Arc<dyn CacheStore>,
114 config: CacheConfig,
115}
116
117impl Cache {
118 pub fn new(store: Arc<dyn CacheStore>) -> Self {
120 Self {
121 store,
122 config: CacheConfig::default(),
123 }
124 }
125
126 pub fn with_config(store: Arc<dyn CacheStore>, config: CacheConfig) -> Self {
128 Self { store, config }
129 }
130
131 #[cfg(feature = "memory")]
133 pub fn memory() -> Self {
134 Self::new(Arc::new(crate::stores::MemoryStore::new()))
135 }
136
137 #[cfg(feature = "redis-backend")]
139 pub async fn redis(url: &str) -> Result<Self, Error> {
140 let store = crate::stores::RedisStore::new(url).await?;
141 Ok(Self::new(Arc::new(store)))
142 }
143
144 #[cfg(feature = "memory")]
161 pub async fn from_env() -> Result<Self, Error> {
162 let driver = env::var("CACHE_DRIVER").unwrap_or_else(|_| "memory".to_string());
163 let config = CacheConfig::from_env();
164
165 let store: Arc<dyn CacheStore> = match driver.as_str() {
166 #[cfg(feature = "redis-backend")]
167 "redis" => {
168 let url =
169 env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
170 Arc::new(crate::stores::RedisStore::new(&url).await?)
171 }
172 _ => {
173 let capacity = env::var("CACHE_MEMORY_CAPACITY")
175 .ok()
176 .and_then(|v| v.parse().ok())
177 .unwrap_or(10_000);
178 Arc::new(crate::stores::MemoryStore::with_capacity(capacity))
179 }
180 };
181
182 Ok(Self::with_config(store, config))
183 }
184
185 fn prefixed_key(&self, key: &str) -> String {
187 if self.config.prefix.is_empty() {
188 key.to_string()
189 } else {
190 format!("{}:{}", self.config.prefix, key)
191 }
192 }
193
194 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
196 let key = self.prefixed_key(key);
197 match self.store.get_raw(&key).await? {
198 Some(bytes) => {
199 let value = serde_json::from_slice(&bytes)
200 .map_err(|e| Error::deserialization(e.to_string()))?;
201 Ok(Some(value))
202 }
203 None => Ok(None),
204 }
205 }
206
207 pub async fn put<T: Serialize>(
209 &self,
210 key: &str,
211 value: &T,
212 ttl: Duration,
213 ) -> Result<(), Error> {
214 let key = self.prefixed_key(key);
215 let bytes = serde_json::to_vec(value).map_err(|e| Error::serialization(e.to_string()))?;
216 self.store.put_raw(&key, bytes, ttl).await
217 }
218
219 pub async fn put_default<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
221 self.put(key, value, self.config.default_ttl).await
222 }
223
224 pub async fn forever<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
226 self.put(key, value, Duration::from_secs(315_360_000)).await }
228
229 pub async fn has(&self, key: &str) -> Result<bool, Error> {
231 let key = self.prefixed_key(key);
232 self.store.has(&key).await
233 }
234
235 pub async fn forget(&self, key: &str) -> Result<bool, Error> {
237 let key = self.prefixed_key(key);
238 self.store.forget(&key).await
239 }
240
241 pub async fn flush(&self) -> Result<(), Error> {
243 self.store.flush().await
244 }
245
246 pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, f: F) -> Result<T, Error>
248 where
249 T: Serialize + DeserializeOwned,
250 F: FnOnce() -> Fut,
251 Fut: Future<Output = T>,
252 {
253 if let Some(value) = self.get(key).await? {
254 return Ok(value);
255 }
256
257 let value = f().await;
258 self.put(key, &value, ttl).await?;
259 Ok(value)
260 }
261
262 pub async fn remember_forever<T, F, Fut>(&self, key: &str, f: F) -> Result<T, Error>
264 where
265 T: Serialize + DeserializeOwned,
266 F: FnOnce() -> Fut,
267 Fut: Future<Output = T>,
268 {
269 self.remember(key, Duration::from_secs(315_360_000), f)
270 .await
271 }
272
273 pub async fn pull<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
275 let value = self.get(key).await?;
276 if value.is_some() {
277 self.forget(key).await?;
278 }
279 Ok(value)
280 }
281
282 pub async fn increment(&self, key: &str, value: i64) -> Result<i64, Error> {
284 let key = self.prefixed_key(key);
285 self.store.increment(&key, value).await
286 }
287
288 pub async fn decrement(&self, key: &str, value: i64) -> Result<i64, Error> {
290 let key = self.prefixed_key(key);
291 self.store.decrement(&key, value).await
292 }
293
294 pub fn tags(&self, tags: &[&str]) -> TaggedCache {
296 TaggedCache::new(
297 self.store.clone(),
298 tags.iter().map(|s| s.to_string()).collect(),
299 self.config.clone(),
300 )
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_cache_config_builder() {
310 let config = CacheConfig::new()
311 .with_ttl(Duration::from_secs(1800))
312 .with_prefix("myapp");
313
314 assert_eq!(config.default_ttl, Duration::from_secs(1800));
315 assert_eq!(config.prefix, "myapp");
316 }
317
318 #[test]
319 fn test_prefixed_key() {
320 let config = CacheConfig::new().with_prefix("test");
321 let cache = Cache::with_config(Arc::new(crate::stores::MemoryStore::new()), config);
322
323 assert_eq!(cache.prefixed_key("key"), "test:key");
324 }
325
326 #[test]
327 fn test_prefixed_key_no_prefix() {
328 let cache = Cache::new(Arc::new(crate::stores::MemoryStore::new()));
329 assert_eq!(cache.prefixed_key("key"), "key");
330 }
331}