use hash_map::RandomState;
use linked_hash_map::LinkedHashMap;
#[cfg(feature = "metrics")]
use prometheus::{
core::{AtomicU64, GenericCounter, GenericGauge},
Opts, Registry,
};
use std::{
collections::hash_map,
fmt,
hash::{BuildHasher, Hash},
num::NonZeroUsize,
};
pub trait Weighable {
fn measure(value: &Self) -> usize;
}
#[derive(Debug)]
pub struct ValueTooBigError;
impl std::error::Error for ValueTooBigError {}
impl fmt::Display for ValueTooBigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Value is bigger than the configured max size of the cache"
)
}
}
struct ValueWithWeight<V> {
value: V,
weight: usize,
}
pub struct WeightCache<K, V, S = hash_map::RandomState> {
max: usize,
current: usize,
inner: LinkedHashMap<K, ValueWithWeight<V>, S>,
#[cfg(feature = "metrics")]
metrics: Metrics,
}
impl<K, V, S> fmt::Debug for WeightCache<K, V, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeightCache")
.field("max", &self.max)
.field("current", &self.current)
.finish()
}
}
impl<K: Hash + Eq, V: Weighable> Default for WeightCache<K, V> {
fn default() -> Self {
WeightCache::<K, V, RandomState>::new(
NonZeroUsize::new(usize::max_value()).expect("MAX > 0"),
)
}
}
#[cfg(feature = "metrics")]
struct Metrics {
hits: GenericCounter<AtomicU64>,
misses: GenericCounter<AtomicU64>,
inserts: GenericCounter<AtomicU64>,
inserts_fail: GenericCounter<AtomicU64>,
size: GenericGauge<AtomicU64>,
}
#[cfg(feature = "metrics")]
impl Metrics {
fn new(namespace: Option<&str>) -> Self {
if let Some(namespace) = namespace {
let cache_size = GenericGauge::with_opts(
Opts::new("cache_size", "Current size of the cache").namespace(namespace),
)
.unwrap();
cache_size.set(0);
Self {
hits: GenericCounter::with_opts(
Opts::new("cache_hit", "Number of cache hits").namespace(namespace),
)
.unwrap(),
misses: GenericCounter::with_opts(
Opts::new("cache_miss", "Number of cache misses").namespace(namespace),
)
.unwrap(),
inserts: GenericCounter::with_opts(
Opts::new("cache_insert", "Number of successful cache insertions")
.namespace(namespace),
)
.unwrap(),
inserts_fail: GenericCounter::with_opts(
Opts::new("cache_insert_fail", "Number of failed cache insertions")
.namespace(namespace),
)
.unwrap(),
size: cache_size,
}
} else {
let cache_size = GenericGauge::new("cache_size", "Current size of the cache").unwrap();
cache_size.set(0);
Self {
hits: GenericCounter::new("cache_hit", "Number of cache hits").unwrap(),
misses: GenericCounter::new("cache_miss", "Number of cache misses").unwrap(),
inserts: GenericCounter::new(
"cache_insert",
"Number of successful cache insertions",
)
.unwrap(),
inserts_fail: GenericCounter::new(
"cache_insert_fail",
"Number of failed cache insertions",
)
.unwrap(),
size: cache_size,
}
}
}
}
impl<K: Hash + Eq, V: Weighable> WeightCache<K, V> {
pub fn new(capacity: NonZeroUsize) -> Self {
Self {
max: capacity.get(),
current: 0,
inner: LinkedHashMap::new(),
#[cfg(feature = "metrics")]
metrics: Metrics::new(None),
}
}
#[cfg(feature = "metrics")]
pub fn new_with_namespace(capacity: NonZeroUsize, metrics_namespace: Option<&str>) -> Self {
Self {
max: capacity.get(),
current: 0,
inner: LinkedHashMap::new(),
metrics: Metrics::new(metrics_namespace),
}
}
}
impl<K: Hash + Eq, V: Weighable, S: BuildHasher> WeightCache<K, V, S> {
pub fn with_hasher(capacity: NonZeroUsize, hasher: S) -> Self {
Self {
max: capacity.get(),
current: 0,
inner: LinkedHashMap::with_hasher(hasher),
#[cfg(feature = "metrics")]
metrics: Metrics::new(None),
}
}
#[cfg(feature = "metrics")]
pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> {
registry.register(Box::new(self.metrics.hits.clone()))?;
registry.register(Box::new(self.metrics.misses.clone()))?;
registry.register(Box::new(self.metrics.inserts.clone()))?;
registry.register(Box::new(self.metrics.inserts_fail.clone()))?;
registry.register(Box::new(self.metrics.size.clone()))?;
Ok(())
}
pub fn get(&mut self, k: &K) -> Option<&V> {
if let Some(v) = self.inner.get_refresh(k) {
#[cfg(feature = "metrics")]
self.metrics.hits.inc();
Some(&v.value as &V)
} else {
#[cfg(feature = "metrics")]
self.metrics.misses.inc();
None
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn put(&mut self, key: K, value: V) -> Result<(), ValueTooBigError> {
let weight = V::measure(&value);
if weight > self.max {
#[cfg(feature = "metrics")]
self.metrics.inserts_fail.inc();
Err(ValueTooBigError)
} else {
self.current += weight;
if let Some(x) = self.inner.insert(key, ValueWithWeight { value, weight }) {
self.current -= x.weight;
}
self.shrink_to_fit();
#[cfg(feature = "metrics")]
self.metrics.inserts.inc();
Ok(())
}
}
fn shrink_to_fit(&mut self) {
while self.current > self.max && !self.inner.is_empty() {
let (_, v) = self.inner.pop_front().expect("Not empty");
self.current -= v.weight;
}
#[cfg(feature = "metrics")]
self.metrics.size.set(self.current as u64);
}
}
impl<K: Hash + Eq + 'static, V: Weighable + 'static, S: BuildHasher> WeightCache<K, V, S> {
pub fn consume(self) -> Box<dyn Iterator<Item = (K, V)> + 'static> {
#[cfg(feature = "metrics")]
self.metrics.size.set(0);
Box::new(self.inner.into_iter().map(|(k, v)| (k, v.value)))
}
}
#[cfg(test)]
mod test {
use std::convert::TryInto;
use super::*;
use quickcheck::{Arbitrary, Gen};
use quickcheck_macros::quickcheck;
#[derive(Clone, Debug, PartialEq)]
struct HeavyWeight(usize);
impl Weighable for HeavyWeight {
fn measure(v: &Self) -> usize {
v.0
}
}
impl Arbitrary for HeavyWeight {
fn arbitrary(g: &mut Gen) -> Self {
Self(usize::arbitrary(g))
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
Box::new(usize::shrink(&self.0).map(HeavyWeight))
}
}
#[derive(Clone, Debug, PartialEq)]
struct UnitWeight;
impl Weighable for UnitWeight {
fn measure(_: &Self) -> usize {
1
}
}
impl Arbitrary for UnitWeight {
fn arbitrary(_: &mut Gen) -> Self {
Self
}
}
#[test]
fn should_not_evict_under_max_size() {
let xs: Vec<_> = (0..10000).map(HeavyWeight).collect();
let mut cache =
WeightCache::<usize, HeavyWeight>::new(usize::max_value().try_into().unwrap());
for (k, v) in xs.iter().enumerate() {
cache.put(k, v.clone()).expect("empty")
}
let cached = cache.consume().map(|x| x.1).collect::<Vec<_>>();
assert_eq!(xs, cached);
}
#[cfg(feature = "metrics")]
fn metrics_test(namespace: Option<&str>) {
let mut cache =
WeightCache::<usize, UnitWeight>::new_with_namespace(3.try_into().unwrap(), namespace);
let registry = Registry::new();
cache.register(®istry).unwrap();
for i in 0usize..5 {
cache.put(i, UnitWeight).unwrap();
}
for i in 0usize..5 {
cache.get(&i);
}
for metric in registry.gather() {
println!("{} {:?}", metric.get_name(), metric.get_metric()[0]);
match metric.get_name() {
x if x
== format!(
"{}cache_size",
namespace.map(|y| format!("{}_", y)).unwrap_or_default()
) =>
{
assert_eq!(3, metric.get_metric()[0].get_gauge().get_value() as usize)
}
x if x
== format!(
"{}cache_insert",
namespace.map(|y| format!("{}_", y)).unwrap_or_default()
) =>
{
assert_eq!(5, metric.get_metric()[0].get_counter().get_value() as usize)
}
x if x
== format!(
"{}cache_insert_fail",
namespace.map(|y| format!("{}_", y)).unwrap_or_default()
) =>
{
assert_eq!(0, metric.get_metric()[0].get_counter().get_value() as usize)
}
x if x
== format!(
"{}cache_hit",
namespace.map(|y| format!("{}_", y)).unwrap_or_default()
) =>
{
assert_eq!(3, metric.get_metric()[0].get_counter().get_value() as usize)
}
x if x
== format!(
"{}cache_miss",
namespace.map(|y| format!("{}_", y)).unwrap_or_default()
) =>
{
assert_eq!(2, metric.get_metric()[0].get_counter().get_value() as usize)
}
x => panic!("unknown metrics {}", x),
}
}
}
#[cfg(feature = "metrics")]
#[test]
fn should_gather_metrics() {
metrics_test(None);
metrics_test(Some("test"));
}
#[quickcheck]
fn should_reject_too_heavy_values(total_size: NonZeroUsize, input: HeavyWeight) -> bool {
let mut cache = WeightCache::<usize, HeavyWeight>::new(total_size);
let res = cache.put(42, input.clone());
match res {
Ok(_) if input.0 < total_size.get() => true,
Err(_) if input.0 >= total_size.get() => true,
_ => false,
}
}
#[quickcheck]
fn should_evict_once_the_size_target_is_hit(
input: Vec<UnitWeight>,
max_size: NonZeroUsize,
) -> bool {
let mut cache_size = 0usize;
let mut cache = WeightCache::<usize, UnitWeight>::new(max_size);
for (k, v) in input.into_iter().enumerate() {
let weight = UnitWeight::measure(&v);
cache_size += weight;
let len_before = cache.len();
cache.put(k, v).unwrap();
let len_after = cache.len();
if cache_size > max_size.get() {
assert_eq!(len_before, len_after);
cache_size -= weight;
} else {
assert_eq!(len_before + 1, len_after);
}
}
true
}
}