use crate::hashtable::{ConcurrentLruCache, LruShard};
pub type CustomReasonPredicate = fn(&'static str) -> bool;
pub struct Predictor<const N_SHARDS: usize> {
uncacheable_keys: ConcurrentLruCache<(), N_SHARDS>,
skip_custom_reasons_fn: Option<CustomReasonPredicate>,
}
use crate::{key::CacheHashKey, CacheKey, NoCacheReason};
use log::debug;
pub trait CacheablePredictor {
fn cacheable_prediction(&self, key: &CacheKey) -> bool;
fn mark_cacheable(&self, key: &CacheKey) -> bool;
fn mark_uncacheable(&self, key: &CacheKey, reason: NoCacheReason) -> Option<bool>;
}
impl<const N_SHARDS: usize> Predictor<N_SHARDS>
where
[LruShard<()>; N_SHARDS]: Default,
{
pub fn new(
shard_capacity: usize,
skip_custom_reasons_fn: Option<CustomReasonPredicate>,
) -> Predictor<N_SHARDS> {
Predictor {
uncacheable_keys: ConcurrentLruCache::<(), N_SHARDS>::new(shard_capacity),
skip_custom_reasons_fn,
}
}
}
impl<const N_SHARDS: usize> CacheablePredictor for Predictor<N_SHARDS>
where
[LruShard<()>; N_SHARDS]: Default,
{
fn cacheable_prediction(&self, key: &CacheKey) -> bool {
let hash = key.primary_bin();
let key = u128::from_be_bytes(hash);
!self.uncacheable_keys.read(key).contains(&key)
}
fn mark_cacheable(&self, key: &CacheKey) -> bool {
let hash = key.primary_bin();
let key = u128::from_be_bytes(hash);
let cache = self.uncacheable_keys.get(key);
if !cache.read().contains(&key) {
return true;
}
let mut cache = cache.write();
cache.pop(&key);
debug!("bypassed request became cacheable");
false
}
fn mark_uncacheable(&self, key: &CacheKey, reason: NoCacheReason) -> Option<bool> {
use NoCacheReason::*;
match reason {
NeverEnabled | StorageError | InternalError | Deferred | CacheLockGiveUp
| CacheLockTimeout => {
return None;
}
Custom(reason) if self.skip_custom_reasons_fn.map_or(false, |f| f(reason)) => {
return None;
}
Custom(_) | OriginNotCache | ResponseTooLarge => {
}
}
let hash = key.primary_bin();
let key = u128::from_be_bytes(hash);
let mut cache = self.uncacheable_keys.get(key).write();
let new_key = cache.put(key, ()).is_none();
if new_key {
debug!("request marked uncacheable");
}
Some(new_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mark_cacheability() {
let predictor = Predictor::<1>::new(10, None);
let key = CacheKey::new("a", "b", "c");
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::InternalError);
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::StorageError);
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key));
predictor.mark_cacheable(&key);
assert!(predictor.cacheable_prediction(&key));
}
#[test]
fn test_custom_skip_predicate() {
let predictor = Predictor::<1>::new(
10,
Some(|custom_reason| matches!(custom_reason, "Skipping")),
);
let key = CacheKey::new("a", "b", "c");
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::InternalError);
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::Custom("DontCacheMe"));
assert!(!predictor.cacheable_prediction(&key));
let key = CacheKey::new("a", "c", "d");
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::Custom("Skipping"));
assert!(predictor.cacheable_prediction(&key));
}
#[test]
fn test_mark_uncacheable_lru() {
let predictor = Predictor::<1>::new(3, None);
let key1 = CacheKey::new("a", "b", "c");
predictor.mark_uncacheable(&key1, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key1));
let key2 = CacheKey::new("a", "bc", "c");
predictor.mark_uncacheable(&key2, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key2));
let key3 = CacheKey::new("a", "cd", "c");
predictor.mark_uncacheable(&key3, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key3));
predictor.mark_uncacheable(&key1, NoCacheReason::OriginNotCache);
let key4 = CacheKey::new("a", "de", "c");
predictor.mark_uncacheable(&key4, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key4));
assert!(!predictor.cacheable_prediction(&key1));
assert!(predictor.cacheable_prediction(&key2));
assert!(!predictor.cacheable_prediction(&key3));
assert!(!predictor.cacheable_prediction(&key4));
}
}