1use crate::{BufferCache, BufferCacheStrategy, CacheEntry};
2use std::any::Any;
3use std::collections::BTreeMap;
4use std::fmt::Debug;
5use std::hash::RandomState;
6use std::marker::PhantomData;
7use std::ops::RangeBounds;
8use std::sync::Mutex;
9
10pub struct LruCache<K, V, S = RandomState> {
12 inner: Mutex<CacheInner<K, V>>,
14 marker: PhantomData<fn() -> S>,
16}
17
18struct CacheInner<K, V> {
20 cache: BTreeMap<K, CacheValue<V>>,
22 lru: BTreeMap<u64, K>,
24 next_serial: u64,
26 cur_cost: usize,
28 max_cost: usize,
30}
31
32struct CacheValue<V> {
34 aux: V,
36 serial: u64,
38}
39
40impl<K, V, S> LruCache<K, V, S> {
41 pub const DEFAULT_SHARDS: usize = 1;
43}
44
45impl<K, V> LruCache<K, V, RandomState>
46where
47 K: Ord + Clone + Debug,
48 V: CacheEntry + Clone,
49{
50 pub fn new(max_cost: usize) -> Self {
52 Self::with_hasher(max_cost, RandomState::new())
53 }
54}
55
56#[allow(clippy::len_without_is_empty)]
58impl<K, V, S> LruCache<K, V, S>
59where
60 K: Ord + Clone + Debug,
61 V: CacheEntry + Clone,
62{
63 pub fn with_hasher(max_cost: usize, _hash_builder: S) -> Self {
65 Self {
66 inner: Mutex::new(CacheInner::new(max_cost)),
67 marker: PhantomData,
68 }
69 }
70
71 pub fn insert(&self, key: K, value: V) {
73 self.inner.lock().unwrap().insert(key, value);
74 }
75
76 pub fn get(&self, key: &K) -> Option<V> {
78 self.inner.lock().unwrap().get(key.clone())
79 }
80
81 pub fn remove(&self, key: &K) -> Option<V> {
83 self.inner.lock().unwrap().remove(key)
84 }
85
86 pub fn remove_if<F>(&self, predicate: F)
88 where
89 F: Fn(&K) -> bool,
90 {
91 self.inner.lock().unwrap().remove_if(predicate)
92 }
93
94 pub fn remove_range<R>(&self, range: R) -> usize
96 where
97 R: RangeBounds<K>,
98 {
99 self.inner.lock().unwrap().remove_range(range)
100 }
101
102 pub fn contains_key(&self, key: &K) -> bool {
104 self.inner.lock().unwrap().contains_key(key)
105 }
106
107 pub fn len(&self) -> usize {
109 self.inner.lock().unwrap().len()
110 }
111
112 pub fn total_charge(&self) -> usize {
114 self.inner.lock().unwrap().cur_cost
115 }
116
117 pub fn total_capacity(&self) -> usize {
119 self.inner.lock().unwrap().max_cost
120 }
121
122 pub fn shard_count(&self) -> usize {
124 Self::DEFAULT_SHARDS
125 }
126
127 #[cfg(test)]
133 pub fn shard_usage(&self, idx: usize) -> (usize, usize) {
134 assert_eq!(idx, 0, "shard index out of bounds");
135 let inner = self.inner.lock().unwrap();
136 (inner.cur_cost, inner.max_cost)
137 }
138
139 #[cfg(test)]
140 pub(crate) fn validate_invariants(&self) {
141 self.inner.lock().unwrap().check_invariants();
142 }
143}
144
145impl<K, V, S> BufferCache<K, V> for LruCache<K, V, S>
146where
147 K: Ord + Clone + Debug + Send + Sync + 'static,
148 V: CacheEntry + Clone + Send + Sync + 'static,
149 S: Send + Sync + 'static,
150{
151 fn as_any(&self) -> &dyn Any {
152 self
153 }
154
155 fn strategy(&self) -> BufferCacheStrategy {
156 BufferCacheStrategy::Lru
157 }
158
159 fn insert(&self, key: K, value: V) {
160 self.insert(key, value);
161 }
162
163 fn get(&self, key: K) -> Option<V> {
164 self.inner.lock().unwrap().get(key)
165 }
166
167 fn remove(&self, key: &K) -> Option<V> {
168 self.remove(key)
169 }
170
171 fn remove_if(&self, predicate: &dyn Fn(&K) -> bool) {
172 self.remove_if(|key| predicate(key))
173 }
174
175 fn contains_key(&self, key: &K) -> bool {
176 self.contains_key(key)
177 }
178
179 fn len(&self) -> usize {
180 self.len()
181 }
182
183 fn total_charge(&self) -> usize {
184 self.total_charge()
185 }
186
187 fn total_capacity(&self) -> usize {
188 self.total_capacity()
189 }
190
191 fn shard_count(&self) -> usize {
192 self.shard_count()
193 }
194
195 #[cfg(test)]
196 fn shard_usage(&self, idx: usize) -> (usize, usize) {
197 self.shard_usage(idx)
198 }
199}
200
201impl<K, V> CacheInner<K, V>
202where
203 K: Ord + Clone + Debug,
204 V: CacheEntry + Clone,
205{
206 fn new(max_cost: usize) -> Self {
208 Self {
209 cache: BTreeMap::new(),
210 lru: BTreeMap::new(),
211 next_serial: 0,
212 cur_cost: 0,
213 max_cost,
214 }
215 }
216
217 #[cfg(any(test, debug_assertions))]
219 fn check_invariants(&self) {
220 assert_eq!(self.cache.len(), self.lru.len());
221 let mut cost = 0;
222 for (key, value) in self.cache.iter() {
223 assert_eq!(self.lru.get(&value.serial), Some(key));
224 cost += value.aux.cost();
225 }
226 for (serial, key) in self.lru.iter() {
227 assert_eq!(self.cache.get(key).unwrap().serial, *serial);
228 }
229 assert_eq!(cost, self.cur_cost);
230 }
231
232 fn debug_check_invariants(&self) {
234 #[cfg(debug_assertions)]
235 self.check_invariants()
236 }
237
238 fn get(&mut self, key: K) -> Option<V> {
240 if let Some(value) = self.cache.get_mut(&key) {
241 self.lru.remove(&value.serial);
242 value.serial = self.next_serial;
243 self.lru.insert(value.serial, key);
244 self.next_serial += 1;
245 Some(value.aux.clone())
246 } else {
247 None
248 }
249 }
250
251 fn evict_to(&mut self, max_cost: usize) {
253 while self.cur_cost > max_cost {
254 let (_serial, key) = self.lru.pop_first().unwrap();
257 let value = self.cache.remove(&key).unwrap();
258 self.cur_cost -= value.aux.cost();
259 }
260 self.debug_check_invariants();
261 }
262
263 fn insert(&mut self, key: K, aux: V) {
265 let cost = aux.cost();
266 self.evict_to(self.max_cost.saturating_sub(cost));
267 if let Some(old_value) = self.cache.insert(
268 key.clone(),
269 CacheValue {
270 aux,
271 serial: self.next_serial,
272 },
273 ) {
274 self.lru.remove(&old_value.serial);
275 self.cur_cost -= old_value.aux.cost();
276 }
277 self.lru.insert(self.next_serial, key);
278 self.cur_cost += cost;
279 self.next_serial += 1;
280 self.debug_check_invariants();
281 }
282
283 fn remove(&mut self, key: &K) -> Option<V> {
285 let value = self.cache.remove(key)?;
286 self.lru.remove(&value.serial).unwrap();
287 self.cur_cost -= value.aux.cost();
288 self.debug_check_invariants();
289 Some(value.aux)
290 }
291
292 fn remove_if<F>(&mut self, predicate: F)
294 where
295 F: Fn(&K) -> bool,
296 {
297 let keys: Vec<K> = self
298 .cache
299 .keys()
300 .filter(|key| predicate(key))
301 .cloned()
302 .collect();
303 for key in keys {
304 let _ = self.remove(&key);
305 }
306 }
307
308 fn remove_range<R>(&mut self, range: R) -> usize
310 where
311 R: RangeBounds<K>,
312 {
313 let victims: Vec<(K, u64)> = self
314 .cache
315 .range(range)
316 .map(|(key, value)| (key.clone(), value.serial))
317 .collect();
318
319 let removed = victims.len();
320 for (key, serial) in victims {
321 self.lru.remove(&serial).unwrap();
322 self.cur_cost -= self.cache.remove(&key).unwrap().aux.cost();
323 }
324 self.debug_check_invariants();
325 removed
326 }
327
328 fn contains_key(&self, key: &K) -> bool {
330 self.cache.contains_key(key)
331 }
332
333 fn len(&self) -> usize {
335 self.cache.len()
336 }
337}