kaspa_utils/
expiring_cache.rs

1use arc_swap::ArcSwapOption;
2use std::{
3    future::Future,
4    sync::{
5        atomic::{AtomicBool, Ordering},
6        Arc,
7    },
8    time::{Duration, Instant},
9};
10
11struct Entry<T> {
12    item: T,
13    timestamp: Instant,
14}
15
16/// An expiring cache for a single object
17pub struct ExpiringCache<T> {
18    store: ArcSwapOption<Entry<T>>,
19    refetch: Duration,
20    expire: Duration,
21    fetching: AtomicBool,
22}
23
24impl<T: Clone> ExpiringCache<T> {
25    /// Constructs a new expiring cache where `fetch` is the amount of time required to trigger a data
26    /// refetch and `expire` is the time duration after which the stored item is guaranteed not to be returned.
27    ///
28    /// Panics if `refetch > expire`.
29    pub fn new(refetch: Duration, expire: Duration) -> Self {
30        assert!(refetch <= expire);
31        Self { store: Default::default(), refetch, expire, fetching: Default::default() }
32    }
33
34    /// Returns the cached item or possibly fetches a new one using the `refetch_future` task. The
35    /// decision whether to refetch depends on the configured expiration and refetch times for this cache.  
36    pub async fn get<F>(&self, refetch_future: F) -> T
37    where
38        F: Future<Output = T> + Send + 'static,
39        F::Output: Send + 'static,
40    {
41        let mut fetching = false;
42
43        {
44            let guard = self.store.load();
45            if let Some(entry) = guard.as_ref() {
46                if let Some(elapsed) = Instant::now().checked_duration_since(entry.timestamp) {
47                    if elapsed < self.refetch {
48                        return entry.item.clone();
49                    }
50                    // Refetch is triggered, attempt to capture the task
51                    fetching = self.fetching.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_ok();
52                    // If the fetch task is not captured and expire time is not over yet, return with prev value. Another
53                    // thread is refetching the data but we can return with the not-too-old value
54                    if !fetching && elapsed < self.expire {
55                        return entry.item.clone();
56                    }
57                }
58                // else -- In rare cases where now < timestamp, fall through to re-update the cache
59            }
60        }
61
62        // We reach here if either we are the refetching thread or the current data has fully expired
63        let new_item = refetch_future.await;
64        let timestamp = Instant::now();
65        // Update the store even if we were not in charge of refetching - let the last thread make the final update
66        self.store.store(Some(Arc::new(Entry { item: new_item.clone(), timestamp })));
67
68        if fetching {
69            let result = self.fetching.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
70            assert!(result.is_ok(), "refetching was captured")
71        }
72
73        new_item
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::ExpiringCache;
80    use std::time::Duration;
81    use tokio::join;
82
83    #[tokio::test]
84    #[ignore]
85    // Tested during development but can be sensitive to runtime machine times so there's no point
86    // in keeping it part of CI. The test should be activated if the ExpiringCache struct changes.
87    async fn test_expiring_cache() {
88        let fetch = Duration::from_millis(500);
89        let expire = Duration::from_millis(1000);
90        let mid_point = Duration::from_millis(700);
91        let expire_point = Duration::from_millis(1200);
92        let cache: ExpiringCache<u64> = ExpiringCache::new(fetch, expire);
93
94        // Test two consecutive calls
95        let item1 = cache
96            .get(async move {
97                println!("first call");
98                1
99            })
100            .await;
101        assert_eq!(1, item1);
102        let item2 = cache
103            .get(async move {
104                // cache was just updated with item1, refetch should not be triggered
105                panic!("should not be called");
106            })
107            .await;
108        assert_eq!(1, item2);
109
110        // Test two calls after refetch point
111        // Sleep until after the refetch point but before expire
112        tokio::time::sleep(mid_point).await;
113        let call3 = cache.get(async move {
114            println!("third call before sleep");
115            // keep this refetch busy so that call4 still gets the first item
116            tokio::time::sleep(Duration::from_millis(100)).await;
117            println!("third call after sleep");
118            3
119        });
120        let call4 = cache.get(async move {
121            // refetch is captured by call3 and we should be before expire
122            panic!("should not be called");
123        });
124        let (item3, item4) = join!(call3, call4);
125        println!("item 3: {}, item 4: {}", item3, item4);
126        assert_eq!(3, item3);
127        assert_eq!(1, item4);
128
129        // Test 2 calls after expire
130        tokio::time::sleep(expire_point).await;
131        let call5 = cache.get(async move {
132            println!("5th call before sleep");
133            tokio::time::sleep(Duration::from_millis(100)).await;
134            println!("5th call after sleep");
135            5
136        });
137        let call6 = cache.get(async move { 6 });
138        let (item5, item6) = join!(call5, call6);
139        println!("item 5: {}, item 6: {}", item5, item6);
140        assert_eq!(5, item5);
141        assert_eq!(6, item6);
142
143        let item7 = cache
144            .get(async move {
145                // cache was just updated with item5, refetch should not be triggered
146                panic!("should not be called");
147            })
148            .await;
149        // call 5 finished after call 6
150        assert_eq!(5, item7);
151    }
152}