kaspa_utils/
expiring_cache.rs1use arc_swap::ArcSwapOption;
2use std::{
3 future::Future,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 time::{Duration, Instant},
9};
10
11struct Entry<T> {
12 item: T,
13 timestamp: Instant,
14}
15
16pub struct ExpiringCache<T> {
18 store: ArcSwapOption<Entry<T>>,
19 refetch: Duration,
20 expire: Duration,
21 fetching: AtomicBool,
22}
23
24impl<T: Clone> ExpiringCache<T> {
25 pub fn new(refetch: Duration, expire: Duration) -> Self {
30 assert!(refetch <= expire);
31 Self { store: Default::default(), refetch, expire, fetching: Default::default() }
32 }
33
34 pub async fn get<F>(&self, refetch_future: F) -> T
37 where
38 F: Future<Output = T> + Send + 'static,
39 F::Output: Send + 'static,
40 {
41 let mut fetching = false;
42
43 {
44 let guard = self.store.load();
45 if let Some(entry) = guard.as_ref() {
46 if let Some(elapsed) = Instant::now().checked_duration_since(entry.timestamp) {
47 if elapsed < self.refetch {
48 return entry.item.clone();
49 }
50 fetching = self.fetching.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_ok();
52 if !fetching && elapsed < self.expire {
55 return entry.item.clone();
56 }
57 }
58 }
60 }
61
62 let new_item = refetch_future.await;
64 let timestamp = Instant::now();
65 self.store.store(Some(Arc::new(Entry { item: new_item.clone(), timestamp })));
67
68 if fetching {
69 let result = self.fetching.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
70 assert!(result.is_ok(), "refetching was captured")
71 }
72
73 new_item
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::ExpiringCache;
80 use std::time::Duration;
81 use tokio::join;
82
83 #[tokio::test]
84 #[ignore]
85 async fn test_expiring_cache() {
88 let fetch = Duration::from_millis(500);
89 let expire = Duration::from_millis(1000);
90 let mid_point = Duration::from_millis(700);
91 let expire_point = Duration::from_millis(1200);
92 let cache: ExpiringCache<u64> = ExpiringCache::new(fetch, expire);
93
94 let item1 = cache
96 .get(async move {
97 println!("first call");
98 1
99 })
100 .await;
101 assert_eq!(1, item1);
102 let item2 = cache
103 .get(async move {
104 panic!("should not be called");
106 })
107 .await;
108 assert_eq!(1, item2);
109
110 tokio::time::sleep(mid_point).await;
113 let call3 = cache.get(async move {
114 println!("third call before sleep");
115 tokio::time::sleep(Duration::from_millis(100)).await;
117 println!("third call after sleep");
118 3
119 });
120 let call4 = cache.get(async move {
121 panic!("should not be called");
123 });
124 let (item3, item4) = join!(call3, call4);
125 println!("item 3: {}, item 4: {}", item3, item4);
126 assert_eq!(3, item3);
127 assert_eq!(1, item4);
128
129 tokio::time::sleep(expire_point).await;
131 let call5 = cache.get(async move {
132 println!("5th call before sleep");
133 tokio::time::sleep(Duration::from_millis(100)).await;
134 println!("5th call after sleep");
135 5
136 });
137 let call6 = cache.get(async move { 6 });
138 let (item5, item6) = join!(call5, call6);
139 println!("item 5: {}, item 6: {}", item5, item6);
140 assert_eq!(5, item5);
141 assert_eq!(6, item6);
142
143 let item7 = cache
144 .get(async move {
145 panic!("should not be called");
147 })
148 .await;
149 assert_eq!(5, item7);
151 }
152}