use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
use serde::{Serialize, de::DeserializeOwned};
use tokio::sync::Mutex;
use crate::cache::{CacheKey, CacheResult, CacheStats, CacheStore, jitter_ttl};
const NOT_FOUND_PLACEHOLDER: &[u8] = b"__rs_zero_not_found__";
#[derive(Debug, Clone, PartialEq)]
pub struct CacheAsideConfig {
pub value_ttl: Duration,
pub not_found_ttl: Duration,
pub ttl_jitter_ratio: f64,
}
impl Default for CacheAsideConfig {
fn default() -> Self {
Self {
value_ttl: Duration::from_secs(300),
not_found_ttl: Duration::from_secs(60),
ttl_jitter_ratio: 0.05,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheAside<S> {
store: S,
config: CacheAsideConfig,
stats: CacheStats,
locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
#[cfg(feature = "observability")]
metrics: Option<crate::observability::MetricsRegistry>,
}
impl<S> CacheAside<S> {
pub fn new(store: S, config: CacheAsideConfig) -> Self {
Self {
store,
config,
stats: CacheStats::default(),
locks: Arc::new(Mutex::new(HashMap::new())),
#[cfg(feature = "observability")]
metrics: None,
}
}
pub fn stats(&self) -> CacheStats {
self.stats.clone()
}
#[cfg(feature = "observability")]
pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
self.metrics = Some(metrics);
self
}
fn record_event(&self, operation: &str, result: &str) {
#[cfg(feature = "observability")]
crate::observability::cache::record_cache_event(
self.metrics.as_ref(),
"cache_aside",
operation,
result,
);
#[cfg(not(feature = "observability"))]
{
let _ = (operation, result);
}
}
}
impl<S> CacheAside<S>
where
S: CacheStore,
{
pub async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
match self.store.delete(key).await {
Ok(()) => {
self.record_event("delete", "success");
Ok(())
}
Err(error) => {
self.stats.record_delete_error();
self.record_event("delete", "error");
Err(error)
}
}
}
pub async fn get_or_load_json<T, F, Fut>(
&self,
key: &CacheKey,
loader: F,
) -> CacheResult<Option<T>>
where
T: DeserializeOwned + Serialize + Send + Sync,
F: FnOnce() -> Fut + Send,
Fut: Future<Output = CacheResult<Option<T>>> + Send,
{
if let Some(value) = self.read_cached_json(key).await? {
return Ok(value);
}
self.stats.record_miss();
self.record_event("get", "miss");
let rendered = key.render();
let lock = self.key_lock(&rendered).await;
let guard = lock.lock().await;
if let Some(value) = self.read_cached_json(key).await? {
drop(guard);
self.release_key_lock(&rendered, &lock).await;
return Ok(value);
}
let loaded = loader().await.inspect_err(|_| {
self.stats.record_loader_error();
self.record_event("load", "error");
})?;
match loaded.as_ref() {
Some(value) => self.write_json(key, value).await?,
None => self.write_not_found(key).await?,
}
drop(guard);
self.release_key_lock(&rendered, &lock).await;
Ok(loaded)
}
async fn read_cached_json<T>(&self, key: &CacheKey) -> CacheResult<Option<Option<T>>>
where
T: DeserializeOwned + Send,
{
let Some(bytes) = self.store.get_raw(key).await? else {
return Ok(None);
};
if bytes == NOT_FOUND_PLACEHOLDER {
self.stats.record_negative_hit();
self.record_event("get", "negative_hit");
return Ok(Some(None));
}
match serde_json::from_slice(&bytes) {
Ok(value) => {
self.stats.record_hit();
self.record_event("get", "hit");
Ok(Some(Some(value)))
}
Err(_) => {
self.record_event("get", "corrupt");
if self.store.delete(key).await.is_err() {
self.stats.record_delete_error();
self.record_event("delete", "corrupt_error");
} else {
self.record_event("delete", "corrupt");
}
Ok(None)
}
}
}
async fn write_json<T>(&self, key: &CacheKey, value: &T) -> CacheResult<()>
where
T: Serialize + Sync,
{
let ttl = jitter_ttl(
self.config.value_ttl,
self.config.ttl_jitter_ratio,
key.render(),
);
let bytes = serde_json::to_vec(value)?;
match self.store.set_raw(key, bytes, Some(ttl)).await {
Ok(()) => {
self.record_event("set", "success");
Ok(())
}
Err(error) => {
self.stats.record_set_error();
self.record_event("set", "error");
Err(error)
}
}
}
async fn write_not_found(&self, key: &CacheKey) -> CacheResult<()> {
let ttl = jitter_ttl(
self.config.not_found_ttl,
self.config.ttl_jitter_ratio,
key.render(),
);
match self
.store
.set_raw(key, NOT_FOUND_PLACEHOLDER.to_vec(), Some(ttl))
.await
{
Ok(()) => {
self.record_event("set", "negative");
Ok(())
}
Err(error) => {
self.stats.record_set_error();
self.record_event("set", "error");
Err(error)
}
}
}
async fn key_lock(&self, rendered: &str) -> Arc<Mutex<()>> {
let mut locks = self.locks.lock().await;
locks
.entry(rendered.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
async fn release_key_lock(&self, rendered: &str, lock: &Arc<Mutex<()>>) {
let mut locks = self.locks.lock().await;
if locks
.get(rendered)
.is_some_and(|current| Arc::ptr_eq(current, lock) && Arc::strong_count(lock) == 2)
{
locks.remove(rendered);
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use crate::cache::{CacheAside, CacheAsideConfig, CacheKey, CacheStore, MemoryCacheStore};
#[tokio::test]
async fn cache_aside_merges_concurrent_misses() {
let client = CacheAside::new(
MemoryCacheStore::new(),
CacheAsideConfig {
value_ttl: Duration::from_secs(60),
..CacheAsideConfig::default()
},
);
let key = CacheKey::new("app", ["user", "42"]);
let calls = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..8 {
let client = client.clone();
let key = key.clone();
let calls = calls.clone();
handles.push(tokio::spawn(async move {
client
.get_or_load_json(&key, || async move {
calls.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(20)).await;
Ok(Some(serde_json::json!({"id":42})))
})
.await
.expect("load")
}));
}
for handle in handles {
assert_eq!(handle.await.expect("join").expect("value")["id"], 42);
}
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn cache_aside_uses_negative_cache() {
let client = CacheAside::new(MemoryCacheStore::new(), CacheAsideConfig::default());
let key = CacheKey::new("app", ["missing"]);
let calls = Arc::new(AtomicUsize::new(0));
for _ in 0..2 {
let calls = calls.clone();
let value: Option<serde_json::Value> = client
.get_or_load_json(&key, || async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok(None)
})
.await
.expect("load");
assert!(value.is_none());
}
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(client.stats().snapshot().negative_hits, 1);
}
#[tokio::test]
async fn cache_aside_deletes_corrupt_value_and_reloads() {
let store = MemoryCacheStore::new();
let client = CacheAside::new(store.clone(), CacheAsideConfig::default());
let key = CacheKey::new("app", ["corrupt"]);
let calls = Arc::new(AtomicUsize::new(0));
store
.set_raw(&key, b"{not-json".to_vec(), None)
.await
.expect("set corrupt");
let value: Option<serde_json::Value> = client
.get_or_load_json(&key, || {
let calls = calls.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok(Some(serde_json::json!({"fresh": true})))
}
})
.await
.expect("reload");
assert_eq!(value.expect("value")["fresh"], true);
assert_eq!(calls.load(Ordering::SeqCst), 1);
let cached: serde_json::Value = store.get_json(&key).await.expect("cache").expect("value");
assert_eq!(cached["fresh"], true);
}
}