use alloc::{boxed::Box, sync::Arc};
use core::{
any::TypeId,
fmt,
marker::PhantomData,
sync::atomic::{AtomicU64, Ordering},
};
#[derive(Clone, Copy)]
pub struct AnyRef<'a> {
ptr: *const (),
type_id: TypeId,
_marker: PhantomData<&'a ()>,
}
impl<'a> AnyRef<'a> {
#[inline]
pub(crate) fn new<T>(value: &'a T) -> Self {
Self {
ptr: value as *const T as *const (),
type_id: typeid::of::<T>(),
_marker: PhantomData,
}
}
#[inline]
pub fn as_ptr(&self) -> *const () {
self.ptr
}
#[inline]
pub fn type_id(&self) -> TypeId {
self.type_id
}
#[inline]
pub fn downcast_ref<T>(&self) -> Option<&'a T> {
if self.type_id == typeid::of::<T>() {
Some(unsafe { &*(self.ptr as *const T) })
} else {
None
}
}
}
impl fmt::Debug for AnyRef<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AnyRef").field("type_id", &self.type_id).finish_non_exhaustive()
}
}
pub trait StatsHandler<K, V>: Send + Sync {
fn on_hit(&self, key: &K, value: &V) {
let _ = key;
let _ = value;
}
fn on_miss(&self, key: AnyRef<'_>) {
let _ = key;
}
#[deprecated(note = "use on_insert and check that evicted key doesn't equal to new key")]
fn on_collision(&self, new_key: AnyRef<'_>, existing_key: &K, existing_value: &V) {
let _ = new_key;
let _ = existing_key;
let _ = existing_value;
}
fn on_insert(&self, key: &K, value: &V, evicted: Option<(&K, &V)>) {
let _ = key;
let _ = value;
let _ = evicted;
}
fn on_remove(&self, key: &K, value: &V) {
let _ = key;
let _ = value;
}
}
pub struct CountingStatsHandler {
hits: AtomicU64,
misses: AtomicU64,
inserts: AtomicU64,
updates: AtomicU64,
removes: AtomicU64,
collisions: AtomicU64,
}
impl CountingStatsHandler {
pub const fn new() -> Self {
Self {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
inserts: AtomicU64::new(0),
updates: AtomicU64::new(0),
removes: AtomicU64::new(0),
collisions: AtomicU64::new(0),
}
}
pub fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
pub fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
pub fn inserts(&self) -> u64 {
self.inserts.load(Ordering::Relaxed)
}
pub fn updates(&self) -> u64 {
self.updates.load(Ordering::Relaxed)
}
pub fn removes(&self) -> u64 {
self.removes.load(Ordering::Relaxed)
}
pub fn collisions(&self) -> u64 {
self.collisions.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.inserts.store(0, Ordering::Relaxed);
self.updates.store(0, Ordering::Relaxed);
self.removes.store(0, Ordering::Relaxed);
self.collisions.store(0, Ordering::Relaxed);
}
}
impl Default for CountingStatsHandler {
fn default() -> Self {
Self::new()
}
}
impl<K: PartialEq, V> StatsHandler<K, V> for CountingStatsHandler {
fn on_hit(&self, _key: &K, _value: &V) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
fn on_miss(&self, _key: AnyRef<'_>) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
fn on_insert(&self, key: &K, _value: &V, evicted: Option<(&K, &V)>) {
match evicted {
Some((old_key, _)) if old_key == key => {
self.updates.fetch_add(1, Ordering::Relaxed);
}
Some(_) => {
self.inserts.fetch_add(1, Ordering::Relaxed);
self.collisions.fetch_add(1, Ordering::Relaxed);
}
None => {
self.inserts.fetch_add(1, Ordering::Relaxed);
}
}
}
fn on_remove(&self, _key: &K, _value: &V) {
self.removes.fetch_add(1, Ordering::Relaxed);
}
}
impl<K, V, T: StatsHandler<K, V>> StatsHandler<K, V> for Arc<T> {
#[inline]
fn on_hit(&self, key: &K, value: &V) {
(**self).on_hit(key, value);
}
#[inline]
fn on_miss(&self, key: AnyRef<'_>) {
(**self).on_miss(key);
}
#[inline]
fn on_insert(&self, key: &K, value: &V, evicted: Option<(&K, &V)>) {
(**self).on_insert(key, value, evicted);
}
#[inline]
fn on_remove(&self, key: &K, value: &V) {
(**self).on_remove(key, value);
}
}
pub struct Stats<K, V> {
handler: Box<dyn StatsHandler<K, V>>,
}
impl<K, V> Stats<K, V> {
#[inline]
pub fn new<H: StatsHandler<K, V> + 'static>(handler: H) -> Self {
Self { handler: Box::new(handler) }
}
#[inline]
pub fn handler(&self) -> &dyn StatsHandler<K, V> {
&*self.handler
}
#[inline]
pub(crate) fn record_hit(&self, key: &K, value: &V) {
self.handler.on_hit(key, value);
}
#[inline]
pub(crate) fn record_miss(&self, key: AnyRef<'_>) {
self.handler.on_miss(key);
}
#[inline]
pub(crate) fn record_insert(&self, key: &K, value: &V, evicted: Option<(&K, &V)>) {
self.handler.on_insert(key, value, evicted);
}
#[inline]
pub(crate) fn record_remove(&self, key: &K, value: &V) {
self.handler.on_remove(key, value);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Cache;
use std::sync::Arc;
type BH = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
#[test]
fn counting_stats_handler_basic() {
let handler = CountingStatsHandler::new();
assert_eq!(handler.hits(), 0);
assert_eq!(handler.misses(), 0);
assert_eq!(handler.inserts(), 0);
assert_eq!(handler.updates(), 0);
assert_eq!(handler.removes(), 0);
assert_eq!(handler.collisions(), 0);
StatsHandler::<u64, u64>::on_hit(&handler, &1, &2);
assert_eq!(handler.hits(), 1);
StatsHandler::<u64, u64>::on_miss(&handler, AnyRef::new(&1u64));
assert_eq!(handler.misses(), 1);
StatsHandler::<u64, u64>::on_insert(&handler, &1, &2, None);
assert_eq!(handler.inserts(), 1);
assert_eq!(handler.collisions(), 0);
StatsHandler::<u64, u64>::on_insert(&handler, &3, &4, Some((&5, &6)));
assert_eq!(handler.inserts(), 2);
assert_eq!(handler.collisions(), 1);
StatsHandler::<u64, u64>::on_insert(&handler, &1, &2, Some((&1, &3)));
assert_eq!(handler.updates(), 1);
StatsHandler::<u64, u64>::on_remove(&handler, &1, &2);
assert_eq!(handler.removes(), 1);
handler.reset();
assert_eq!(handler.hits(), 0);
assert_eq!(handler.misses(), 0);
assert_eq!(handler.inserts(), 0);
assert_eq!(handler.updates(), 0);
assert_eq!(handler.removes(), 0);
assert_eq!(handler.collisions(), 0);
}
#[test]
fn cache_with_stats_hits_and_misses() {
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<u64, u64, BH> = Cache::new(64, Default::default()).with_stats(Some(stats));
assert_eq!(cache.get(&42), None);
assert_eq!(handler.misses(), 1);
cache.insert(42, 100);
assert_eq!(cache.get(&42), Some(100));
assert_eq!(handler.hits(), 1);
assert_eq!(handler.misses(), 1);
assert_eq!(cache.get(&99), None);
assert_eq!(handler.misses(), 2);
}
#[test]
fn cache_with_stats_no_collision_on_get_miss() {
use std::hash::{Hash, Hasher};
#[derive(Clone, Eq, PartialEq)]
struct CollidingKey(u64, u64);
impl Hash for CollidingKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<CollidingKey, u64, BH> =
Cache::new(64, Default::default()).with_stats(Some(stats));
cache.insert(CollidingKey(1, 1), 100);
let result = cache.get(&CollidingKey(1, 2));
assert!(result.is_none());
assert_eq!(handler.collisions(), 0);
assert_eq!(handler.misses(), 1);
}
#[test]
fn cache_with_stats_collisions_on_insert() {
use std::hash::{Hash, Hasher};
#[derive(Clone, Eq, PartialEq)]
struct CollidingKey(u64, u64);
impl Hash for CollidingKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<CollidingKey, u64, BH> =
Cache::new(64, Default::default()).with_stats(Some(stats));
cache.insert(CollidingKey(1, 1), 100);
assert_eq!(handler.collisions(), 0);
cache.insert(CollidingKey(1, 2), 200);
assert_eq!(handler.collisions(), 1);
cache.insert(CollidingKey(1, 2), 300);
assert_eq!(handler.collisions(), 1);
}
#[test]
fn get_or_insert_with_stats() {
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<u64, u64, BH> = Cache::new(64, Default::default()).with_stats(Some(stats));
let value = cache.get_or_insert_with(42, |&k| k * 2);
assert_eq!(value, 84);
assert_eq!(handler.misses(), 1);
let value = cache.get_or_insert_with(42, |&k| k * 3);
assert_eq!(value, 84);
assert_eq!(handler.hits(), 1);
}
#[test]
fn boxed_stats_handler() {
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<u64, u64, BH> = Cache::new(64, Default::default()).with_stats(Some(stats));
cache.insert(1, 100);
assert_eq!(cache.get(&1), Some(100));
assert_eq!(cache.get(&2), None);
assert_eq!(handler.hits(), 1);
assert_eq!(handler.misses(), 1);
}
#[test]
fn cache_without_stats() {
let cache: Cache<u64, u64, BH> = Cache::new(64, Default::default());
cache.insert(1, 100);
assert_eq!(cache.get(&1), Some(100));
assert!(cache.stats().is_none());
}
#[test]
fn anyref_downcast_in_handler() {
use std::sync::Mutex;
struct CapturingHandler {
missed_keys: Mutex<Vec<String>>,
}
impl CapturingHandler {
fn new() -> Self {
Self { missed_keys: Mutex::new(Vec::new()) }
}
}
impl StatsHandler<String, u64> for CapturingHandler {
fn on_miss(&self, key: AnyRef<'_>) {
if let Some(k) = key.downcast_ref::<&str>() {
self.missed_keys.lock().unwrap().push((*k).to_string());
} else if let Some(k) = key.downcast_ref::<&String>() {
self.missed_keys.lock().unwrap().push((*k).clone());
}
}
}
let handler = Arc::new(CapturingHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<String, u64, BH> =
Cache::new(64, Default::default()).with_stats(Some(stats));
assert_eq!(cache.get("hello"), None);
assert_eq!(cache.get("world"), None);
assert_eq!(cache.get(&"foo".to_string()), None);
assert_eq!(*handler.missed_keys.lock().unwrap(), vec!["hello", "world", "foo"]);
}
#[test]
fn insert_update_collision_stats() {
use std::hash::{Hash, Hasher};
#[derive(Clone, Eq, PartialEq, Debug)]
struct CollidingKey(u64, u64);
impl Hash for CollidingKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<CollidingKey, u64, BH> =
Cache::new(64, Default::default()).with_stats(Some(stats));
cache.insert(CollidingKey(1, 1), 100);
assert_eq!(handler.inserts(), 1);
assert_eq!(handler.updates(), 0);
assert_eq!(handler.collisions(), 0);
cache.insert(CollidingKey(1, 1), 200);
assert_eq!(handler.inserts(), 1);
assert_eq!(handler.updates(), 1);
assert_eq!(handler.collisions(), 0);
cache.insert(CollidingKey(1, 2), 300);
assert_eq!(handler.inserts(), 2);
assert_eq!(handler.updates(), 1);
assert_eq!(handler.collisions(), 1);
assert_eq!(handler.removes(), 0);
cache.remove(&CollidingKey(1, 2));
assert_eq!(handler.removes(), 1);
cache.remove(&CollidingKey(99, 99));
assert_eq!(handler.removes(), 1);
}
#[test]
fn concurrent_stats() {
if cfg!(miri) {
return;
}
let handler = Arc::new(CountingStatsHandler::new());
let stats = Stats::new(Arc::clone(&handler));
let cache: Cache<u64, u64, BH> =
Cache::new(1024, Default::default()).with_stats(Some(stats));
std::thread::scope(|s| {
for t in 0..4 {
let cache = &cache;
s.spawn(move || {
for i in 0..1000u64 {
match t {
0 => cache.insert(i % 100, i),
1 => _ = cache.get(&(i % 100)),
2 => _ = cache.get_or_insert_with(i % 100, |&k| k * 2),
_ => _ = cache.remove(&(i % 100)),
}
}
});
}
});
let total = handler.hits()
+ handler.misses()
+ handler.inserts()
+ handler.updates()
+ handler.removes()
+ handler.collisions();
assert!(total > 0);
assert_eq!(handler.hits() + handler.misses(), 1000 + 1000);
}
}