ded/
lib.rs

1#![warn(
2    clippy::all,
3    clippy::dbg_macro,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::enum_glob_use,
7    clippy::mem_forget,
8    clippy::unused_self,
9    clippy::filter_map_next,
10    clippy::needless_continue,
11    clippy::needless_borrow,
12    clippy::match_wildcard_for_single_variants,
13    clippy::if_let_mutex,
14    clippy::mismatched_target_os,
15    clippy::await_holding_lock,
16    clippy::match_on_vec_items,
17    clippy::imprecise_flops,
18    clippy::suboptimal_flops,
19    clippy::lossy_float_literal,
20    clippy::rest_pat_in_fully_bound_structs,
21    clippy::fn_params_excessive_bools,
22    clippy::exit,
23    clippy::inefficient_to_string,
24    clippy::linkedlist,
25    clippy::macro_use_imports,
26    clippy::option_option,
27    clippy::verbose_file_reads,
28    clippy::unnested_or_patterns,
29    clippy::str_to_string,
30    rust_2018_idioms,
31    future_incompatible,
32    nonstandard_style
33)]
34#![deny(unreachable_pub, private_in_public)]
35
36use std::fmt::Debug;
37use std::future::Future;
38use std::hash::Hash;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::{Duration, Instant};
42
43use parking_lot::Mutex;
44use schnellru::{ByLength, LruMap};
45use tokio::sync::watch;
46
47/// Deduplication cache for requests.
48pub struct DedCache<Req, Res, Err>(Arc<SharedState<Req, Res, Err>>);
49
50impl<Req, Res, Err> Clone for DedCache<Req, Res, Err> {
51    #[inline]
52    fn clone(&self) -> Self {
53        Self(self.0.clone())
54    }
55}
56
57impl<Req, Res, Err> DedCache<Req, Res, Err>
58where
59    Req: Hash + Eq + Clone + Debug,
60    Res: Clone,
61    Err: Clone,
62{
63    /// # Arguments
64    /// * `lifetime` - The lifetime of a cached value.
65    /// * `size` - The maximum number of cached values.
66    pub fn new(lifetime: Duration, size: u32) -> Self {
67        Self(Arc::new(SharedState::new(lifetime, size)))
68    }
69
70    /// Calling this method can cause 3 different things to happen:
71    /// 1. If the key is not in the cache, `f` will be called, all other callers will wait for it to finish and then return the result.
72    /// 2. If the key is in the cache and the value is older than `lifetime` then `1.` will happen.
73    /// 3. If the key is in the cache and the value is newer than `lifetime` then the cached value will be returned.
74    ///
75    /// Cache hits are tracked and can be retrieved using `stats`.
76    ///
77    /// # Arguments
78    /// * `key` - The key to lookup in the cache.
79    /// * `f` - The future to call if the key is not in the cache or the value is older than `lifetime`.
80    pub async fn get_or_update<F, Fut>(&self, key: Req, f: F) -> Result<Res, CoalesceError<Err>>
81    where
82        F: FnOnce() -> Fut,
83        Fut: Future<Output = Result<Res, Err>>,
84    {
85        self.0.get_or_update(key, f).await
86    }
87
88    /// Returns cache performance statistics.
89    pub fn fetch_stats(&self) -> Stats {
90        self.0.fetch_stats()
91    }
92}
93
94struct SharedState<Req, Res, Err> {
95    cache: RequestLru<Req, Res, Err>,
96    lifetime: Duration,
97    total_requests: AtomicU64,
98    calls_made: AtomicU64,
99    cache_hit: AtomicU64,
100}
101
102impl<Req, Res, Err> SharedState<Req, Res, Err>
103where
104    Req: Hash + Eq + Clone + Debug,
105    Res: Clone,
106    Err: Clone,
107{
108    fn new(lifetime: Duration, size: u32) -> Self {
109        Self {
110            cache: Mutex::new(LruMap::new(ByLength::new(size))),
111            lifetime,
112            total_requests: Default::default(),
113            cache_hit: Default::default(),
114            calls_made: Default::default(),
115        }
116    }
117
118    async fn get_or_update<F, Fut>(&self, key: Req, f: F) -> Result<Res, CoalesceError<Err>>
119    where
120        F: FnOnce() -> Fut,
121        Fut: Future<Output = Result<Res, Err>>,
122    {
123        struct RemoveWatchOnDrop<'a, Req, Res, Err>
124        where
125            Req: Hash + Eq,
126        {
127            key: Option<&'a Req>,
128            cache: &'a RequestLru<Req, Res, Err>,
129        }
130
131        impl<Req, Res, Err> RemoveWatchOnDrop<'_, Req, Res, Err>
132        where
133            Req: Hash + Eq,
134        {
135            fn disarm(mut self) {
136                self.key = None;
137            }
138        }
139
140        impl<Req, Res, Err> Drop for RemoveWatchOnDrop<'_, Req, Res, Err>
141        where
142            Req: Hash + Eq,
143        {
144            fn drop(&mut self) {
145                if let Some(key) = self.key.take() {
146                    self.cache.lock().remove(key);
147                }
148            }
149        }
150
151        enum Task {
152            New,
153            Existing,
154        }
155
156        self.update_request_number();
157
158        // Determine what to do based on cache
159        let (tx, task) = 'task: {
160            let mut cache = self.cache.lock();
161
162            // Check and existing entry
163            if let Some(entry) = cache.get(&key) {
164                let result = entry.borrow();
165                if let Some(result) = &*result {
166                    // Return cached response if it is still valid
167                    if result.since.elapsed() <= self.lifetime {
168                        self.update_cache_hit();
169                        return result.data.clone().map_err(CoalesceError::Indirect);
170                    }
171                } else {
172                    break 'task (entry.clone(), Task::Existing);
173                }
174            }
175
176            // Create a new entry
177            let (tx, _) = watch::channel(None);
178            let tx = Arc::new(tx);
179            cache.insert(key.clone(), tx.clone());
180            (tx, Task::New)
181        };
182
183        // Ensure that task will be dropped when the future is dropped
184        let drop_guard = RemoveWatchOnDrop {
185            key: Some(&key),
186            cache: &self.cache,
187        };
188
189        // Execute task
190        match task {
191            Task::New => {
192                // Execute request
193                let result = f().await;
194                self.update_calls_number();
195
196                if result.is_ok() {
197                    // Prevent watch from being removed from the cache
198                    drop_guard.disarm();
199                }
200
201                // Notify all waiters with result
202                tx.send_modify(|value| {
203                    *value = Some(Entry {
204                        data: result.clone(),
205                        since: Instant::now(),
206                    })
207                });
208
209                // Done
210                result.map_err(CoalesceError::Direct)
211            }
212            Task::Existing => {
213                let mut rx = tx.subscribe();
214
215                // Check if notify was already resolved
216                {
217                    let result = rx.borrow();
218                    if let Some(result) = &*result {
219                        return result.data.clone().map_err(CoalesceError::Indirect);
220                    }
221                }
222
223                // Wait for an existing response
224                rx.changed().await.unwrap();
225
226                let result = rx.borrow();
227                let result = result.as_ref().unwrap().data.clone();
228
229                self.update_cache_hit();
230                result.map_err(CoalesceError::Indirect)
231            }
232        }
233    }
234
235    fn fetch_stats(&self) -> Stats {
236        let (memory_usage, len) = {
237            let cache = self.cache.lock();
238            (cache.memory_usage(), cache.len())
239        };
240
241        let total_requests = self.total_requests.load(Ordering::Relaxed);
242        let calls_made = self.calls_made.load(Ordering::Relaxed);
243        let cache_hit = self.cache_hit.load(Ordering::Relaxed);
244
245        let cache_hit_ratio = if total_requests == 0 {
246            0.0
247        } else {
248            cache_hit as f64 / total_requests as f64
249        };
250
251        Stats {
252            total_requests,
253            cache_hit,
254            memory_usage,
255            len,
256            cache_hit_ratio,
257            calls_made,
258        }
259    }
260
261    fn update_request_number(&self) {
262        self.total_requests.fetch_add(1, Ordering::Relaxed);
263    }
264
265    fn update_calls_number(&self) {
266        self.calls_made.fetch_add(1, Ordering::Relaxed);
267    }
268
269    fn update_cache_hit(&self) {
270        self.cache_hit.fetch_add(1, Ordering::Relaxed);
271    }
272}
273
274#[derive(thiserror::Error, Debug)]
275pub enum CoalesceError<E> {
276    /// our own request failed
277    #[error("request failed")]
278    Direct(#[source] E),
279
280    /// request which was in progress failed, you should log it cause it will cause log spam
281    #[error("inflight request failed")]
282    Indirect(#[source] E),
283}
284
285impl<E> CoalesceError<E> {
286    pub fn into_inner(self) -> E {
287        match self {
288            Self::Direct(e) => e,
289            Self::Indirect(e) => e,
290        }
291    }
292}
293
294#[derive(Clone, Debug)]
295pub struct Stats {
296    pub memory_usage: usize,
297    pub len: usize,
298
299    /// number of requests for `get_or_update`
300    pub total_requests: u64,
301    /// number of requests which were called upstream
302    pub calls_made: u64,
303    /// number of requests which were served from the cache
304    pub cache_hit: u64,
305    /// ratio of cache hits to total requests in range [0.0, 100.0]
306    pub cache_hit_ratio: f64,
307}
308
309struct Entry<T, E> {
310    data: Result<T, E>,
311    since: Instant,
312}
313
314type ResultTx<T, E> = watch::Sender<Option<Entry<T, E>>>;
315type RequestLru<K, V, E> = Mutex<LruMap<K, Arc<ResultTx<V, E>>, ByLength>>;
316
317#[cfg(test)]
318mod test {
319    use std::convert::Infallible;
320    use std::time::Duration;
321
322    use super::*;
323
324    #[tokio::test]
325    async fn test_cache() {
326        let cache = DedCache::<_, _, Infallible>::new(Duration::from_secs(1), 1024);
327
328        let key = "key";
329
330        // inserting a value
331        let value = cache.get_or_update(key, fut).await.unwrap();
332        assert_eq!(value, "value"); // value is returned
333
334        let start = Instant::now();
335
336        let value = cache.get_or_update(key, fut).await.unwrap();
337        // value is returned immediately
338        assert_eq!(value, "value");
339        assert!(start.elapsed() < Duration::from_secs(1));
340
341        tokio::time::sleep(Duration::from_secs(2)).await;
342        // at this point the value is expired
343
344        {
345            let mut cache = cache.0.cache.lock();
346            let val = cache.get(&key).unwrap();
347            let val = val.borrow();
348            let val = val.as_ref().unwrap();
349
350            // last update is more than 1 second ago
351            assert!(val.since.elapsed() > Duration::from_secs(1));
352        }
353
354        let start = std::time::Instant::now();
355        let value = cache.get_or_update(key, fut).await.unwrap();
356        assert_eq!(value, "value");
357        // update took more than 1 second cause it was expired
358        assert!(start.elapsed() > Duration::from_secs(1));
359    }
360
361    async fn fut() -> Result<&'static str, Infallible> {
362        tokio::time::sleep(Duration::from_secs(2)).await;
363        Ok("value")
364    }
365
366    #[tokio::test]
367    async fn test_with_eviction() {
368        let cache = DedCache::<_, _, Infallible>::new(Duration::from_secs(1), 2);
369
370        // creating 3 updates so that the first one is evicted
371        let key1 = 1;
372        let key2 = 2;
373        let key3 = 3;
374
375        let value1 = {
376            let cache = cache.clone();
377            tokio::spawn(async move {
378                cache
379                    .get_or_update(key1, || fut2(Duration::from_secs(1), 1337))
380                    .await
381            })
382        };
383        let value2 = {
384            let cache = cache.clone();
385            tokio::spawn(async move {
386                cache
387                    .get_or_update(key2, || fut2(Duration::from_secs(0), 1337))
388                    .await
389            })
390        };
391        let value3 = {
392            let cache = cache.clone();
393            tokio::spawn(async move {
394                cache
395                    .get_or_update(key3, || fut2(Duration::from_secs(0), 1337))
396                    .await
397            })
398        };
399
400        let value1_second_get = {
401            let cache = cache.clone();
402            tokio::spawn(async move {
403                cache
404                    .get_or_update(key1, || fut2(Duration::from_secs(0), 1337))
405                    .await
406            })
407        };
408
409        // waiting for the first update to finish
410        println!("Val1 = {:?}", value1.await.unwrap().unwrap());
411        // waiting for the second update to finish
412        println!("Val2 = {:?}", value2.await.unwrap().unwrap());
413        // waiting for the third update to finish
414        println!("Val3 = {:?}", value3.await.unwrap().unwrap());
415
416        // waiting for the first update to finish
417        println!("Val1 = {:?}", value1_second_get.await.unwrap().unwrap());
418
419        let lock = cache.0.cache.lock();
420        for (i, entry) in lock.iter() {
421            let entry = entry.borrow();
422            println!("{i}: {:?}", entry.is_some());
423        }
424    }
425
426    async fn fut2(time: Duration, retval: u32) -> Result<u32, Infallible> {
427        tokio::time::sleep(time).await;
428        println!("fut2 finished after {time:?}");
429        Ok(retval)
430    }
431
432    #[tokio::test]
433    async fn test_under_load() {
434        let cache = DedCache::<_, _, Infallible>::new(Duration::from_secs(1), 100);
435
436        let start = Instant::now();
437        // spawning 100 task groups which will try to get the same key within group
438
439        let mut futures_list = Vec::new();
440        for key in 0..100 {
441            let mut futures = Vec::new();
442            for _ in 0..100 {
443                let cache = cache.clone();
444                let handle = tokio::spawn(async move {
445                    cache
446                        .get_or_update(key, || fut2(Duration::from_secs(1), key as u32))
447                        .await
448                });
449                futures.push(handle);
450            }
451
452            futures_list.push(futures);
453        }
454        for (group_id, futures) in futures_list.into_iter().enumerate() {
455            for future in futures {
456                let res = future.await.unwrap().unwrap();
457                assert_eq!(res, group_id as u32);
458            }
459        }
460
461        assert!(start.elapsed() < Duration::from_secs(6));
462        println!("Stats: {:?}", cache.fetch_stats());
463        assert!(cache.fetch_stats().cache_hit_ratio > 0.9);
464        assert_eq!(cache.fetch_stats().calls_made, 100);
465    }
466}