use crate::{shard::KQCacheShard, DefaultHashBuilder, UnitWeighter, Weighter};
use parking_lot::RwLock;
use std::{
borrow::Borrow,
hash::{BuildHasher, Hash, Hasher},
};
pub struct KQCache<Key, Qey, Val, We = UnitWeighter, B = DefaultHashBuilder> {
hash_builder: B,
#[allow(clippy::type_complexity)]
shards: Box<[RwLock<KQCacheShard<Key, Qey, Val, We, B>>]>,
shards_mask: u64,
}
impl<Key: Eq + Hash, Qey: Eq + Hash, Val: Clone>
KQCache<Key, Qey, Val, UnitWeighter, DefaultHashBuilder>
{
pub fn new(items_capacity: usize) -> Self {
Self::with(
items_capacity,
items_capacity as u64,
UnitWeighter,
DefaultHashBuilder::default(),
)
}
}
impl<Key: Eq + Hash, Qey: Eq + Hash, Val: Clone, We: Weighter<Key, Qey, Val> + Clone>
KQCache<Key, Qey, Val, We, DefaultHashBuilder>
{
pub fn with_weighter(
estimated_items_capacity: usize,
weight_capacity: u64,
weighter: We,
) -> KQCache<Key, Qey, Val, We, DefaultHashBuilder> {
Self::with(
estimated_items_capacity,
weight_capacity,
weighter,
DefaultHashBuilder::default(),
)
}
}
impl<
Key: Eq + Hash,
Qey: Eq + Hash,
Val: Clone,
We: Weighter<Key, Qey, Val> + Clone,
B: BuildHasher + Clone,
> KQCache<Key, Qey, Val, We, B>
{
pub fn with(
estimated_items_capacity: usize,
weight_capacity: u64,
weighter: We,
hasher: B,
) -> Self {
let mut num_shards = std::thread::available_parallelism()
.map_or(4, |n| n.get() * 4)
.min(estimated_items_capacity)
.next_power_of_two() as u64;
let estimated_items_capacity = estimated_items_capacity as u64;
let mut shard_items_capacity =
estimated_items_capacity.saturating_add(num_shards - 1) / num_shards;
let mut shard_max_weight = weight_capacity.saturating_add(num_shards - 1) / num_shards;
while shard_items_capacity < 32 && num_shards > 1 {
num_shards /= 2;
shard_items_capacity =
estimated_items_capacity.saturating_add(num_shards - 1) / num_shards;
shard_max_weight = weight_capacity.saturating_add(num_shards - 1) / num_shards;
}
let shards = (0..num_shards)
.map(|_| {
RwLock::new(KQCacheShard::new(
shard_items_capacity as usize,
shard_max_weight,
weighter.clone(),
hasher.clone(),
))
})
.collect::<Vec<_>>();
Self {
shards: shards.into_boxed_slice(),
hash_builder: hasher,
shards_mask: num_shards - 1,
}
}
#[cfg(fuzzing)]
pub fn validate(&self) {
for s in &*self.shards {
s.read().validate()
}
}
pub fn is_empty(&self) -> bool {
self.shards.iter().any(|s| s.read().len() == 0)
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().len()).sum()
}
pub fn weight(&self) -> u64 {
self.shards.iter().map(|s| s.read().weight()).sum()
}
pub fn capacity(&self) -> u64 {
self.shards.iter().map(|s| s.read().capacity()).sum()
}
pub fn misses(&self) -> u64 {
self.shards.iter().map(|s| s.read().misses()).sum()
}
pub fn hits(&self) -> u64 {
self.shards.iter().map(|s| s.read().hits()).sum()
}
#[allow(clippy::type_complexity)]
#[inline]
fn shard_for<Q: ?Sized, W: ?Sized>(
&self,
key: &Q,
qey: &W,
) -> Option<(&RwLock<KQCacheShard<Key, Qey, Val, We, B>>, u64)>
where
Key: Borrow<Q>,
Q: Hash + Eq,
Qey: Borrow<W>,
W: Hash + Eq,
{
let mut hasher = self.hash_builder.build_hasher();
key.hash(&mut hasher);
qey.hash(&mut hasher);
let hash = hasher.finish();
self.shards
.get((hash & self.shards_mask) as usize)
.map(|s| (s, hash))
}
pub fn reserve(&mut self, additional: usize) {
let additional_per_shard =
additional.saturating_add(self.shards.len() - 1) / self.shards.len();
for s in &*self.shards {
s.write().reserve(additional_per_shard);
}
}
pub fn get<Q: ?Sized, W: ?Sized>(&self, key: &Q, qey: &W) -> Option<Val>
where
Key: Borrow<Q>,
Q: Hash + Eq,
Qey: Borrow<W>,
W: Hash + Eq,
{
let (shard, hash) = self.shard_for(key, qey)?;
shard.read().get(hash, key, qey).cloned()
}
pub fn peek<Q: ?Sized, W: ?Sized>(&self, key: &Q, qey: &W) -> Option<Val>
where
Key: Borrow<Q>,
Q: Hash + Eq,
Qey: Borrow<W>,
W: Hash + Eq,
{
let (shard, hash) = self.shard_for(key, qey)?;
shard.read().peek(hash, key, qey).cloned()
}
pub fn remove<Q: ?Sized, W: ?Sized>(&self, key: &Q, qey: &W) -> bool
where
Key: Borrow<Q>,
Q: Hash + Eq,
Qey: Borrow<W>,
W: Hash + Eq,
{
if let Some((shard, hash)) = self.shard_for(key, qey) {
let evicted = shard.write().remove(hash, key, qey);
matches!(evicted, Some(Ok(_)))
} else {
false
}
}
pub fn insert(&self, key: Key, qey: Qey, value: Val) {
if let Some((shard, hash)) = self.shard_for(&key, &qey) {
let _evicted = shard.write().insert(hash, key, qey, value);
}
}
}
impl<Key: Eq + Hash, Qey: Eq + Hash, Val: Clone> std::fmt::Debug for KQCache<Key, Qey, Val> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KQCache").finish_non_exhaustive()
}
}
pub struct Cache<Key, Val, We = UnitWeighter, B = DefaultHashBuilder>(KQCache<Key, (), Val, We, B>);
impl<Key: Eq + Hash, Val: Clone> Cache<Key, Val, UnitWeighter, DefaultHashBuilder> {
pub fn new(items_capacity: usize) -> Self {
Self(KQCache::new(items_capacity))
}
}
impl<Key: Eq + Hash, Val: Clone, We: Weighter<Key, (), Val> + Clone>
Cache<Key, Val, We, DefaultHashBuilder>
{
pub fn with_weighter(
estimated_items_capacity: usize,
weight_capacity: u64,
weighter: We,
) -> Cache<Key, Val, We, DefaultHashBuilder> {
Self::with(
estimated_items_capacity,
weight_capacity,
weighter,
DefaultHashBuilder::default(),
)
}
}
impl<Key: Eq + Hash, Val: Clone, We: Weighter<Key, (), Val> + Clone, B: BuildHasher + Clone>
Cache<Key, Val, We, B>
{
pub fn with(
estimated_items_capacity: usize,
weight_capacity: u64,
weighter: We,
hash_builder: B,
) -> Self {
Self(KQCache::with(
estimated_items_capacity,
weight_capacity,
weighter,
hash_builder,
))
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn weight(&self) -> u64 {
self.0.weight()
}
pub fn capacity(&self) -> u64 {
self.0.capacity()
}
pub fn misses(&self) -> u64 {
self.0.misses()
}
pub fn hits(&self) -> u64 {
self.0.hits()
}
pub fn reserve(&mut self, additional: usize) {
self.0.reserve(additional)
}
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<Val>
where
Key: Borrow<Q>,
Q: Eq + Hash,
{
self.0.get(key, &())
}
pub fn peek<Q: ?Sized>(&self, key: &Q) -> Option<Val>
where
Key: Borrow<Q>,
Q: Eq + Hash,
{
self.0.peek(key, &())
}
pub fn remove<Q: ?Sized>(&self, key: &Q) -> bool
where
Key: Borrow<Q>,
Q: Eq + Hash,
{
self.0.remove(key, &())
}
pub fn insert(&self, key: Key, value: Val) {
self.0.insert(key, (), value);
}
}
impl<Key: Eq + Hash, Val: Clone> std::fmt::Debug for Cache<Key, Val> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Cache").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
sync::{Arc, Barrier},
thread,
};
#[test]
fn test_multiple_threads() {
const N_THREAD_PAIRS: usize = 8;
const N_ROUNDS: usize = 1_000;
const ITEMS_PER_THREAD: usize = 1_000;
let mut threads = Vec::new();
let barrier = Arc::new(Barrier::new(N_THREAD_PAIRS * 2));
let cache = Arc::new(Cache::new(N_THREAD_PAIRS * ITEMS_PER_THREAD / 10));
for t in 0..N_THREAD_PAIRS {
let barrier = barrier.clone();
let cache = cache.clone();
let handle = thread::spawn(move || {
let start = ITEMS_PER_THREAD * t;
barrier.wait();
for _round in 0..N_ROUNDS {
for i in start..start + ITEMS_PER_THREAD {
cache.insert(i, i);
}
}
});
threads.push(handle);
}
for t in 0..N_THREAD_PAIRS {
let barrier = barrier.clone();
let cache = cache.clone();
let handle = thread::spawn(move || {
let start = ITEMS_PER_THREAD * t;
barrier.wait();
for _round in 0..N_ROUNDS {
for i in start..start + ITEMS_PER_THREAD {
if let Some(cached) = cache.get(&i) {
assert_eq!(cached, i);
}
}
}
});
threads.push(handle);
}
for t in threads {
t.join().unwrap();
}
}
}