use crate::{shard::VersionedCacheShard, DefaultHashBuilder};
use parking_lot::RwLock;
use std::{
borrow::Borrow,
hash::{BuildHasher, Hash, Hasher},
};
pub struct VersionedCache<Key, Ver, Val, B = DefaultHashBuilder> {
hash_builder: B,
#[allow(clippy::type_complexity)]
shards: Box<[RwLock<VersionedCacheShard<Key, Ver, Val, B>>]>,
shards_mask: usize,
}
impl<Key: Eq + Hash, Ver: Eq + Hash, Val: Clone> VersionedCache<Key, Ver, Val, DefaultHashBuilder> {
pub fn new(initial_capacity: usize, max_capacity: usize) -> Self {
Self::with_hasher(
initial_capacity,
max_capacity,
DefaultHashBuilder::default(),
)
}
}
impl<Key: Eq + Hash, Ver: Eq + Hash, Val: Clone, B: BuildHasher + Clone>
VersionedCache<Key, Ver, Val, B>
{
pub fn with_hasher(initial_capacity: usize, max_capacity: usize, hasher: B) -> Self {
assert!(initial_capacity <= max_capacity);
let mut num_shards = std::thread::available_parallelism()
.map_or(2, |n| n.get() * 2)
.min(max_capacity)
.next_power_of_two();
let mut shard_max_capacity = max_capacity.saturating_add(num_shards - 1) / num_shards;
while shard_max_capacity < 32 && num_shards > 1 {
num_shards /= 2;
shard_max_capacity = max_capacity.saturating_add(num_shards - 1) / num_shards;
}
let shard_initial_capacity = initial_capacity.saturating_add(num_shards - 1) / num_shards;
let shards = (0..num_shards)
.map(|_| {
RwLock::new(VersionedCacheShard::new(
shard_initial_capacity,
shard_max_capacity,
hasher.clone(),
))
})
.collect::<Vec<_>>();
Self {
shards: shards.into_boxed_slice(),
hash_builder: hasher,
shards_mask: num_shards - 1,
}
}
pub fn is_empty(&self) -> bool {
self.shards.iter().any(|s| s.read().len() == 0)
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().len()).sum()
}
pub fn capacity(&self) -> usize {
self.shards.iter().map(|s| s.read().capacity()).sum()
}
pub fn misses(&self) -> u64 {
self.shards.iter().map(|s| s.read().misses()).sum()
}
pub fn hits(&self) -> u64 {
self.shards.iter().map(|s| s.read().hits()).sum()
}
#[allow(clippy::type_complexity)]
#[inline]
fn shard_for<Q: ?Sized, W: ?Sized>(
&self,
key: &Q,
version: &W,
) -> Option<(&RwLock<VersionedCacheShard<Key, Ver, Val, B>>, u64)>
where
Key: Borrow<Q>,
Q: Hash + Eq,
Ver: Borrow<W>,
W: Hash + Eq,
{
let mut hasher = self.hash_builder.build_hasher();
key.hash(&mut hasher);
version.hash(&mut hasher);
let hash = hasher.finish();
self.shards
.get(hash as usize & self.shards_mask)
.map(|s| (s, hash))
}
pub fn get<Q: ?Sized, W: ?Sized>(&self, key: &Q, version: &W) -> Option<Val>
where
Key: Borrow<Q>,
Q: Hash + Eq,
Ver: Borrow<W>,
W: Hash + Eq,
{
let (shard, hash) = self.shard_for(key, version)?;
shard.read().get(hash, key, version).cloned()
}
pub fn peek<Q: ?Sized, W: ?Sized>(&self, key: &Q, version: &W) -> Option<Val>
where
Key: Borrow<Q>,
Q: Hash + Eq,
Ver: Borrow<W>,
W: Hash + Eq,
{
let (shard, hash) = self.shard_for(key, version)?;
shard.read().peek(hash, key, version).cloned()
}
pub fn remove<Q: ?Sized, W: ?Sized>(&self, key: &Q, version: &W) -> bool
where
Key: Borrow<Q>,
Q: Hash + Eq,
Ver: Borrow<W>,
W: Hash + Eq,
{
if let Some((shard, hash)) = self.shard_for(key, version) {
let evicted = shard.write().remove(hash, key, version);
matches!(evicted, Some(Ok(_)))
} else {
false
}
}
pub fn insert(&self, key: Key, version: Ver, value: Val) {
if let Some((shard, hash)) = self.shard_for(&key, &version) {
let _evicted = shard.write().insert(hash, key, version, value);
}
}
}
impl<Key: Eq + Hash, Ver: Eq + Hash, Val: Clone> std::fmt::Debug for VersionedCache<Key, Ver, Val> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VersionedCache").finish_non_exhaustive()
}
}
pub struct Cache<Key, Val, B = DefaultHashBuilder>(VersionedCache<Key, (), Val, B>);
impl<Key: Eq + Hash, Val: Clone> Cache<Key, Val, DefaultHashBuilder> {
pub fn new(initial_capacity: usize, max_capacity: usize) -> Self {
Self(VersionedCache::new(initial_capacity, max_capacity))
}
}
impl<Key: Eq + Hash, Val: Clone, B: Clone + BuildHasher> Cache<Key, Val, B> {
pub fn with_hasher(initial_capacity: usize, max_capacity: usize, hash_builder: B) -> Self {
Self(VersionedCache::with_hasher(
initial_capacity,
max_capacity,
hash_builder,
))
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn capacity(&self) -> usize {
self.0.capacity()
}
pub fn misses(&self) -> u64 {
self.0.misses()
}
pub fn hits(&self) -> u64 {
self.0.hits()
}
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<Val>
where
Key: Borrow<Q>,
Q: Eq + Hash,
{
self.0.get(key, &())
}
pub fn peek<Q: ?Sized>(&self, key: &Q) -> Option<Val>
where
Key: Borrow<Q>,
Q: Eq + Hash,
{
self.0.peek(key, &())
}
pub fn remove<Q: ?Sized>(&self, key: &Q) -> bool
where
Key: Borrow<Q>,
Q: Eq + Hash,
{
self.0.remove(key, &())
}
pub fn insert(&self, key: Key, value: Val) {
self.0.insert(key, (), value);
}
}
impl<Key: Eq + Hash, Val: Clone> std::fmt::Debug for Cache<Key, Val> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Cache").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
sync::{Arc, Barrier},
thread,
};
#[test]
fn test_multiple_threads() {
const N_THREAD_PAIRS: usize = 8;
const N_ROUNDS: usize = 1_000;
const ITEMS_PER_THREAD: usize = 1_000;
let mut threads = Vec::new();
let barrier = Arc::new(Barrier::new(N_THREAD_PAIRS * 2));
let cache = Arc::new(Cache::new(0, N_THREAD_PAIRS * ITEMS_PER_THREAD / 10));
for t in 0..N_THREAD_PAIRS {
let barrier = barrier.clone();
let cache = cache.clone();
let handle = thread::spawn(move || {
let start = ITEMS_PER_THREAD * t;
barrier.wait();
for _round in 0..N_ROUNDS {
for i in start..start + ITEMS_PER_THREAD {
cache.insert(i, i);
}
}
});
threads.push(handle);
}
for t in 0..N_THREAD_PAIRS {
let barrier = barrier.clone();
let cache = cache.clone();
let handle = thread::spawn(move || {
let start = ITEMS_PER_THREAD * t;
barrier.wait();
for _round in 0..N_ROUNDS {
for i in start..start + ITEMS_PER_THREAD {
if let Some(cached) = cache.get(&i) {
assert_eq!(cached, i);
}
}
}
});
threads.push(handle);
}
for t in threads {
t.join().unwrap();
}
}
}