use crate::Stats;
use parking_lot::{Mutex, RwLock};
use shard::Shard;
use std::borrow::Borrow;
use std::hash::{BuildHasher, Hash};
use std::num::NonZero;
use std::time::Instant;
use std::{cmp, thread};
mod entry;
mod fixed_size_hash_table;
mod ring_buffer;
mod shard;
pub(crate) mod stats;
pub(crate) type RandomState = ahash::RandomState;
#[derive(Debug)]
pub struct Cache<K, V, S = RandomState> {
hash_builder: S,
shards: Vec<RwLock<Shard<K, V, S>>>,
metrics_last_accessed: Mutex<Instant>,
}
impl<K, V> Cache<K, V, RandomState>
where
K: Clone + Eq + Hash,
V: Clone,
{
pub fn with_capacity(capacity: usize) -> Cache<K, V, RandomState> {
Cache::with_capacity_and_hasher(capacity, Default::default())
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Clone + Eq + Hash,
V: Clone,
S: BuildHasher,
{
pub fn insert(&self, key: K, value: V) -> Option<V> {
let hash = self.hash_builder.hash_one(&key);
let shard_lock = self.get_shard(hash)?;
let mut shard = shard_lock.write();
shard.insert(key, value)
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let hash = self.hash_builder.hash_one(key);
let shard_lock = self.get_shard(hash)?;
let shard = shard_lock.read();
shard.get(key)
}
fn get_shard(&self, hash: u64) -> Option<&RwLock<Shard<K, V, S>>> {
let shard_idx = hash as usize % (cmp::max(self.shards.len(), 2) - 1);
self.shards.get(shard_idx)
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Clone + Eq + Hash,
V: Clone,
S: Clone + BuildHasher,
{
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Cache<K, V, S> {
let available_parallelism = thread::available_parallelism()
.map(NonZero::get)
.unwrap_or(1);
let number_of_shards = cmp::min(available_parallelism * 4, capacity);
let mut shards = Vec::with_capacity(number_of_shards);
let metrics_last_accessed = Mutex::new(Instant::now());
if number_of_shards == 0 {
return Self {
hash_builder,
shards,
metrics_last_accessed,
};
}
let capacity_per_shard = capacity.div_ceil(number_of_shards);
for _ in 0..number_of_shards {
let shard = Shard::with_capacity_and_hasher(capacity_per_shard, hash_builder.clone());
shards.push(RwLock::new(shard))
}
Self {
hash_builder,
shards,
metrics_last_accessed,
}
}
}
impl<K, V, S> Cache<K, V, S> {
pub fn stats(&self) -> Stats {
let mut stats = Stats::default();
let millis_elapsed = {
let mut guard = self.metrics_last_accessed.lock();
let millis_elapsed = guard.elapsed().as_millis();
*guard = Instant::now();
millis_elapsed
};
stats.millis_elapsed = millis_elapsed;
for shard in &self.shards {
let shard = shard.read();
stats.hit_count += shard.hit_count();
stats.miss_count += shard.miss_count();
stats.eviction_count += shard.eviction_count();
shard.reset_counters();
}
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn it_inserts_and_gets_basic_values() {
let cache = Cache::with_capacity(100);
cache.insert("key1", "value1");
assert_eq!(cache.get("key1"), Some("value1"));
assert_eq!(cache.get("key2"), None);
}
#[test]
fn it_updates_existing_value() {
let cache = Cache::with_capacity(100);
cache.insert("key1", "value1");
let old_value = cache.insert("key1", "new_value");
assert_eq!(old_value, Some("value1"));
assert_eq!(cache.get("key1"), Some("new_value"));
}
#[test]
fn it_handles_zero_capacity() {
let cache = Cache::with_capacity(0);
cache.insert("key1", "value1");
assert_eq!(cache.get("key1"), None);
}
#[test]
fn it_handles_one_capacity() {
let cache = Cache::with_capacity(1);
cache.insert("key1", "value1");
assert_eq!(cache.get("key1"), Some("value1"));
assert_eq!(cache.get("key2"), None);
}
#[test]
fn it_works_with_custom_hasher() {
use std::collections::hash_map::RandomState;
let cache = Cache::with_capacity_and_hasher(100, RandomState::new());
cache.insert("key1", "value1");
assert_eq!(cache.get("key1"), Some("value1"));
}
#[test]
fn it_is_thread_safe() {
let cache: Arc<Cache<String, String>> = Arc::new(Cache::with_capacity(1_000));
let mut handles = vec![];
for i in 0..5 {
let cache_clone = Arc::clone(&cache);
let key = format!("key{}", i);
let value = format!("value{}", i);
let handle = thread::spawn(move || {
cache_clone.insert(key.clone(), value.clone());
assert_eq!(cache_clone.get(&key), Some(value));
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
for i in 0..5 {
let key = format!("key{}", i);
let value = format!("value{}", i);
assert_eq!(cache.get(&key), Some(value));
}
}
#[test]
fn it_respects_capacity_limits() {
let cache = Cache::with_capacity(2);
cache.insert("key1", "value1");
cache.insert("key2", "value2");
cache.insert("key3", "value3");
cache.insert("key4", "value4");
assert_eq!(cache.get("key1"), None);
}
#[test]
fn it_returns_and_resets_stats() {
let cache = Cache::with_capacity(1_000);
for i in 0..10 {
cache.insert(i, i);
}
for i in 0..5 {
cache.get(&i);
}
for i in 10..15 {
cache.get(&i);
}
let stats = cache.stats();
assert_eq!(stats.hit_count, 5);
assert_eq!(stats.miss_count, 5);
let stats = cache.stats();
assert_eq!(stats.hit_count, 0);
assert_eq!(stats.miss_count, 0);
}
}