rs-zero 0.2.8

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};

use serde::{Serialize, de::DeserializeOwned};
use tokio::sync::Mutex;

use crate::cache::{CacheKey, CacheResult, CacheStats, CacheStore, jitter_ttl};

const NOT_FOUND_PLACEHOLDER: &[u8] = b"__rs_zero_not_found__";

/// Cache-aside strategy configuration.
#[derive(Debug, Clone, PartialEq)]
pub struct CacheAsideConfig {
    /// TTL for positive values loaded from the source of truth.
    pub value_ttl: Duration,
    /// TTL for negative values returned as not found.
    pub not_found_ttl: Duration,
    /// Per-key TTL jitter ratio.
    pub ttl_jitter_ratio: f64,
}

impl Default for CacheAsideConfig {
    fn default() -> Self {
        Self {
            value_ttl: Duration::from_secs(300),
            not_found_ttl: Duration::from_secs(60),
            ttl_jitter_ratio: 0.05,
        }
    }
}

/// High-level cache-aside client with singleflight and negative caching.
#[derive(Debug, Clone)]
pub struct CacheAside<S> {
    store: S,
    config: CacheAsideConfig,
    stats: CacheStats,
    locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
    #[cfg(feature = "observability")]
    metrics: Option<crate::observability::MetricsRegistry>,
}

impl<S> CacheAside<S> {
    /// Creates a cache-aside helper around the given store.
    pub fn new(store: S, config: CacheAsideConfig) -> Self {
        Self {
            store,
            config,
            stats: CacheStats::default(),
            locks: Arc::new(Mutex::new(HashMap::new())),
            #[cfg(feature = "observability")]
            metrics: None,
        }
    }

    /// Returns shared cache statistics.
    pub fn stats(&self) -> CacheStats {
        self.stats.clone()
    }

    /// Attaches a metrics registry to this cache-aside helper.
    #[cfg(feature = "observability")]
    pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
        self.metrics = Some(metrics);
        self
    }

    fn record_event(&self, operation: &str, result: &str) {
        #[cfg(feature = "observability")]
        crate::observability::cache::record_cache_event(
            self.metrics.as_ref(),
            "cache_aside",
            operation,
            result,
        );

        #[cfg(not(feature = "observability"))]
        {
            let _ = (operation, result);
        }
    }
}

impl<S> CacheAside<S>
where
    S: CacheStore,
{
    /// Deletes a cached value and records delete failures in shared stats.
    pub async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
        match self.store.delete(key).await {
            Ok(()) => {
                self.record_event("delete", "success");
                Ok(())
            }
            Err(error) => {
                self.stats.record_delete_error();
                self.record_event("delete", "error");
                Err(error)
            }
        }
    }

    /// Reads JSON from cache or loads it once for concurrent misses.
    pub async fn get_or_load_json<T, F, Fut>(
        &self,
        key: &CacheKey,
        loader: F,
    ) -> CacheResult<Option<T>>
    where
        T: DeserializeOwned + Serialize + Send + Sync,
        F: FnOnce() -> Fut + Send,
        Fut: Future<Output = CacheResult<Option<T>>> + Send,
    {
        if let Some(value) = self.read_cached_json(key).await? {
            return Ok(value);
        }

        self.stats.record_miss();
        self.record_event("get", "miss");
        let rendered = key.render();
        let lock = self.key_lock(&rendered).await;
        let guard = lock.lock().await;

        if let Some(value) = self.read_cached_json(key).await? {
            drop(guard);
            self.release_key_lock(&rendered, &lock).await;
            return Ok(value);
        }

        let loaded = loader().await.inspect_err(|_| {
            self.stats.record_loader_error();
            self.record_event("load", "error");
        })?;
        match loaded.as_ref() {
            Some(value) => self.write_json(key, value).await?,
            None => self.write_not_found(key).await?,
        }

        drop(guard);
        self.release_key_lock(&rendered, &lock).await;
        Ok(loaded)
    }

    async fn read_cached_json<T>(&self, key: &CacheKey) -> CacheResult<Option<Option<T>>>
    where
        T: DeserializeOwned + Send,
    {
        let Some(bytes) = self.store.get_raw(key).await? else {
            return Ok(None);
        };

        if bytes == NOT_FOUND_PLACEHOLDER {
            self.stats.record_negative_hit();
            self.record_event("get", "negative_hit");
            return Ok(Some(None));
        }

        match serde_json::from_slice(&bytes) {
            Ok(value) => {
                self.stats.record_hit();
                self.record_event("get", "hit");
                Ok(Some(Some(value)))
            }
            Err(_) => {
                self.record_event("get", "corrupt");
                if self.store.delete(key).await.is_err() {
                    self.stats.record_delete_error();
                    self.record_event("delete", "corrupt_error");
                } else {
                    self.record_event("delete", "corrupt");
                }
                Ok(None)
            }
        }
    }

    async fn write_json<T>(&self, key: &CacheKey, value: &T) -> CacheResult<()>
    where
        T: Serialize + Sync,
    {
        let ttl = jitter_ttl(
            self.config.value_ttl,
            self.config.ttl_jitter_ratio,
            key.render(),
        );
        let bytes = serde_json::to_vec(value)?;
        match self.store.set_raw(key, bytes, Some(ttl)).await {
            Ok(()) => {
                self.record_event("set", "success");
                Ok(())
            }
            Err(error) => {
                self.stats.record_set_error();
                self.record_event("set", "error");
                Err(error)
            }
        }
    }

    async fn write_not_found(&self, key: &CacheKey) -> CacheResult<()> {
        let ttl = jitter_ttl(
            self.config.not_found_ttl,
            self.config.ttl_jitter_ratio,
            key.render(),
        );
        match self
            .store
            .set_raw(key, NOT_FOUND_PLACEHOLDER.to_vec(), Some(ttl))
            .await
        {
            Ok(()) => {
                self.record_event("set", "negative");
                Ok(())
            }
            Err(error) => {
                self.stats.record_set_error();
                self.record_event("set", "error");
                Err(error)
            }
        }
    }

    async fn key_lock(&self, rendered: &str) -> Arc<Mutex<()>> {
        let mut locks = self.locks.lock().await;
        locks
            .entry(rendered.to_string())
            .or_insert_with(|| Arc::new(Mutex::new(())))
            .clone()
    }

    async fn release_key_lock(&self, rendered: &str, lock: &Arc<Mutex<()>>) {
        let mut locks = self.locks.lock().await;
        if locks
            .get(rendered)
            .is_some_and(|current| Arc::ptr_eq(current, lock) && Arc::strong_count(lock) == 2)
        {
            locks.remove(rendered);
        }
    }
}

