1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7pub mod redis;
8pub use crate::redis::RedisCache;
9
10pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
11
12#[async_trait]
14pub trait Cache: Send + Sync {
15 async fn get<T>(&self, key: &str) -> Result<Option<T>>
16 where
17 T: for<'de> Deserialize<'de> + Send;
18
19 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
20 where
21 T: Serialize + Send + Sync;
22
23 async fn delete(&self, key: &str) -> Result<()>;
24
25 async fn exists(&self, key: &str) -> Result<bool>;
26
27 async fn flush(&self) -> Result<()>;
28}
29
30#[derive(Clone)]
32struct CacheEntry {
33 data: Vec<u8>,
34 expires_at: Option<Instant>,
35}
36
37impl CacheEntry {
38 fn new(data: Vec<u8>, ttl: Option<Duration>) -> Self {
39 let expires_at = ttl.map(|d| Instant::now() + d);
40 Self { data, expires_at }
41 }
42
43 fn is_expired(&self) -> bool {
44 self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
45 }
46}
47
48pub struct MemoryCache {
50 store: Arc<RwLock<HashMap<String, CacheEntry>>>,
51 default_ttl: Option<Duration>,
52}
53
54impl MemoryCache {
55 pub fn new() -> Self {
56 Self {
57 store: Arc::new(RwLock::new(HashMap::new())),
58 default_ttl: Some(Duration::from_secs(3600)),
59 }
60 }
61
62 pub fn with_default_ttl(ttl: Duration) -> Self {
63 Self {
64 store: Arc::new(RwLock::new(HashMap::new())),
65 default_ttl: Some(ttl),
66 }
67 }
68
69 pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, f: F) -> Result<T>
71 where
72 T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
73 F: FnOnce() -> Fut + Send,
74 Fut: std::future::Future<Output = Result<T>> + Send,
75 {
76 if let Some(value) = self.get::<T>(key).await? {
78 return Ok(value);
79 }
80
81 let value = f().await?;
83 self.set(key, &value, Some(ttl)).await?;
84 Ok(value)
85 }
86
87 async fn cleanup(&self) {
89 let mut store = self.store.write().await;
90 store.retain(|_, entry| !entry.is_expired());
91 }
92}
93
94impl Default for MemoryCache {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[async_trait]
101impl Cache for MemoryCache {
102 async fn get<T>(&self, key: &str) -> Result<Option<T>>
103 where
104 T: for<'de> Deserialize<'de> + Send,
105 {
106 let store = self.store.read().await;
107
108 if let Some(entry) = store.get(key) {
109 if entry.is_expired() {
110 return Ok(None);
111 }
112
113 let value: T = serde_json::from_slice(&entry.data)?;
114 Ok(Some(value))
115 } else {
116 Ok(None)
117 }
118 }
119
120 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
121 where
122 T: Serialize + Send + Sync,
123 {
124 let data = serde_json::to_vec(value)?;
125 let ttl = ttl.or(self.default_ttl);
126 let entry = CacheEntry::new(data, ttl);
127
128 let mut store = self.store.write().await;
129 store.insert(key.to_string(), entry);
130
131 Ok(())
132 }
133
134 async fn delete(&self, key: &str) -> Result<()> {
135 let mut store = self.store.write().await;
136 store.remove(key);
137 Ok(())
138 }
139
140 async fn exists(&self, key: &str) -> Result<bool> {
141 let store = self.store.read().await;
142 Ok(store.get(key).map(|e| !e.is_expired()).unwrap_or(false))
143 }
144
145 async fn flush(&self) -> Result<()> {
146 let mut store = self.store.write().await;
147 store.clear();
148 Ok(())
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[tokio::test]
157 async fn test_set_and_get() {
158 let cache = MemoryCache::new();
159
160 cache.set("key1", &"value1", None).await.unwrap();
161 let value: Option<String> = cache.get("key1").await.unwrap();
162
163 assert_eq!(value, Some("value1".to_string()));
164 }
165
166 #[tokio::test]
167 async fn test_expiration() {
168 let cache = MemoryCache::new();
169
170 cache.set("key1", &"value1", Some(Duration::from_millis(100))).await.unwrap();
171 tokio::time::sleep(Duration::from_millis(200)).await;
172
173 let value: Option<String> = cache.get("key1").await.unwrap();
174 assert_eq!(value, None);
175 }
176
177 #[tokio::test]
178 async fn test_remember() {
179 let cache = MemoryCache::new();
180 let mut call_count = 0;
181
182 let value = cache.remember("key1", Duration::from_secs(60), || async {
183 call_count += 1;
184 Ok::<_, Box<dyn std::error::Error + Send + Sync>>("computed".to_string())
185 }).await.unwrap();
186
187 assert_eq!(value, "computed");
188 assert_eq!(call_count, 1);
189
190 let value2 = cache.remember("key1", Duration::from_secs(60), || async {
192 call_count += 1;
193 Ok::<_, Box<dyn std::error::Error + Send + Sync>>("computed".to_string())
194 }).await.unwrap();
195
196 assert_eq!(value2, "computed");
197 assert_eq!(call_count, 1); }
199}