use crate::{
linked_slab::Token,
options::*,
shard::{self, CacheShard, InsertStrategy},
DefaultHashBuilder, Equivalent, Lifecycle, MemoryUsed, UnitWeighter, Weighter,
};
use std::hash::{BuildHasher, Hash};
#[derive(Clone)]
pub struct Cache<
Key,
Val,
We = UnitWeighter,
B = DefaultHashBuilder,
L = DefaultLifecycle<Key, Val>,
> {
shard: CacheShard<Key, Val, We, B, L, SharedPlaceholder>,
}
impl<Key: Eq + Hash, Val> Cache<Key, Val> {
pub fn new(items_capacity: usize) -> Self {
Self::with(
items_capacity,
items_capacity as u64,
Default::default(),
Default::default(),
Default::default(),
)
}
}
impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>> Cache<Key, Val, We> {
pub fn with_weighter(
estimated_items_capacity: usize,
weight_capacity: u64,
weighter: We,
) -> Self {
Self::with(
estimated_items_capacity,
weight_capacity,
weighter,
Default::default(),
Default::default(),
)
}
}
impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>, B: BuildHasher, L: Lifecycle<Key, Val>>
Cache<Key, Val, We, B, L>
{
pub fn with(
estimated_items_capacity: usize,
weight_capacity: u64,
weighter: We,
hash_builder: B,
lifecycle: L,
) -> Self {
Self::with_options(
OptionsBuilder::new()
.estimated_items_capacity(estimated_items_capacity)
.weight_capacity(weight_capacity)
.build()
.unwrap(),
weighter,
hash_builder,
lifecycle,
)
}
pub fn with_options(options: Options, weighter: We, hash_builder: B, lifecycle: L) -> Self {
let shard = CacheShard::new(
options.hot_allocation,
options.ghost_allocation,
options.estimated_items_capacity,
options.weight_capacity,
weighter,
hash_builder,
lifecycle,
);
Self { shard }
}
pub fn is_empty(&self) -> bool {
self.shard.len() == 0
}
pub fn len(&self) -> usize {
self.shard.len()
}
pub fn weight(&self) -> u64 {
self.shard.weight()
}
pub fn capacity(&self) -> u64 {
self.shard.capacity()
}
#[cfg(feature = "stats")]
pub fn misses(&self) -> u64 {
self.shard.misses()
}
#[cfg(feature = "stats")]
pub fn hits(&self) -> u64 {
self.shard.hits()
}
pub fn reserve(&mut self, additional: usize) {
self.shard.reserve(additional);
}
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
Q: Hash + Equivalent<Key> + ?Sized,
{
self.shard.contains(self.shard.hash(key), key)
}
pub fn get<Q>(&self, key: &Q) -> Option<&Val>
where
Q: Hash + Equivalent<Key> + ?Sized,
{
self.shard.get(self.shard.hash(key), key)
}
pub fn get_mut<Q>(&mut self, key: &Q) -> Option<RefMut<'_, Key, Val, We, B, L>>
where
Q: Hash + Equivalent<Key> + ?Sized,
{
self.shard.get_mut(self.shard.hash(key), key).map(RefMut)
}
pub fn peek<Q>(&self, key: &Q) -> Option<&Val>
where
Q: Hash + Equivalent<Key> + ?Sized,
{
self.shard.peek(self.shard.hash(key), key)
}
pub fn peek_mut<Q>(&mut self, key: &Q) -> Option<RefMut<'_, Key, Val, We, B, L>>
where
Q: Hash + Equivalent<Key> + ?Sized,
{
self.shard.peek_mut(self.shard.hash(key), key).map(RefMut)
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<(Key, Val)>
where
Q: Hash + Equivalent<Key> + ?Sized,
{
self.shard.remove(self.shard.hash(key), key)
}
pub fn remove_if<Q, F>(&mut self, key: &Q, f: F) -> Option<(Key, Val)>
where
Q: Hash + Equivalent<Key> + ?Sized,
F: FnOnce(&Val) -> bool,
{
self.shard.remove_if(self.shard.hash(key), key, f)
}
pub fn replace(&mut self, key: Key, value: Val, soft: bool) -> Result<(), (Key, Val)> {
let lcs = self.replace_with_lifecycle(key, value, soft)?;
self.shard.lifecycle.end_request(lcs);
Ok(())
}
pub fn replace_with_lifecycle(
&mut self,
key: Key,
value: Val,
soft: bool,
) -> Result<L::RequestState, (Key, Val)> {
let mut lcs = self.shard.lifecycle.begin_request();
self.shard.insert(
&mut lcs,
self.shard.hash(&key),
key,
value,
InsertStrategy::Replace { soft },
)?;
Ok(lcs)
}
pub fn retain<F>(&mut self, f: F)
where
F: Fn(&Key, &Val) -> bool,
{
self.shard.retain(f);
}
pub fn get_or_insert_with<Q, E>(
&mut self,
key: &Q,
with: impl FnOnce() -> Result<Val, E>,
) -> Result<Option<&Val>, E>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
let idx = match self.shard.get_or_placeholder(self.shard.hash(key), key) {
Ok((idx, _)) => idx,
Err((plh, _)) => {
let v = with()?;
let mut lcs = self.shard.lifecycle.begin_request();
let replaced = self.shard.replace_placeholder(&mut lcs, &plh, false, v);
self.shard.lifecycle.end_request(lcs);
debug_assert!(replaced.is_ok(), "unsync replace_placeholder can't fail");
plh.idx
}
};
Ok(self.shard.peek_token(idx))
}
pub fn get_mut_or_insert_with<'a, Q, E>(
&'a mut self,
key: &Q,
with: impl FnOnce() -> Result<Val, E>,
) -> Result<Option<RefMut<'a, Key, Val, We, B, L>>, E>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
let idx = match self.shard.get_or_placeholder(self.shard.hash(key), key) {
Ok((idx, _)) => idx,
Err((plh, _)) => {
let v = with()?;
let mut lcs = self.shard.lifecycle.begin_request();
let replaced = self.shard.replace_placeholder(&mut lcs, &plh, false, v);
debug_assert!(replaced.is_ok(), "unsync replace_placeholder can't fail");
self.shard.lifecycle.end_request(lcs);
plh.idx
}
};
Ok(self.shard.peek_token_mut(idx).map(RefMut))
}
pub fn get_ref_or_guard<Q>(&mut self, key: &Q) -> Result<&Val, Guard<'_, Key, Val, We, B, L>>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
match self.shard.get_or_placeholder(self.shard.hash(key), key) {
Ok((_, v)) => unsafe {
let v: *const Val = v;
Ok(&*v)
},
Err((placeholder, _)) => Err(Guard {
cache: self,
placeholder,
inserted: false,
}),
}
}
pub fn get_mut_or_guard<'a, Q>(
&'a mut self,
key: &Q,
) -> Result<Option<RefMut<'a, Key, Val, We, B, L>>, Guard<'a, Key, Val, We, B, L>>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
match self.shard.get_or_placeholder(self.shard.hash(key), key) {
Ok((idx, _)) => Ok(self.shard.peek_token_mut(idx).map(RefMut)),
Err((placeholder, _)) => Err(Guard {
cache: self,
placeholder,
inserted: false,
}),
}
}
pub fn insert(&mut self, key: Key, value: Val) {
let lcs = self.insert_with_lifecycle(key, value);
self.shard.lifecycle.end_request(lcs);
}
pub fn insert_with_lifecycle(&mut self, key: Key, value: Val) -> L::RequestState {
let mut lcs = self.shard.lifecycle.begin_request();
let result = self.shard.insert(
&mut lcs,
self.shard.hash(&key),
key,
value,
InsertStrategy::Insert,
);
debug_assert!(result.is_ok());
lcs
}
pub fn clear(&mut self) {
self.shard.clear();
}
pub fn iter(&self) -> impl Iterator<Item = (&'_ Key, &'_ Val)> + '_ {
self.shard.iter()
}
pub fn drain(&mut self) -> impl Iterator<Item = (Key, Val)> + '_ {
self.shard.drain()
}
pub fn set_capacity(&mut self, new_weight_capacity: u64) {
self.shard.set_capacity(new_weight_capacity);
}
#[cfg(any(fuzzing, test))]
pub fn validate(&self, accept_overweight: bool) {
self.shard.validate(accept_overweight);
}
pub fn memory_used(&self) -> MemoryUsed {
self.shard.memory_used()
}
}
impl<Key, Val, We, B, L> std::fmt::Debug for Cache<Key, Val, We, B, L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Cache").finish_non_exhaustive()
}
}
pub struct DefaultLifecycle<Key, Val>(std::marker::PhantomData<(Key, Val)>);
impl<Key, Val> std::fmt::Debug for DefaultLifecycle<Key, Val> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("DefaultLifecycle").finish()
}
}
impl<Key, Val> Default for DefaultLifecycle<Key, Val> {
#[inline]
fn default() -> Self {
Self(Default::default())
}
}
impl<Key, Val> Clone for DefaultLifecycle<Key, Val> {
#[inline]
fn clone(&self) -> Self {
Self(Default::default())
}
}
impl<Key, Val> Lifecycle<Key, Val> for DefaultLifecycle<Key, Val> {
type RequestState = ();
#[inline]
fn begin_request(&self) -> Self::RequestState {}
#[inline]
fn on_evict(&self, _state: &mut Self::RequestState, _key: Key, _val: Val) {}
}
#[derive(Debug, Clone)]
pub(crate) struct SharedPlaceholder {
hash: u64,
idx: Token,
}
pub struct Guard<'a, Key, Val, We, B, L> {
cache: &'a mut Cache<Key, Val, We, B, L>,
placeholder: SharedPlaceholder,
inserted: bool,
}
impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>, B: BuildHasher, L: Lifecycle<Key, Val>>
Guard<'_, Key, Val, We, B, L>
{
pub fn insert(self, value: Val) {
self.insert_internal(value, false);
}
pub fn insert_with_lifecycle(self, value: Val) -> L::RequestState {
self.insert_internal(value, true).unwrap()
}
#[inline]
fn insert_internal(mut self, value: Val, return_lcs: bool) -> Option<L::RequestState> {
let mut lcs = self.cache.shard.lifecycle.begin_request();
let replaced =
self.cache
.shard
.replace_placeholder(&mut lcs, &self.placeholder, false, value);
debug_assert!(replaced.is_ok(), "unsync replace_placeholder can't fail");
self.inserted = true;
if return_lcs {
Some(lcs)
} else {
self.cache.shard.lifecycle.end_request(lcs);
None
}
}
}
impl<Key, Val, We, B, L> Drop for Guard<'_, Key, Val, We, B, L> {
#[inline]
fn drop(&mut self) {
#[cold]
fn drop_slow<Key, Val, We, B, L>(this: &mut Guard<'_, Key, Val, We, B, L>) {
this.cache.shard.remove_placeholder(&this.placeholder);
}
if !self.inserted {
drop_slow(self);
}
}
}
pub struct RefMut<'cache, Key, Val, We: Weighter<Key, Val>, B, L>(
crate::shard::RefMut<'cache, Key, Val, We, B, L, SharedPlaceholder>,
);
impl<Key, Val, We: Weighter<Key, Val>, B, L> std::ops::Deref for RefMut<'_, Key, Val, We, B, L> {
type Target = Val;
#[inline]
fn deref(&self) -> &Self::Target {
self.0.pair().1
}
}
impl<Key, Val, We: Weighter<Key, Val>, B, L> std::ops::DerefMut for RefMut<'_, Key, Val, We, B, L> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.value_mut()
}
}
impl shard::SharedPlaceholder for SharedPlaceholder {
#[inline]
fn new(hash: u64, idx: Token) -> Self {
Self { hash, idx }
}
#[inline]
fn same_as(&self, _other: &Self) -> bool {
true
}
#[inline]
fn hash(&self) -> u64 {
self.hash
}
#[inline]
fn idx(&self) -> Token {
self.idx
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Weighter;
impl crate::Weighter<u32, u32> for Weighter {
fn weight(&self, _key: &u32, val: &u32) -> u64 {
*val as u64
}
}
#[test]
fn test_zero_weights() {
let mut cache = Cache::with_weighter(100, 100, Weighter);
cache.insert(0, 0);
assert_eq!(cache.weight(), 0);
for i in 1..100 {
cache.insert(i, i);
cache.insert(i, i);
}
assert_eq!(cache.get(&0).copied(), Some(0));
assert!(cache.contains_key(&0));
let a = cache.weight();
*cache.get_mut(&0).unwrap() += 1;
assert_eq!(cache.weight(), a + 1);
for i in 1..100 {
cache.insert(i, i);
cache.insert(i, i);
}
assert_eq!(cache.get(&0), None);
assert!(!cache.contains_key(&0));
cache.insert(0, 1);
let a = cache.weight();
*cache.get_mut(&0).unwrap() -= 1;
assert_eq!(cache.weight(), a - 1);
for i in 1..100 {
cache.insert(i, i);
cache.insert(i, i);
}
assert_eq!(cache.get(&0).copied(), Some(0));
assert!(cache.contains_key(&0));
}
#[test]
fn test_set_capacity() {
let mut cache = Cache::new(100);
for i in 0..80 {
cache.insert(i, i);
}
let initial_len = cache.len();
assert!(initial_len <= 80);
cache.set_capacity(50);
assert!(cache.len() <= 50);
assert!(cache.weight() <= 50);
cache.validate(false);
cache.set_capacity(200);
assert_eq!(cache.capacity(), 200);
cache.validate(false);
for i in 100..180 {
cache.insert(i, i);
}
assert!(cache.len() <= 180);
assert!(cache.weight() <= 200);
cache.validate(false);
}
#[test]
fn test_set_capacity_with_ghosts() {
let mut cache = Cache::new(50);
for i in 0..100 {
cache.insert(i, i);
}
cache.validate(false);
cache.set_capacity(25);
assert!(cache.weight() <= 25);
cache.validate(false);
cache.set_capacity(100);
assert_eq!(cache.capacity(), 100);
cache.validate(false);
for i in 100..150 {
cache.insert(i, i);
}
cache.validate(false);
}
#[test]
fn test_remove_if() {
let mut cache = Cache::new(100);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
let removed = cache.remove_if(&2, |v| *v == 20);
assert_eq!(removed, Some((2, 20)));
assert_eq!(cache.get(&2), None);
let not_removed = cache.remove_if(&3, |v| *v == 999);
assert_eq!(not_removed, None);
assert_eq!(cache.get(&3), Some(&30));
let not_found = cache.remove_if(&999, |_| true);
assert_eq!(not_found, None);
cache.validate(false);
}
}