mod layer;
#[cfg(feature = "cache-moka")]
mod moka_impl;
pub use layer::{CacheResponseLayer, CacheResponseService};
#[cfg(feature = "cache-moka")]
pub use moka_impl::MokaCache;
use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{Arc, RwLock};
static GLOBAL_CACHE: RwLock<Option<Arc<dyn Cache>>> = RwLock::new(None);
pub fn set_global_cache(cache: Arc<dyn Cache>) {
*GLOBAL_CACHE.write().expect("global cache lock poisoned") = Some(cache);
}
#[must_use]
pub fn global_cache() -> Option<Arc<dyn Cache>> {
GLOBAL_CACHE
.read()
.expect("global cache lock poisoned")
.clone()
}
pub fn clear_global_cache() {
*GLOBAL_CACHE.write().expect("global cache lock poisoned") = None;
}
#[derive(Clone)]
pub struct RawCacheBytes(pub Vec<u8>);
pub trait Cache: Send + Sync + 'static {
fn get_value(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>>;
fn insert_value(&self, key: &str, value: Arc<dyn Any + Send + Sync>);
fn invalidate(&self, key: &str);
fn clear(&self);
fn insert_raw_bytes(&self, _key: &str, _bytes: Vec<u8>, _ttl: Option<std::time::Duration>) {}
}
pub fn get<V: Clone + Send + Sync + 'static>(cache: &dyn Cache, key: &str) -> Option<V> {
cache
.get_value(key)
.and_then(|arc| arc.downcast_ref::<V>().cloned())
}
pub fn insert<V: Clone + Send + Sync + 'static>(cache: &dyn Cache, key: &str, value: V) {
cache.insert_value(key, Arc::new(value));
}
pub fn get_cached<V>(cache: &dyn Cache, key: &str) -> Option<V>
where
V: Clone + serde::de::DeserializeOwned + Send + Sync + 'static,
{
let arc = cache.get_value(key)?;
if let Some(v) = arc.downcast_ref::<V>() {
return Some(v.clone());
}
arc.downcast_ref::<RawCacheBytes>()
.and_then(|raw| serde_json::from_slice::<V>(&raw.0).ok())
}
pub fn insert_cached<V>(cache: &dyn Cache, key: &str, value: V, ttl: Option<std::time::Duration>)
where
V: Clone + serde::Serialize + Send + Sync + 'static,
{
cache.insert_value(key, Arc::new(value.clone()));
if let Ok(bytes) = serde_json::to_vec(&value) {
cache.insert_raw_bytes(key, bytes, ttl);
}
}
pub trait CacheableResult {
type Ok: Clone;
type Err;
fn into_result(self) -> Result<Self::Ok, Self::Err>;
fn from_ok(ok: Self::Ok) -> Self;
}
impl<T: Clone, E> CacheableResult for Result<T, E> {
type Ok = T;
type Err = E;
fn into_result(self) -> Self {
self
}
fn from_ok(ok: T) -> Self {
Ok(ok)
}
}
#[must_use]
pub fn make_cache_key<K: Hash>(fn_name: &str, args: &K) -> String {
let mut hasher = DefaultHasher::new();
args.hash(&mut hasher);
format!("{}:{:x}", fn_name, hasher.finish())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_key_deterministic() {
let k1 = make_cache_key("get_user", &(42_i64,));
let k2 = make_cache_key("get_user", &(42_i64,));
assert_eq!(k1, k2);
}
#[test]
fn cache_key_differs_by_fn_name() {
let k1 = make_cache_key("get_user", &(42_i64,));
let k2 = make_cache_key("find_user", &(42_i64,));
assert_ne!(k1, k2);
}
#[test]
fn cache_key_differs_by_args() {
let k1 = make_cache_key("get_user", &(1_i64,));
let k2 = make_cache_key("get_user", &(2_i64,));
assert_ne!(k1, k2);
}
#[test]
fn cache_key_no_args() {
let k = make_cache_key("get_config", &());
assert!(k.starts_with("get_config:"));
}
#[cfg(feature = "cache-moka")]
#[test]
fn insert_cached_and_get_cached_round_trip() {
let cache = MokaCache::new(10, None);
insert_cached(&cache, "key", "hello".to_string(), None);
let val: Option<String> = get_cached(&cache, "key");
assert_eq!(val.as_deref(), Some("hello"));
}
#[cfg(feature = "cache-moka")]
#[test]
fn get_cached_raw_bytes_slow_path() {
let cache = MokaCache::new(10, None);
let bytes = serde_json::to_vec(&42_i32).unwrap();
cache.insert_value("k", Arc::new(RawCacheBytes(bytes)));
let val: Option<i32> = get_cached(&cache, "k");
assert_eq!(val, Some(42));
}
#[cfg(feature = "cache-moka")]
#[test]
fn get_cached_miss_returns_none() {
let cache = MokaCache::new(10, None);
let val: Option<String> = get_cached(&cache, "missing");
assert!(val.is_none());
}
#[test]
fn cacheable_result_ok_round_trips() {
let r: Result<i32, &str> = Result::from_ok(42);
assert_eq!(r, Ok(42));
assert_eq!(r.into_result(), Ok(42));
}
#[test]
fn cacheable_result_err_passes_through() {
let r: Result<i32, &str> = Err("oops");
assert_eq!(r.into_result(), Err("oops"));
}
}