1#![warn(
2 clippy::all,
3 clippy::dbg_macro,
4 clippy::todo,
5 clippy::empty_enum,
6 clippy::enum_glob_use,
7 clippy::mem_forget,
8 clippy::unused_self,
9 clippy::filter_map_next,
10 clippy::needless_continue,
11 clippy::needless_borrow,
12 clippy::match_wildcard_for_single_variants,
13 clippy::if_let_mutex,
14 clippy::mismatched_target_os,
15 clippy::await_holding_lock,
16 clippy::match_on_vec_items,
17 clippy::imprecise_flops,
18 clippy::suboptimal_flops,
19 clippy::lossy_float_literal,
20 clippy::rest_pat_in_fully_bound_structs,
21 clippy::fn_params_excessive_bools,
22 clippy::exit,
23 clippy::inefficient_to_string,
24 clippy::linkedlist,
25 clippy::macro_use_imports,
26 clippy::option_option,
27 clippy::verbose_file_reads,
28 clippy::unnested_or_patterns,
29 clippy::str_to_string,
30 rust_2018_idioms,
31 future_incompatible,
32 nonstandard_style
33)]
34#![deny(unreachable_pub, private_in_public)]
35
36use std::fmt::Debug;
37use std::future::Future;
38use std::hash::Hash;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::{Duration, Instant};
42
43use parking_lot::Mutex;
44use schnellru::{ByLength, LruMap};
45use tokio::sync::watch;
46
47pub struct DedCache<Req, Res, Err>(Arc<SharedState<Req, Res, Err>>);
49
50impl<Req, Res, Err> Clone for DedCache<Req, Res, Err> {
51 #[inline]
52 fn clone(&self) -> Self {
53 Self(self.0.clone())
54 }
55}
56
57impl<Req, Res, Err> DedCache<Req, Res, Err>
58where
59 Req: Hash + Eq + Clone + Debug,
60 Res: Clone,
61 Err: Clone,
62{
63 pub fn new(lifetime: Duration, size: u32) -> Self {
67 Self(Arc::new(SharedState::new(lifetime, size)))
68 }
69
70 pub async fn get_or_update<F, Fut>(&self, key: Req, f: F) -> Result<Res, CoalesceError<Err>>
81 where
82 F: FnOnce() -> Fut,
83 Fut: Future<Output = Result<Res, Err>>,
84 {
85 self.0.get_or_update(key, f).await
86 }
87
88 pub fn fetch_stats(&self) -> Stats {
90 self.0.fetch_stats()
91 }
92}
93
94struct SharedState<Req, Res, Err> {
95 cache: RequestLru<Req, Res, Err>,
96 lifetime: Duration,
97 total_requests: AtomicU64,
98 calls_made: AtomicU64,
99 cache_hit: AtomicU64,
100}
101
102impl<Req, Res, Err> SharedState<Req, Res, Err>
103where
104 Req: Hash + Eq + Clone + Debug,
105 Res: Clone,
106 Err: Clone,
107{
108 fn new(lifetime: Duration, size: u32) -> Self {
109 Self {
110 cache: Mutex::new(LruMap::new(ByLength::new(size))),
111 lifetime,
112 total_requests: Default::default(),
113 cache_hit: Default::default(),
114 calls_made: Default::default(),
115 }
116 }
117
118 async fn get_or_update<F, Fut>(&self, key: Req, f: F) -> Result<Res, CoalesceError<Err>>
119 where
120 F: FnOnce() -> Fut,
121 Fut: Future<Output = Result<Res, Err>>,
122 {
123 struct RemoveWatchOnDrop<'a, Req, Res, Err>
124 where
125 Req: Hash + Eq,
126 {
127 key: Option<&'a Req>,
128 cache: &'a RequestLru<Req, Res, Err>,
129 }
130
131 impl<Req, Res, Err> RemoveWatchOnDrop<'_, Req, Res, Err>
132 where
133 Req: Hash + Eq,
134 {
135 fn disarm(mut self) {
136 self.key = None;
137 }
138 }
139
140 impl<Req, Res, Err> Drop for RemoveWatchOnDrop<'_, Req, Res, Err>
141 where
142 Req: Hash + Eq,
143 {
144 fn drop(&mut self) {
145 if let Some(key) = self.key.take() {
146 self.cache.lock().remove(key);
147 }
148 }
149 }
150
151 enum Task {
152 New,
153 Existing,
154 }
155
156 self.update_request_number();
157
158 let (tx, task) = 'task: {
160 let mut cache = self.cache.lock();
161
162 if let Some(entry) = cache.get(&key) {
164 let result = entry.borrow();
165 if let Some(result) = &*result {
166 if result.since.elapsed() <= self.lifetime {
168 self.update_cache_hit();
169 return result.data.clone().map_err(CoalesceError::Indirect);
170 }
171 } else {
172 break 'task (entry.clone(), Task::Existing);
173 }
174 }
175
176 let (tx, _) = watch::channel(None);
178 let tx = Arc::new(tx);
179 cache.insert(key.clone(), tx.clone());
180 (tx, Task::New)
181 };
182
183 let drop_guard = RemoveWatchOnDrop {
185 key: Some(&key),
186 cache: &self.cache,
187 };
188
189 match task {
191 Task::New => {
192 let result = f().await;
194 self.update_calls_number();
195
196 if result.is_ok() {
197 drop_guard.disarm();
199 }
200
201 tx.send_modify(|value| {
203 *value = Some(Entry {
204 data: result.clone(),
205 since: Instant::now(),
206 })
207 });
208
209 result.map_err(CoalesceError::Direct)
211 }
212 Task::Existing => {
213 let mut rx = tx.subscribe();
214
215 {
217 let result = rx.borrow();
218 if let Some(result) = &*result {
219 return result.data.clone().map_err(CoalesceError::Indirect);
220 }
221 }
222
223 rx.changed().await.unwrap();
225
226 let result = rx.borrow();
227 let result = result.as_ref().unwrap().data.clone();
228
229 self.update_cache_hit();
230 result.map_err(CoalesceError::Indirect)
231 }
232 }
233 }
234
235 fn fetch_stats(&self) -> Stats {
236 let (memory_usage, len) = {
237 let cache = self.cache.lock();
238 (cache.memory_usage(), cache.len())
239 };
240
241 let total_requests = self.total_requests.load(Ordering::Relaxed);
242 let calls_made = self.calls_made.load(Ordering::Relaxed);
243 let cache_hit = self.cache_hit.load(Ordering::Relaxed);
244
245 let cache_hit_ratio = if total_requests == 0 {
246 0.0
247 } else {
248 cache_hit as f64 / total_requests as f64
249 };
250
251 Stats {
252 total_requests,
253 cache_hit,
254 memory_usage,
255 len,
256 cache_hit_ratio,
257 calls_made,
258 }
259 }
260
261 fn update_request_number(&self) {
262 self.total_requests.fetch_add(1, Ordering::Relaxed);
263 }
264
265 fn update_calls_number(&self) {
266 self.calls_made.fetch_add(1, Ordering::Relaxed);
267 }
268
269 fn update_cache_hit(&self) {
270 self.cache_hit.fetch_add(1, Ordering::Relaxed);
271 }
272}
273
274#[derive(thiserror::Error, Debug)]
275pub enum CoalesceError<E> {
276 #[error("request failed")]
278 Direct(#[source] E),
279
280 #[error("inflight request failed")]
282 Indirect(#[source] E),
283}
284
285impl<E> CoalesceError<E> {
286 pub fn into_inner(self) -> E {
287 match self {
288 Self::Direct(e) => e,
289 Self::Indirect(e) => e,
290 }
291 }
292}
293
294#[derive(Clone, Debug)]
295pub struct Stats {
296 pub memory_usage: usize,
297 pub len: usize,
298
299 pub total_requests: u64,
301 pub calls_made: u64,
303 pub cache_hit: u64,
305 pub cache_hit_ratio: f64,
307}
308
309struct Entry<T, E> {
310 data: Result<T, E>,
311 since: Instant,
312}
313
314type ResultTx<T, E> = watch::Sender<Option<Entry<T, E>>>;
315type RequestLru<K, V, E> = Mutex<LruMap<K, Arc<ResultTx<V, E>>, ByLength>>;
316
317#[cfg(test)]
318mod test {
319 use std::convert::Infallible;
320 use std::time::Duration;
321
322 use super::*;
323
324 #[tokio::test]
325 async fn test_cache() {
326 let cache = DedCache::<_, _, Infallible>::new(Duration::from_secs(1), 1024);
327
328 let key = "key";
329
330 let value = cache.get_or_update(key, fut).await.unwrap();
332 assert_eq!(value, "value"); let start = Instant::now();
335
336 let value = cache.get_or_update(key, fut).await.unwrap();
337 assert_eq!(value, "value");
339 assert!(start.elapsed() < Duration::from_secs(1));
340
341 tokio::time::sleep(Duration::from_secs(2)).await;
342 {
345 let mut cache = cache.0.cache.lock();
346 let val = cache.get(&key).unwrap();
347 let val = val.borrow();
348 let val = val.as_ref().unwrap();
349
350 assert!(val.since.elapsed() > Duration::from_secs(1));
352 }
353
354 let start = std::time::Instant::now();
355 let value = cache.get_or_update(key, fut).await.unwrap();
356 assert_eq!(value, "value");
357 assert!(start.elapsed() > Duration::from_secs(1));
359 }
360
361 async fn fut() -> Result<&'static str, Infallible> {
362 tokio::time::sleep(Duration::from_secs(2)).await;
363 Ok("value")
364 }
365
366 #[tokio::test]
367 async fn test_with_eviction() {
368 let cache = DedCache::<_, _, Infallible>::new(Duration::from_secs(1), 2);
369
370 let key1 = 1;
372 let key2 = 2;
373 let key3 = 3;
374
375 let value1 = {
376 let cache = cache.clone();
377 tokio::spawn(async move {
378 cache
379 .get_or_update(key1, || fut2(Duration::from_secs(1), 1337))
380 .await
381 })
382 };
383 let value2 = {
384 let cache = cache.clone();
385 tokio::spawn(async move {
386 cache
387 .get_or_update(key2, || fut2(Duration::from_secs(0), 1337))
388 .await
389 })
390 };
391 let value3 = {
392 let cache = cache.clone();
393 tokio::spawn(async move {
394 cache
395 .get_or_update(key3, || fut2(Duration::from_secs(0), 1337))
396 .await
397 })
398 };
399
400 let value1_second_get = {
401 let cache = cache.clone();
402 tokio::spawn(async move {
403 cache
404 .get_or_update(key1, || fut2(Duration::from_secs(0), 1337))
405 .await
406 })
407 };
408
409 println!("Val1 = {:?}", value1.await.unwrap().unwrap());
411 println!("Val2 = {:?}", value2.await.unwrap().unwrap());
413 println!("Val3 = {:?}", value3.await.unwrap().unwrap());
415
416 println!("Val1 = {:?}", value1_second_get.await.unwrap().unwrap());
418
419 let lock = cache.0.cache.lock();
420 for (i, entry) in lock.iter() {
421 let entry = entry.borrow();
422 println!("{i}: {:?}", entry.is_some());
423 }
424 }
425
426 async fn fut2(time: Duration, retval: u32) -> Result<u32, Infallible> {
427 tokio::time::sleep(time).await;
428 println!("fut2 finished after {time:?}");
429 Ok(retval)
430 }
431
432 #[tokio::test]
433 async fn test_under_load() {
434 let cache = DedCache::<_, _, Infallible>::new(Duration::from_secs(1), 100);
435
436 let start = Instant::now();
437 let mut futures_list = Vec::new();
440 for key in 0..100 {
441 let mut futures = Vec::new();
442 for _ in 0..100 {
443 let cache = cache.clone();
444 let handle = tokio::spawn(async move {
445 cache
446 .get_or_update(key, || fut2(Duration::from_secs(1), key as u32))
447 .await
448 });
449 futures.push(handle);
450 }
451
452 futures_list.push(futures);
453 }
454 for (group_id, futures) in futures_list.into_iter().enumerate() {
455 for future in futures {
456 let res = future.await.unwrap().unwrap();
457 assert_eq!(res, group_id as u32);
458 }
459 }
460
461 assert!(start.elapsed() < Duration::from_secs(6));
462 println!("Stats: {:?}", cache.fetch_stats());
463 assert!(cache.fetch_stats().cache_hit_ratio > 0.9);
464 assert_eq!(cache.fetch_stats().calls_made, 100);
465 }
466}