1use std::time::Duration;
41use serde::{Serialize, de::DeserializeOwned};
42use async_trait::async_trait;
43use thiserror::Error;
44
45pub mod backends;
46pub mod config;
47pub mod tagging;
48pub mod invalidation;
49pub mod warming;
50
51#[cfg(feature = "http-cache")]
52pub mod http_cache;
53
54#[cfg(feature = "http-cache")]
55pub mod middleware;
56
57pub use backends::*;
58pub use config::*;
59pub use tagging::*;
60pub use invalidation::*;
61pub use warming::*;
62
63#[derive(Error, Debug)]
65pub enum CacheError {
66 #[error("Serialization error: {0}")]
67 Serialization(#[from] serde_json::Error),
68
69 #[error("Backend error: {0}")]
70 Backend(String),
71
72 #[error("Key not found: {0}")]
73 KeyNotFound(String),
74
75 #[error("Cache configuration error: {0}")]
76 Configuration(String),
77
78 #[error("Network error: {0}")]
79 Network(String),
80
81 #[error("Timeout error")]
82 Timeout,
83}
84
85pub type CacheResult<T> = Result<T, CacheError>;
87
88pub type CacheKey = String;
90
91pub type CacheTag = String;
93
94#[async_trait]
96pub trait CacheBackend: Send + Sync {
97 async fn get(&self, key: &str) -> CacheResult<Option<Vec<u8>>>;
99
100 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> CacheResult<()>;
102
103 async fn forget(&self, key: &str) -> CacheResult<bool>;
105
106 async fn exists(&self, key: &str) -> CacheResult<bool>;
108
109 async fn flush(&self) -> CacheResult<()>;
111
112 async fn get_many(&self, keys: &[&str]) -> CacheResult<Vec<Option<Vec<u8>>>> {
114 let mut results = Vec::with_capacity(keys.len());
115 for key in keys {
116 results.push(self.get(key).await?);
117 }
118 Ok(results)
119 }
120
121 async fn put_many(&self, entries: &[(&str, Vec<u8>, Option<Duration>)]) -> CacheResult<()> {
123 for (key, value, ttl) in entries {
124 self.put(key, value.clone(), *ttl).await?;
125 }
126 Ok(())
127 }
128
129 async fn forget_many(&self, keys: &[&str]) -> CacheResult<usize> {
131 let mut removed_count = 0;
132 for key in keys {
133 if self.forget(key).await? {
134 removed_count += 1;
135 }
136 }
137 Ok(removed_count)
138 }
139
140 async fn stats(&self) -> CacheResult<CacheStats> {
142 Ok(CacheStats::default())
143 }
144}
145
146#[derive(Debug, Clone, Default)]
148pub struct CacheStats {
149 pub hits: u64,
150 pub misses: u64,
151 pub total_keys: u64,
152 pub memory_usage: u64,
153}
154
155impl CacheStats {
156 pub fn hit_ratio(&self) -> f64 {
157 if self.hits + self.misses == 0 {
158 0.0
159 } else {
160 self.hits as f64 / (self.hits + self.misses) as f64
161 }
162 }
163}
164
165pub struct Cache<B: CacheBackend> {
167 backend: B,
168 default_ttl: Option<Duration>,
169}
170
171impl<B: CacheBackend> Cache<B> {
172 pub fn new(backend: B) -> Self {
174 Self {
175 backend,
176 default_ttl: None,
177 }
178 }
179
180 pub fn with_default_ttl(backend: B, ttl: Duration) -> Self {
182 Self {
183 backend,
184 default_ttl: Some(ttl),
185 }
186 }
187
188 pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
190 where
191 T: DeserializeOwned,
192 {
193 match self.backend.get(key).await? {
194 Some(bytes) => {
195 let value = serde_json::from_slice(&bytes)?;
196 Ok(Some(value))
197 }
198 None => Ok(None),
199 }
200 }
201
202 pub async fn put<T>(&self, key: &str, value: &T, ttl: Duration) -> CacheResult<()>
204 where
205 T: Serialize,
206 {
207 let bytes = serde_json::to_vec(value)?;
208 self.backend.put(key, bytes, Some(ttl)).await
209 }
210
211 pub async fn put_default<T>(&self, key: &str, value: &T) -> CacheResult<()>
213 where
214 T: Serialize,
215 {
216 let bytes = serde_json::to_vec(value)?;
217 self.backend.put(key, bytes, self.default_ttl).await
218 }
219
220 pub async fn forget(&self, key: &str) -> CacheResult<bool> {
222 self.backend.forget(key).await
223 }
224
225 pub async fn exists(&self, key: &str) -> CacheResult<bool> {
227 self.backend.exists(key).await
228 }
229
230 pub async fn flush(&self) -> CacheResult<()> {
232 self.backend.flush().await
233 }
234
235 pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, compute: F) -> CacheResult<T>
237 where
238 T: Serialize + DeserializeOwned,
239 F: FnOnce() -> Fut,
240 Fut: std::future::Future<Output = T>,
241 {
242 if let Some(cached) = self.get(key).await? {
243 return Ok(cached);
244 }
245
246 let value = compute().await;
247 self.put(key, &value, ttl).await?;
248 Ok(value)
249 }
250
251 pub async fn remember_default<T, F, Fut>(&self, key: &str, compute: F) -> CacheResult<T>
253 where
254 T: Serialize + DeserializeOwned,
255 F: FnOnce() -> Fut,
256 Fut: std::future::Future<Output = T>,
257 {
258 if let Some(cached) = self.get(key).await? {
259 return Ok(cached);
260 }
261
262 let value = compute().await;
263
264 if let Some(ttl) = self.default_ttl {
265 self.put(key, &value, ttl).await?;
266 } else {
267 return Err(CacheError::Configuration("No default TTL configured".to_string()));
268 }
269
270 Ok(value)
271 }
272
273 pub async fn stats(&self) -> CacheResult<CacheStats> {
275 self.backend.stats().await
276 }
277}
278
279static GLOBAL_CACHE: std::sync::OnceLock<Box<dyn CacheBackend>> = std::sync::OnceLock::new();
281
282pub fn set_global_cache<B: CacheBackend + 'static>(backend: B) -> Result<(), Box<dyn CacheBackend>> {
284 GLOBAL_CACHE.set(Box::new(backend))
285}
286
287pub fn global_cache() -> Option<&'static Box<dyn CacheBackend>> {
289 GLOBAL_CACHE.get()
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::backends::MemoryBackend;
296 use std::time::Duration;
297
298 #[tokio::test]
299 async fn test_cache_basic_operations() {
300 let backend = MemoryBackend::new(CacheConfig::default());
301 let cache = Cache::new(backend);
302
303 cache.put("test_key", &"test_value", Duration::from_secs(60)).await.unwrap();
305 let value: Option<String> = cache.get("test_key").await.unwrap();
306 assert_eq!(value, Some("test_value".to_string()));
307
308 assert!(cache.exists("test_key").await.unwrap());
310 assert!(!cache.exists("nonexistent").await.unwrap());
311
312 assert!(cache.forget("test_key").await.unwrap());
314 let value: Option<String> = cache.get("test_key").await.unwrap();
315 assert_eq!(value, None);
316 }
317
318 #[tokio::test]
319 async fn test_cache_remember_pattern() {
320 let backend = MemoryBackend::new(CacheConfig::default());
321 let cache = Cache::new(backend);
322
323 use std::sync::Arc;
324 use std::sync::atomic::{AtomicU32, Ordering};
325
326 let call_count = Arc::new(AtomicU32::new(0));
327 let call_count_clone = call_count.clone();
328
329 let result1 = cache.remember("remember_test", Duration::from_secs(60), move || {
331 let count = call_count_clone.fetch_add(1, Ordering::Relaxed) + 1;
332 async move { format!("computed_{}", count) }
333 }).await.unwrap();
334 assert_eq!(result1, "computed_1");
335
336 let result2 = cache.remember("remember_test", Duration::from_secs(60), || async { "should_not_be_called".to_string() }).await.unwrap();
338 assert_eq!(result2, "computed_1");
339
340 assert_eq!(call_count.load(Ordering::Relaxed), 1);
342 }
343
344 #[tokio::test]
345 async fn test_cache_with_default_ttl() {
346 let backend = MemoryBackend::new(CacheConfig::default());
347 let cache = Cache::with_default_ttl(backend, Duration::from_secs(3600));
348
349 cache.put_default("default_ttl_test", &42i32).await.unwrap();
350 let value: Option<i32> = cache.get("default_ttl_test").await.unwrap();
351 assert_eq!(value, Some(42));
352 }
353}