#[cfg(test)]
mod tests {
    use std::{
        sync::{
            Arc,
            atomic::{AtomicUsize, Ordering},
        },
        time::Duration,
    };

    use crate::cache::{CacheAside, CacheAsideConfig, CacheKey, CacheStore, MemoryCacheStore};

    #[tokio::test]
    async fn cache_aside_merges_concurrent_misses() {
        let client = CacheAside::new(
            MemoryCacheStore::new(),
            CacheAsideConfig {
                value_ttl: Duration::from_secs(60),
                ..CacheAsideConfig::default()
            },
        );
        let key = CacheKey::new("app", ["user", "42"]);
        let calls = Arc::new(AtomicUsize::new(0));

        let mut handles = Vec::new();
        for _ in 0..8 {
            let client = client.clone();
            let key = key.clone();
            let calls = calls.clone();
            handles.push(tokio::spawn(async move {
                client
                    .get_or_load_json(&key, || async move {
                        calls.fetch_add(1, Ordering::SeqCst);
                        tokio::time::sleep(Duration::from_millis(20)).await;
                        Ok(Some(serde_json::json!({"id":42})))
                    })
                    .await
                    .expect("load")
            }));
        }

        for handle in handles {
            assert_eq!(handle.await.expect("join").expect("value")["id"], 42);
        }
        assert_eq!(calls.load(Ordering::SeqCst), 1);
    }

    #[tokio::test]
    async fn cache_aside_uses_negative_cache() {
        let client = CacheAside::new(MemoryCacheStore::new(), CacheAsideConfig::default());
        let key = CacheKey::new("app", ["missing"]);
        let calls = Arc::new(AtomicUsize::new(0));

        for _ in 0..2 {
            let calls = calls.clone();
            let value: Option<serde_json::Value> = client
                .get_or_load_json(&key, || async move {
                    calls.fetch_add(1, Ordering::SeqCst);
                    Ok(None)
                })
                .await
                .expect("load");
            assert!(value.is_none());
        }

        assert_eq!(calls.load(Ordering::SeqCst), 1);
        assert_eq!(client.stats().snapshot().negative_hits, 1);
    }

    #[tokio::test]
    async fn cache_aside_deletes_corrupt_value_and_reloads() {
        let store = MemoryCacheStore::new();
        let client = CacheAside::new(store.clone(), CacheAsideConfig::default());
        let key = CacheKey::new("app", ["corrupt"]);
        let calls = Arc::new(AtomicUsize::new(0));

        store
            .set_raw(&key, b"{not-json".to_vec(), None)
            .await
            .expect("set corrupt");

        let value: Option<serde_json::Value> = client
            .get_or_load_json(&key, || {
                let calls = calls.clone();
                async move {
                    calls.fetch_add(1, Ordering::SeqCst);
                    Ok(Some(serde_json::json!({"fresh": true})))
                }
            })
            .await
            .expect("reload");

        assert_eq!(value.expect("value")["fresh"], true);
        assert_eq!(calls.load(Ordering::SeqCst), 1);
        let cached: serde_json::Value = store.get_json(&key).await.expect("cache").expect("value");
        assert_eq!(cached["fresh"], true);
    }
}