use std::sync::Mutex;
use std::collections::hash_map::RandomState;
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
pub struct DynamicCacheLocal<K, V, S = RandomState> {
map: HashMap<K, (u32, Option<Arc<V>>), S>,
list: VecDeque<(K, u32)>,
mem_len: usize,
size: usize,
hits: u64,
misses: u64,
}
impl<K: Clone + Eq + Hash, V, S> DynamicCacheLocal<K, V, S> {
pub fn with_hasher(mem_len: usize, hash_builder: S) -> DynamicCacheLocal<K, V, S> {
let mem_len = mem_len.clamp(2, u32::MAX as usize);
Self {
map: HashMap::with_hasher(hash_builder),
list: VecDeque::with_capacity(mem_len),
mem_len,
size: 0,
hits: 0,
misses: 0,
}
}
}
impl<K: Clone + Eq + Hash, V> DynamicCacheLocal<K, V> {
pub fn new(mem_len: usize) -> Self {
let mem_len = mem_len.clamp(2, u32::MAX as usize);
Self {
map: HashMap::new(),
list: VecDeque::with_capacity(mem_len),
mem_len,
size: 0,
hits: 0,
misses: 0,
}
}
pub fn get(&mut self, key: &K) -> Option<Arc<V>> {
let (counter, ret) = match self.map.get_mut(key) {
Some((counter, Some(v))) => {
*counter += 1;
(*counter, Some(v.clone()))
}
Some((counter, None)) => {
*counter += 1;
(*counter, None)
}
None => {
self.map.insert(key.clone(), (0, None));
(0, None)
}
};
if self.list.len() == self.mem_len {
let (key, last_count) = self
.list
.pop_back()
.expect("Cache memory queue should be non-empty at this point");
let (counter, val) = self
.map
.get(&key)
.expect("Cache hashmap should contain the key from the memory queue");
if *counter == last_count {
if val.is_some() {
self.size -= 1;
}
self.map.remove(&key);
}
}
self.list.push_front((key.clone(), counter));
if ret.is_some() {
self.hits += 1;
} else {
self.misses += 1;
}
ret
}
pub fn pop(&mut self, key: &K) -> Option<Arc<V>> {
let Some((_, v)) = self.map.get_mut(key) else { return None };
v.take()
}
pub fn insert(&mut self, key: &K, v: V) -> Arc<V> {
match self.map.get_mut(key) {
None | Some((0, _)) => Arc::new(v),
Some((_, Some(val))) => val.clone(),
Some((_, val @ None)) => {
let v = Arc::new(v);
*val = Some(v.clone());
self.size += 1;
v
}
}
}
pub fn get_or_insert<F: FnOnce() -> V>(&mut self, key: &K, f: F) -> Arc<V> {
self.get(key).unwrap_or_else(|| self.insert(key, f()))
}
pub fn size(&self) -> usize {
self.size
}
pub fn mem_len(&self) -> usize {
self.mem_len
}
pub fn set_mem_len(&mut self, new_len: usize) {
let new_len = new_len.clamp(2, u32::MAX as usize);
while self.list.len() > new_len {
let (key, last_count) = self
.list
.pop_back()
.expect("Cache memory queue should be non-empty at this point");
let (counter, val) = self
.map
.get(&key)
.expect("Cache hashmap should contain the key from the memory queue");
if *counter == last_count {
if val.is_some() {
self.size -= 1;
}
self.map.remove(&key);
}
}
self.mem_len = new_len;
}
pub fn clear_cache(&mut self) {
self.size = 0;
self.map.clear();
self.list.clear();
}
pub fn hits(&self) -> u64 {
self.hits
}
pub fn misses(&self) -> u64 {
self.misses
}
pub fn reset_metrics(&mut self) {
self.hits = 0;
self.misses = 0;
}
}
impl<K: fmt::Debug, V: fmt::Debug, S> fmt::Debug for DynamicCacheLocal<K, V, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DynamicCacheLocal")
.field("map", &format!("{} entries", self.map.len()))
.field("list", &format!("{} long", self.list.len()))
.field("mem_len", &self.mem_len)
.field("size", &self.size)
.finish()
}
}
#[derive(Clone, Debug)]
pub struct DynamicCache<K, V, S = RandomState> {
cache: Arc<Mutex<DynamicCacheLocal<K, V, S>>>,
}
impl<K: Clone + Eq + Hash, V, S> DynamicCache<K, V, S> {
pub fn with_hasher(mem_len: usize, hash_builder: S) -> DynamicCache<K, V, S> {
Self {
cache: Arc::new(Mutex::new(DynamicCacheLocal::with_hasher(
mem_len,
hash_builder,
))),
}
}
}
impl<K: Clone + Eq + Hash, V> DynamicCache<K, V> {
pub fn new(mem_len: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(DynamicCacheLocal::new(mem_len))),
}
}
pub fn get(&self, key: &K) -> Option<Arc<V>> {
self.cache.lock().unwrap().get(key)
}
pub fn pop(&self, key: &K) -> Option<Arc<V>> {
self.cache.lock().unwrap().pop(key)
}
pub fn insert(&self, key: &K, value: V) -> Arc<V> {
self.cache.lock().unwrap().insert(key, value)
}
pub fn get_or_insert<F: FnOnce() -> V>(&self, key: &K, f: F) -> Arc<V> {
self.get(key).unwrap_or_else(|| self.insert(key, f()))
}
pub fn size(&self) -> usize {
self.cache.lock().unwrap().size()
}
pub fn mem_len(&self) -> usize {
self.cache.lock().unwrap().mem_len()
}
pub fn set_mem_len(&self, new_len: usize) {
self.cache.lock().unwrap().set_mem_len(new_len)
}
pub fn clear_cache(&self) {
self.cache.lock().unwrap().clear_cache()
}
pub fn hits_misses(&self) -> (u64, u64) {
let cache = self.cache.lock().unwrap();
(cache.hits(), cache.misses())
}
pub fn reset_metrics(&self) {
self.cache.lock().unwrap().reset_metrics()
}
}
#[cfg(test)]
mod test {
use super::*;
use rand::prelude::*;
#[test]
fn fetch_test() {
let (key, val) = (0, String::from("0"));
let cache = DynamicCache::new(8);
assert_eq!(cache.size(), 0);
assert_eq!(cache.mem_len(), 8);
assert_eq!(cache.hits_misses(), (0, 0));
assert!(
cache.get(&key).is_none(),
"First `get` should have nothing in cache"
);
assert_eq!(cache.size(), 0);
assert_eq!(cache.hits_misses(), (0, 1));
assert!(
cache.insert(&key, val.clone()).as_ref() == &val,
"Insert should return right value"
);
assert_eq!(cache.size(), 0);
assert_eq!(cache.hits_misses(), (0, 1));
assert!(
cache.get(&key).is_none(),
"Second `get` should still have nothing in cache"
);
assert_eq!(cache.size(), 0);
assert_eq!(cache.hits_misses(), (0, 2));
assert!(
cache.insert(&key, val.clone()).as_ref() == &val,
"Insert should return right value"
);
assert_eq!(cache.size(), 1);
assert_eq!(cache.hits_misses(), (0, 2));
assert!(
cache.get(&key).map_or(false, |x| x.as_ref() == &val),
"Third `get` should have a value in cache"
);
assert_eq!(cache.size(), 1);
assert_eq!(cache.hits_misses(), (1, 2));
assert_eq!(cache.mem_len(), 8);
}
#[test]
fn stress_test() {
let sample_size = 1 << 12;
let cache = DynamicCache::new(128);
let mut rng = thread_rng();
let seq: Vec<u16> = vec![
0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2,
];
for key in seq {
println!("Write {}", key);
let val = format!("{}", key);
let cache_val = if let Some(v) = cache.get(&key) {
println!("Hit");
v
} else {
println!("Miss");
cache.insert(&key, val.clone())
};
assert_eq!(val.as_str(), cache_val.as_str());
}
println!("Cache size: {}", cache.size());
for i in (3..=9).rev() {
let mut misses = 0;
for _ in 0..sample_size {
let key: u16 = rng.gen_range(0, 1 << i);
let val = format!("{}", key);
let cache_val = if let Some(v) = cache.get(&key) {
v
} else {
misses += 1;
cache.insert(&key, val.clone())
};
assert_eq!(val.as_str(), cache_val.as_str());
}
let hit_rate = 100.0 * f64::from(sample_size - misses) / f64::from(sample_size);
println!(
"With range of (0..{:3}), Cache size: {:3}, hit rate = {:4.1}%",
(1 << i),
cache.size(),
hit_rate
);
}
let mut misses = 0;
for _ in 0..sample_size {
let key: u16 = rng.gen();
let val = format!("{}", key);
let cache_val = if let Some(v) = cache.get(&key) {
v
} else {
misses += 1;
cache.insert(&key, val.clone())
};
assert_eq!(val.as_str(), cache_val.as_str());
}
let hit_rate = 100.0 * f64::from(sample_size - misses) / f64::from(sample_size);
println!(
"With range of full u16, Cache size: {:3}, hit rate = {:4.1}%",
cache.size(),
hit_rate
);
let weights: Vec<u32> = vec![16, 8, 4, 2, 1];
let dist = rand::distributions::WeightedIndex::new(&weights).unwrap();
let mut misses = 0;
for _ in 0..sample_size {
let is_main = rng.gen_bool(0.5);
let key: u16 = if is_main {
dist.sample(&mut rng) as u16
} else {
rng.gen()
};
let val = format!("{}", key);
let cache_val = if let Some(v) = cache.get(&key) {
v
} else {
if is_main {
misses += 1;
}
cache.insert(&key, val.clone())
};
assert_eq!(val.as_str(), cache_val.as_str());
}
let hit_rate = 100.0 * f64::from(sample_size - misses) / f64::from(sample_size);
println!("Random u16 with log2 frequent requests, Cache size: {:3}, hit rate for main data = {:4.1}%", cache.size(), hit_rate);
}
}