1use std::sync::Mutex;
18use std::collections::hash_map::RandomState;
19use std::collections::{HashMap, VecDeque};
20use std::fmt;
21use std::hash::Hash;
22use std::sync::Arc;
23
24pub struct DynamicCacheLocal<K, V, S = RandomState> {
30 map: HashMap<K, (u32, Option<Arc<V>>), S>,
31 list: VecDeque<(K, u32)>,
32 mem_len: usize,
33 size: usize,
34 hits: u64,
35 misses: u64,
36}
37
38impl<K: Clone + Eq + Hash, V, S> DynamicCacheLocal<K, V, S> {
39 pub fn with_hasher(mem_len: usize, hash_builder: S) -> DynamicCacheLocal<K, V, S> {
42 let mem_len = mem_len.clamp(2, u32::MAX as usize);
44
45 Self {
46 map: HashMap::with_hasher(hash_builder),
47 list: VecDeque::with_capacity(mem_len),
48 mem_len,
49 size: 0,
50 hits: 0,
51 misses: 0,
52 }
53 }
54}
55
56impl<K: Clone + Eq + Hash, V> DynamicCacheLocal<K, V> {
57 pub fn new(mem_len: usize) -> Self {
59 let mem_len = mem_len.clamp(2, u32::MAX as usize);
61
62 Self {
63 map: HashMap::new(),
64 list: VecDeque::with_capacity(mem_len),
65 mem_len,
66 size: 0,
67 hits: 0,
68 misses: 0,
69 }
70 }
71
72 pub fn get(&mut self, key: &K) -> Option<Arc<V>> {
75 let (counter, ret) = match self.map.get_mut(key) {
76 Some((counter, Some(v))) => {
77 *counter += 1;
78 (*counter, Some(v.clone()))
79 }
80 Some((counter, None)) => {
81 *counter += 1;
82 (*counter, None)
83 }
84 None => {
85 self.map.insert(key.clone(), (0, None));
86 (0, None)
87 }
88 };
89
90 if self.list.len() == self.mem_len {
91 let (key, last_count) = self
92 .list
93 .pop_back()
94 .expect("Cache memory queue should be non-empty at this point");
95 let (counter, val) = self
96 .map
97 .get(&key)
98 .expect("Cache hashmap should contain the key from the memory queue");
99 if *counter == last_count {
100 if val.is_some() {
101 self.size -= 1;
102 }
103 self.map.remove(&key);
104 }
105 }
106 self.list.push_front((key.clone(), counter));
107
108 if ret.is_some() {
109 self.hits += 1;
110 } else {
111 self.misses += 1;
112 }
113
114 ret
115 }
116
117 pub fn pop(&mut self, key: &K) -> Option<Arc<V>> {
123 let Some((_, v)) = self.map.get_mut(key) else { return None };
124 v.take()
125 }
126
127 pub fn insert(&mut self, key: &K, v: V) -> Arc<V> {
132 match self.map.get_mut(key) {
133 None | Some((0, _)) => Arc::new(v),
134 Some((_, Some(val))) => val.clone(),
135 Some((_, val @ None)) => {
136 let v = Arc::new(v);
137 *val = Some(v.clone());
138 self.size += 1;
139 v
140 }
141 }
142 }
143
144 pub fn get_or_insert<F: FnOnce() -> V>(&mut self, key: &K, f: F) -> Arc<V> {
147 self.get(key).unwrap_or_else(|| self.insert(key, f()))
148 }
149
150 pub fn size(&self) -> usize {
152 self.size
153 }
154
155 pub fn mem_len(&self) -> usize {
157 self.mem_len
158 }
159
160 pub fn set_mem_len(&mut self, new_len: usize) {
163 let new_len = new_len.clamp(2, u32::MAX as usize);
165 while self.list.len() > new_len {
167 let (key, last_count) = self
168 .list
169 .pop_back()
170 .expect("Cache memory queue should be non-empty at this point");
171 let (counter, val) = self
172 .map
173 .get(&key)
174 .expect("Cache hashmap should contain the key from the memory queue");
175 if *counter == last_count {
176 if val.is_some() {
177 self.size -= 1;
178 }
179 self.map.remove(&key);
180 }
181 }
182 self.mem_len = new_len;
183 }
184
185 pub fn clear_cache(&mut self) {
187 self.size = 0;
188 self.map.clear();
189 self.list.clear();
190 }
191
192 pub fn hits(&self) -> u64 {
194 self.hits
195 }
196
197 pub fn misses(&self) -> u64 {
199 self.misses
200 }
201
202 pub fn reset_metrics(&mut self) {
204 self.hits = 0;
205 self.misses = 0;
206 }
207}
208
209impl<K: fmt::Debug, V: fmt::Debug, S> fmt::Debug for DynamicCacheLocal<K, V, S> {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 f.debug_struct("DynamicCacheLocal")
212 .field("map", &format!("{} entries", self.map.len()))
213 .field("list", &format!("{} long", self.list.len()))
214 .field("mem_len", &self.mem_len)
215 .field("size", &self.size)
216 .finish()
217 }
218}
219
220#[derive(Clone, Debug)]
228pub struct DynamicCache<K, V, S = RandomState> {
229 cache: Arc<Mutex<DynamicCacheLocal<K, V, S>>>,
230}
231
232impl<K: Clone + Eq + Hash, V, S> DynamicCache<K, V, S> {
233 pub fn with_hasher(mem_len: usize, hash_builder: S) -> DynamicCache<K, V, S> {
236 Self {
237 cache: Arc::new(Mutex::new(DynamicCacheLocal::with_hasher(
238 mem_len,
239 hash_builder,
240 ))),
241 }
242 }
243}
244
245impl<K: Clone + Eq + Hash, V> DynamicCache<K, V> {
246 pub fn new(mem_len: usize) -> Self {
248 Self {
249 cache: Arc::new(Mutex::new(DynamicCacheLocal::new(mem_len))),
250 }
251 }
252
253 pub fn get(&self, key: &K) -> Option<Arc<V>> {
256 self.cache.lock().unwrap().get(key)
257 }
258
259 pub fn pop(&self, key: &K) -> Option<Arc<V>> {
265 self.cache.lock().unwrap().pop(key)
266 }
267
268 pub fn insert(&self, key: &K, value: V) -> Arc<V> {
273 self.cache.lock().unwrap().insert(key, value)
274 }
275
276 pub fn get_or_insert<F: FnOnce() -> V>(&self, key: &K, f: F) -> Arc<V> {
280 self.get(key).unwrap_or_else(|| self.insert(key, f()))
281 }
282
283 pub fn size(&self) -> usize {
285 self.cache.lock().unwrap().size()
286 }
287
288 pub fn mem_len(&self) -> usize {
290 self.cache.lock().unwrap().mem_len()
291 }
292
293 pub fn set_mem_len(&self, new_len: usize) {
296 self.cache.lock().unwrap().set_mem_len(new_len)
297 }
298
299 pub fn clear_cache(&self) {
301 self.cache.lock().unwrap().clear_cache()
302 }
303
304 pub fn hits_misses(&self) -> (u64, u64) {
306 let cache = self.cache.lock().unwrap();
307 (cache.hits(), cache.misses())
308 }
309
310 pub fn reset_metrics(&self) {
312 self.cache.lock().unwrap().reset_metrics()
313 }
314}
315
316#[cfg(test)]
317mod test {
318 use super::*;
319 use rand::prelude::*;
320
321 #[test]
322 fn fetch_test() {
323 let (key, val) = (0, String::from("0"));
324 let cache = DynamicCache::new(8);
325 assert_eq!(cache.size(), 0);
326 assert_eq!(cache.mem_len(), 8);
327 assert_eq!(cache.hits_misses(), (0, 0));
328
329 assert!(
330 cache.get(&key).is_none(),
331 "First `get` should have nothing in cache"
332 );
333 assert_eq!(cache.size(), 0);
334 assert_eq!(cache.hits_misses(), (0, 1));
335
336 assert!(
337 cache.insert(&key, val.clone()).as_ref() == &val,
338 "Insert should return right value"
339 );
340 assert_eq!(cache.size(), 0);
341 assert_eq!(cache.hits_misses(), (0, 1));
342
343 assert!(
344 cache.get(&key).is_none(),
345 "Second `get` should still have nothing in cache"
346 );
347 assert_eq!(cache.size(), 0);
348 assert_eq!(cache.hits_misses(), (0, 2));
349
350 assert!(
351 cache.insert(&key, val.clone()).as_ref() == &val,
352 "Insert should return right value"
353 );
354 assert_eq!(cache.size(), 1);
355 assert_eq!(cache.hits_misses(), (0, 2));
356
357 assert!(
358 cache.get(&key).map_or(false, |x| x.as_ref() == &val),
359 "Third `get` should have a value in cache"
360 );
361 assert_eq!(cache.size(), 1);
362 assert_eq!(cache.hits_misses(), (1, 2));
363
364 assert_eq!(cache.mem_len(), 8);
365 }
366
367 #[test]
368 fn stress_test() {
369 let sample_size = 1 << 12;
370 let cache = DynamicCache::new(128);
371
372 let mut rng = thread_rng();
373
374 let seq: Vec<u16> = vec![
375 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2,
376 ];
377
378 for key in seq {
379 println!("Write {}", key);
380 let val = format!("{}", key);
381 let cache_val = if let Some(v) = cache.get(&key) {
382 println!("Hit");
383 v
384 } else {
385 println!("Miss");
386 cache.insert(&key, val.clone())
387 };
388 assert_eq!(val.as_str(), cache_val.as_str());
389 }
390 println!("Cache size: {}", cache.size());
391
392 for i in (3..=9).rev() {
393 let mut misses = 0;
394 for _ in 0..sample_size {
395 let key: u16 = rng.gen_range(0, 1 << i);
396 let val = format!("{}", key);
397 let cache_val = if let Some(v) = cache.get(&key) {
398 v
399 } else {
400 misses += 1;
401 cache.insert(&key, val.clone())
402 };
403 assert_eq!(val.as_str(), cache_val.as_str());
404 }
405 let hit_rate = 100.0 * f64::from(sample_size - misses) / f64::from(sample_size);
406 println!(
407 "With range of (0..{:3}), Cache size: {:3}, hit rate = {:4.1}%",
408 (1 << i),
409 cache.size(),
410 hit_rate
411 );
412 }
413
414 let mut misses = 0;
415 for _ in 0..sample_size {
416 let key: u16 = rng.gen();
417 let val = format!("{}", key);
418 let cache_val = if let Some(v) = cache.get(&key) {
419 v
420 } else {
421 misses += 1;
422 cache.insert(&key, val.clone())
423 };
424 assert_eq!(val.as_str(), cache_val.as_str());
425 }
426 let hit_rate = 100.0 * f64::from(sample_size - misses) / f64::from(sample_size);
427 println!(
428 "With range of full u16, Cache size: {:3}, hit rate = {:4.1}%",
429 cache.size(),
430 hit_rate
431 );
432
433 let weights: Vec<u32> = vec![16, 8, 4, 2, 1];
434 let dist = rand::distributions::WeightedIndex::new(&weights).unwrap();
435 let mut misses = 0;
436 for _ in 0..sample_size {
437 let is_main = rng.gen_bool(0.5);
438 let key: u16 = if is_main {
439 dist.sample(&mut rng) as u16
440 } else {
441 rng.gen()
442 };
443 let val = format!("{}", key);
444 let cache_val = if let Some(v) = cache.get(&key) {
445 v
446 } else {
447 if is_main {
448 misses += 1;
449 }
450 cache.insert(&key, val.clone())
451 };
452 assert_eq!(val.as_str(), cache_val.as_str());
453 }
454 let hit_rate = 100.0 * f64::from(sample_size - misses) / f64::from(sample_size);
455 println!("Random u16 with log2 frequent requests, Cache size: {:3}, hit rate for main data = {:4.1}%", cache.size(), hit_rate);
456 }
457}