1use std::borrow::Borrow;
2use std::hash::Hash;
3use std::num::NonZeroUsize;
4
5use lru::LruCache;
6use sha1::Digest;
7use sha1::Sha1;
8use tokio::sync::Mutex;
9use tokio::sync::MutexGuard;
10
11pub struct BlockingLruCache<K, V> {
14 inner: Mutex<LruCache<K, V>>,
15}
16
17impl<K, V> BlockingLruCache<K, V>
18where
19 K: Eq + Hash,
20{
21 #[must_use]
23 pub fn new(capacity: NonZeroUsize) -> Self {
24 Self {
25 inner: Mutex::new(LruCache::new(capacity)),
26 }
27 }
28
29 pub fn get_or_insert_with(&self, key: K, value: impl FnOnce() -> V) -> V
31 where
32 V: Clone,
33 {
34 if let Some(mut guard) = lock_if_runtime(&self.inner) {
35 if let Some(v) = guard.get(&key) {
36 return v.clone();
37 }
38 let v = value();
39 guard.put(key, v.clone());
41 return v;
42 }
43 value()
44 }
45
46 pub fn get_or_try_insert_with<E>(
48 &self,
49 key: K,
50 value: impl FnOnce() -> Result<V, E>,
51 ) -> Result<V, E>
52 where
53 V: Clone,
54 {
55 if let Some(mut guard) = lock_if_runtime(&self.inner) {
56 if let Some(v) = guard.get(&key) {
57 return Ok(v.clone());
58 }
59 let v = value()?;
60 guard.put(key, v.clone());
61 return Ok(v);
62 }
63 value()
64 }
65
66 #[must_use]
68 pub fn try_with_capacity(capacity: usize) -> Option<Self> {
69 NonZeroUsize::new(capacity).map(Self::new)
70 }
71
72 pub fn get<Q>(&self, key: &Q) -> Option<V>
74 where
75 K: Borrow<Q>,
76 Q: Hash + Eq + ?Sized,
77 V: Clone,
78 {
79 let mut guard = lock_if_runtime(&self.inner)?;
80 guard.get(key).cloned()
81 }
82
83 pub fn insert(&self, key: K, value: V) -> Option<V> {
85 let mut guard = lock_if_runtime(&self.inner)?;
86 guard.put(key, value)
87 }
88
89 pub fn remove<Q>(&self, key: &Q) -> Option<V>
91 where
92 K: Borrow<Q>,
93 Q: Hash + Eq + ?Sized,
94 {
95 let mut guard = lock_if_runtime(&self.inner)?;
96 guard.pop(key)
97 }
98
99 pub fn clear(&self) {
101 if let Some(mut guard) = lock_if_runtime(&self.inner) {
102 guard.clear();
103 }
104 }
105
106 pub fn with_mut<R>(&self, callback: impl FnOnce(&mut LruCache<K, V>) -> R) -> R {
108 if let Some(mut guard) = lock_if_runtime(&self.inner) {
109 callback(&mut guard)
110 } else {
111 let mut disabled = LruCache::unbounded();
112 callback(&mut disabled)
113 }
114 }
115
116 pub fn blocking_lock(&self) -> Option<MutexGuard<'_, LruCache<K, V>>> {
118 lock_if_runtime(&self.inner)
119 }
120}
121
122fn lock_if_runtime<K, V>(m: &Mutex<LruCache<K, V>>) -> Option<MutexGuard<'_, LruCache<K, V>>>
123where
124 K: Eq + Hash,
125{
126 tokio::runtime::Handle::try_current().ok()?;
127 Some(tokio::task::block_in_place(|| m.blocking_lock()))
128}
129
130#[must_use]
135pub fn sha1_digest(bytes: &[u8]) -> [u8; 20] {
136 let mut hasher = Sha1::new();
137 hasher.update(bytes);
138 let result = hasher.finalize();
139 let mut out = [0; 20];
140 out.copy_from_slice(&result);
141 out
142}
143
144#[cfg(test)]
145mod tests {
146 use super::BlockingLruCache;
147 use std::num::NonZeroUsize;
148
149 #[tokio::test(flavor = "multi_thread")]
150 async fn stores_and_retrieves_values() {
151 let cache = BlockingLruCache::new(NonZeroUsize::new(2).expect("capacity"));
152
153 assert!(cache.get(&"first").is_none());
154 cache.insert("first", 1);
155 assert_eq!(cache.get(&"first"), Some(1));
156 }
157
158 #[tokio::test(flavor = "multi_thread")]
159 async fn evicts_least_recently_used() {
160 let cache = BlockingLruCache::new(NonZeroUsize::new(2).expect("capacity"));
161 cache.insert("a", 1);
162 cache.insert("b", 2);
163 assert_eq!(cache.get(&"a"), Some(1));
164
165 cache.insert("c", 3);
166
167 assert!(cache.get(&"b").is_none());
168 assert_eq!(cache.get(&"a"), Some(1));
169 assert_eq!(cache.get(&"c"), Some(3));
170 }
171
172 #[test]
173 fn disabled_without_runtime() {
174 let cache = BlockingLruCache::new(NonZeroUsize::new(2).expect("capacity"));
175 cache.insert("first", 1);
176 assert!(cache.get(&"first").is_none());
177
178 assert_eq!(cache.get_or_insert_with("first", || 2), 2);
179 assert!(cache.get(&"first").is_none());
180
181 assert!(cache.remove(&"first").is_none());
182 cache.clear();
183
184 let result = cache.with_mut(|inner| {
185 inner.put("tmp", 3);
186 inner.get(&"tmp").cloned()
187 });
188 assert_eq!(result, Some(3));
189 assert!(cache.get(&"tmp").is_none());
190
191 assert!(cache.blocking_lock().is_none());
192 }
193}
194