use std::{
collections::hash_map::RandomState,
fmt,
hash::{self, BuildHasher, Hash},
};
use hash::Hasher;
use crate::unsharded;
pub struct CacheHandle<'a, K, V>(unsharded::CacheHandle<'a, K, V>)
where
K: Hash + Eq + Clone;
impl<'a, K, V> CacheHandle<'a, K, V>
where
K: Hash + Eq + Clone,
{
pub fn value(&self) -> &V {
self.0.value()
}
}
pub struct LruCache<K, V, S = RandomState> {
shards: Vec<unsharded::LruCache<K, V>>,
hasher: S,
}
impl<K, V> fmt::Debug for LruCache<K, V>
where
K: fmt::Debug,
V: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.shards.iter()).finish()
}
}
unsafe impl<K, V> Send for LruCache<K, V> {}
unsafe impl<K, V> Sync for LruCache<K, V> {}
fn default_shards() -> usize {
16
}
impl<K, V, S> LruCache<K, V, S>
where
K: Send + Sync + Hash + Eq + Clone,
V: Send + Sync,
S: BuildHasher,
{
pub fn with_shards_hasher(capacity: u64, shards: usize, hasher: S) -> Self {
let shards = shards as u64;
let cap_per_shard = (capacity + shards - 1) / shards; Self {
hasher,
shards: (0..shards)
.map(|_| unsharded::LruCache::new(cap_per_shard))
.collect(),
}
}
}
impl<K, V> LruCache<K, V, RandomState>
where
K: Send + Sync + Hash + Eq + Clone,
V: Send + Sync,
{
pub fn new(capacity: u64) -> Self {
Self::with_shards_hasher(capacity, default_shards(), RandomState::default())
}
}
impl<K, V, S> LruCache<K, V, S>
where
K: Hash + Eq + Clone,
S: BuildHasher,
{
fn shard(&self, key: &K) -> &unsharded::LruCache<K, V> {
let mut hasher = self.hasher.build_hasher();
key.hash(&mut hasher);
let h = hasher.finish() as usize;
let shard_idx = h % self.shards.len();
&self.shards[shard_idx]
}
pub fn advice_evict(&self, key: K) {
self.shard(&key).advice_evict(key)
}
pub fn prune(&self) {
for s in &self.shards {
s.prune();
}
}
pub fn total_charge(&self) -> u64 {
self.shards.iter().map(|s| s.total_charge()).sum()
}
pub fn get(&self, key: K) -> Option<CacheHandle<'_, K, V>> {
self.shard(&key).get(key).map(|h| CacheHandle(h))
}
pub fn get_or_try_init<E>(
&self,
key: K,
charge: u64,
init: impl FnOnce(&K) -> Result<V, E>,
) -> Result<CacheHandle<'_, K, V>, E> {
self.shard(&key)
.get_or_try_init(key, charge, init)
.map(|h| CacheHandle(h))
}
pub fn get_or_init(
&self,
key: K,
charge: u64,
init: impl FnOnce(&K) -> V,
) -> CacheHandle<'_, K, V> {
CacheHandle(self.shard(&key).get_or_init(key, charge, init))
}
}
mod compile_time_assertions {
use super::*;
#[allow(unreachable_code)]
fn _assert_public_types_send_sync() {
_assert_send_sync::<LruCache<u32, u32>>(unreachable!());
_assert_send_sync::<CacheHandle<u32, u32>>(unreachable!());
}
fn _assert_send<S: Send>(_: &S) {}
fn _assert_send_sync<S: Send + Sync>(_: &S) {}
}
#[cfg(test)]
mod tests {
use crate::override_lifetime;
use super::*;
use rand::{distributions::Uniform, prelude::*};
use std::{
sync::atomic::{AtomicU64, Ordering},
thread,
};
#[test]
#[cfg_attr(miri, ignore)]
fn sharded_stress() {
struct IncCounterOnDrop<'a> {
charge: u64,
counter: &'a AtomicU64,
}
impl<'a> Drop for IncCounterOnDrop<'a> {
fn drop(&mut self) {
self.counter.fetch_add(self.charge, Ordering::Relaxed);
}
}
let capacity = 128;
let threads = 8;
let per_thread_count = 10000;
let yield_interval = 1000;
let init_charge = AtomicU64::new(0);
let drop_charge = AtomicU64::new(0);
let lru = LruCache::new(capacity);
let mut handles = vec![];
for _ in 0..threads {
handles.push(thread::spawn({
let lru = unsafe { override_lifetime(&lru) };
let init_counter = unsafe { override_lifetime(&init_charge) };
let drop_counter = unsafe { override_lifetime(&drop_charge) };
move || {
let mut rng = StdRng::from_entropy();
for _ in 0..per_thread_count {
let i = rng.sample(Uniform::new(0, 100));
let charge = rng.sample(Uniform::new(1, 5));
let fail = rng.sample(Uniform::new(0, 10)) >= 8;
let res = lru.get_or_try_init(i, charge, |_| {
if fail {
Err(())
} else {
init_counter.fetch_add(charge, Ordering::Relaxed);
Ok(IncCounterOnDrop {
charge,
counter: &drop_counter,
})
}
});
if !fail {
assert!(res.is_ok());
}
if i % yield_interval == 0 {
thread::yield_now();
}
}
}
}));
}
for h in handles {
h.join().unwrap();
}
assert!(lru.total_charge() <= capacity);
assert_eq!(
init_charge.load(Ordering::Relaxed),
lru.total_charge() + drop_charge.load(Ordering::Relaxed)
);
lru.prune();
assert_eq!(
init_charge.load(Ordering::Relaxed),
drop_charge.load(Ordering::Relaxed)
);
}
}