1use std::future::Future;
26use std::sync::Arc;
27use std::time::{Duration, Instant};
28use tokio::sync::RwLock;
29use crate::policy::Policy;
30
31#[derive(Clone)]
33struct CachedValue<T> {
34 value: T,
35 expires_at: Instant,
36}
37
38impl<T> CachedValue<T> {
39 fn new(value: T, ttl: Duration) -> Self {
40 Self {
41 value,
42 expires_at: Instant::now() + ttl,
43 }
44 }
45
46 fn is_expired(&self) -> bool {
47 Instant::now() >= self.expires_at
48 }
49}
50
51pub struct Cache<T> {
82 ttl: Duration,
83 cached: Arc<RwLock<Option<CachedValue<T>>>>,
84}
85
86impl<T> Clone for Cache<T> {
87 fn clone(&self) -> Self {
88 Self {
89 ttl: self.ttl,
90 cached: Arc::clone(&self.cached),
91 }
92 }
93}
94
95impl<T> Cache<T>
96where
97 T: Clone + Send + Sync,
98{
99 pub fn new(ttl: Duration) -> Self {
118 Self {
119 ttl,
120 cached: Arc::new(RwLock::new(None)),
121 }
122 }
123
124 pub async fn invalidate(&self) {
128 let mut cached = self.cached.write().await;
129 *cached = None;
130 }
131
132 pub async fn has_cached_value(&self) -> bool {
134 let cached = self.cached.read().await;
135 matches!(&*cached, Some(cv) if !cv.is_expired())
136 }
137}
138
139#[async_trait::async_trait]
140impl<T, E> Policy<E> for Cache<T>
141where
142 T: Clone + Send + Sync,
143 E: Send + Sync,
144{
145 async fn execute<F, Fut, R>(&self, f: F) -> Result<R, E>
146 where
147 F: Fn() -> Fut + Send + Sync,
148 Fut: Future<Output = Result<R, E>> + Send,
149 R: Send,
150 {
151 f().await
154 }
155}
156
157pub struct TypedCache<T> {
159 ttl: Duration,
160 cached: Arc<RwLock<Option<CachedValue<T>>>>,
161}
162
163impl<T: Clone> Clone for TypedCache<T> {
164 fn clone(&self) -> Self {
165 Self {
166 ttl: self.ttl,
167 cached: Arc::clone(&self.cached),
168 }
169 }
170}
171
172impl<T> TypedCache<T>
173where
174 T: Clone + Send + Sync,
175{
176 pub fn new(ttl: Duration) -> Self {
178 Self {
179 ttl,
180 cached: Arc::new(RwLock::new(None)),
181 }
182 }
183
184 pub async fn execute<F, Fut, E>(&self, f: F) -> Result<T, E>
186 where
187 F: Fn() -> Fut + Send + Sync,
188 Fut: Future<Output = Result<T, E>> + Send,
189 E: Send + Sync,
190 {
191 {
193 let cached = self.cached.read().await;
194 if let Some(cv) = &*cached {
195 if !cv.is_expired() {
196 return Ok(cv.value.clone());
197 }
198 }
199 }
200
201 let result = f().await?;
203 {
204 let mut cached = self.cached.write().await;
205 *cached = Some(CachedValue::new(result.clone(), self.ttl));
206 }
207 Ok(result)
208 }
209
210 pub async fn invalidate(&self) {
212 let mut cached = self.cached.write().await;
213 *cached = None;
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use std::sync::atomic::{AtomicUsize, Ordering};
221
222 #[tokio::test]
223 async fn test_typed_cache_caches_result() {
224 let cache = TypedCache::<String>::new(Duration::from_secs(60));
225 let call_count = Arc::new(AtomicUsize::new(0));
226
227 let cc = Arc::clone(&call_count);
229 let result = cache
230 .execute(|| {
231 let count = Arc::clone(&cc);
232 async move {
233 count.fetch_add(1, Ordering::SeqCst);
234 Ok::<_, String>("result".to_string())
235 }
236 })
237 .await;
238 assert_eq!(result.unwrap(), "result");
239 assert_eq!(call_count.load(Ordering::SeqCst), 1);
240
241 let cc = Arc::clone(&call_count);
243 let result = cache
244 .execute(|| {
245 let count = Arc::clone(&cc);
246 async move {
247 count.fetch_add(1, Ordering::SeqCst);
248 Ok::<_, String>("new_result".to_string())
249 }
250 })
251 .await;
252 assert_eq!(result.unwrap(), "result"); assert_eq!(call_count.load(Ordering::SeqCst), 1); }
255
256 #[tokio::test]
257 async fn test_typed_cache_invalidate() {
258 let cache = TypedCache::<String>::new(Duration::from_secs(60));
259 let call_count = Arc::new(AtomicUsize::new(0));
260
261 let cc = Arc::clone(&call_count);
263 let _ = cache
264 .execute(|| {
265 let count = Arc::clone(&cc);
266 async move {
267 count.fetch_add(1, Ordering::SeqCst);
268 Ok::<_, String>("first".to_string())
269 }
270 })
271 .await;
272 assert_eq!(call_count.load(Ordering::SeqCst), 1);
273
274 cache.invalidate().await;
276
277 let cc = Arc::clone(&call_count);
279 let result = cache
280 .execute(|| {
281 let count = Arc::clone(&cc);
282 async move {
283 count.fetch_add(1, Ordering::SeqCst);
284 Ok::<_, String>("second".to_string())
285 }
286 })
287 .await;
288 assert_eq!(result.unwrap(), "second");
289 assert_eq!(call_count.load(Ordering::SeqCst), 2);
290 }
291}