rs-zero 0.2.8

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
#![cfg(feature = "cache")]

mod support;

use std::{
    sync::{
        Arc,
        atomic::{AtomicUsize, Ordering},
    },
    time::Duration,
};

use rs_zero::cache::{
    CacheAside, CacheAsideConfig, CacheKey, CacheStore, LruCacheStore, MemoryCacheStore,
    TwoLevelCacheStore,
};
use tokio::sync::Barrier;

use support::stress::{FaultyCacheStore, cache_key};

fn cache_aside_config() -> CacheAsideConfig {
    CacheAsideConfig {
        value_ttl: Duration::from_secs(30),
        not_found_ttl: Duration::from_secs(5),
        ttl_jitter_ratio: 0.0,
    }
}

async fn join_all<T>(handles: Vec<tokio::task::JoinHandle<T>>) -> Vec<T> {
    let mut values = Vec::with_capacity(handles.len());
    for handle in handles {
        values.push(handle.await.expect("join"));
    }
    values
}

async fn assert_loader_failure_releases_lock(cache: &CacheAside<FaultyCacheStore>, key: &CacheKey) {
    let failure = cache
        .get_or_load_json::<serde_json::Value, _, _>(key, || async {
            Err(rs_zero::cache::CacheError::Backend(
                "loader unavailable".to_string(),
            ))
        })
        .await
        .expect_err("loader failure");
    assert!(failure.to_string().contains("loader unavailable"));
    assert_eq!(cache.stats().snapshot().loader_errors, 1);

    let recovered: serde_json::Value = cache
        .get_or_load_json(key, || async { Ok(Some(serde_json::json!({"ok": true}))) })
        .await
        .expect("recovered")
        .expect("value");
    assert_eq!(recovered["ok"], true);
}

async fn assert_set_failure_releases_lock() {
    let set_fail_store = FaultyCacheStore::new();
    set_fail_store.fail_next_sets(1).await;
    let failing_cache = CacheAside::new(set_fail_store, cache_aside_config());
    let key = CacheKey::new("stress", ["set-fail"]);
    let set_error = failing_cache
        .get_or_load_json(&key, || async {
            Ok(Some(serde_json::json!({"first": true})))
        })
        .await
        .expect_err("set failure");
    assert!(set_error.to_string().contains("injected set failure"));
    assert_eq!(failing_cache.stats().snapshot().set_errors, 1);

    let value: serde_json::Value = failing_cache
        .get_or_load_json(&key, || async {
            Ok(Some(serde_json::json!({"second": true})))
        })
        .await
        .expect("second load")
        .expect("value");
    assert_eq!(value["second"], true);
}

async fn exercise_lru_concurrent_writes(store: LruCacheStore) {
    let barrier = Arc::new(Barrier::new(64));
    let mut handles = Vec::new();
    for task in 0..64 {
        let store = store.clone();
        let barrier = barrier.clone();
        handles.push(tokio::spawn(async move {
            barrier.wait().await;
            for round in 0..4 {
                let key = cache_key("lru", task * 4 + round);
                store
                    .set_raw(&key, vec![task as u8, round as u8], None)
                    .await
                    .expect("set");
                let _ = store.get_raw(&key).await.expect("get");
            }
        }));
    }
    for handle in handles {
        handle.await.expect("join");
    }
}

async fn assert_lru_ttl_removal(store: &LruCacheStore) {
    let ttl_key = CacheKey::new("lru", ["ttl"]);
    store
        .set_raw(&ttl_key, b"ttl".to_vec(), Some(Duration::from_millis(5)))
        .await
        .expect("ttl set");
    tokio::time::sleep(Duration::from_millis(15)).await;
    assert!(store.get_raw(&ttl_key).await.expect("ttl get").is_none());
    assert!(store.snapshot().await.expired_removals >= 1);
}

async fn seed_two_level_l2(l2: &MemoryCacheStore) {
    for index in 0..16 {
        let key = cache_key("two", index);
        l2.set_raw(&key, format!("value-{index}").into_bytes(), None)
            .await
            .expect("seed l2");
    }
}

async fn exercise_two_level_concurrent_backfill(
    store: TwoLevelCacheStore<LruCacheStore, MemoryCacheStore>,
) {
    let barrier = Arc::new(Barrier::new(32));
    let mut handles = Vec::new();
    for task in 0..32 {
        let store = store.clone();
        let barrier = barrier.clone();
        handles.push(tokio::spawn(async move {
            barrier.wait().await;
            let key = cache_key("two", task % 16);
            if let Some(value) = store.get_raw(&key).await.expect("get") {
                assert!(value.starts_with(b"value-"));
            }
            if task % 5 == 0 {
                store.delete(&key).await.expect("delete");
            }
        }));
    }
    for handle in handles {
        handle.await.expect("join");
    }
}

async fn assert_l2_set_failure_does_not_pollute_l1() {
    let failing_l2 = FaultyCacheStore::new();
    failing_l2.fail_next_sets(1).await;
    let protected_l1 = MemoryCacheStore::new();
    let failing_store = TwoLevelCacheStore::new(protected_l1.clone(), failing_l2);
    let failed_key = CacheKey::new("two", ["set-failure"]);
    let error = failing_store
        .set_raw(&failed_key, b"should-not-reach-l1".to_vec(), None)
        .await
        .expect_err("l2 set failure");
    assert!(error.to_string().contains("injected set failure"));
    assert!(
        protected_l1
            .get_raw(&failed_key)
            .await
            .expect("l1")
            .is_none()
    );
}

fn spawn_hot_key_loaders(
    cache: CacheAside<FaultyCacheStore>,
    key: CacheKey,
    calls: Arc<AtomicUsize>,
    barrier: Arc<Barrier>,
) -> Vec<tokio::task::JoinHandle<Option<serde_json::Value>>> {
    let mut handles = Vec::new();
    for _ in 0..32 {
        let cache = cache.clone();
        let key = key.clone();
        let calls = calls.clone();
        let barrier = barrier.clone();
        handles.push(tokio::spawn(async move {
            barrier.wait().await;
            cache
                .get_or_load_json(&key, || async move {
                    calls.fetch_add(1, Ordering::SeqCst);
                    tokio::time::sleep(Duration::from_millis(20)).await;
                    Ok(Some(serde_json::json!({"id": 7})))
                })
                .await
                .expect("cache aside load")
        }));
    }
    handles
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn cache_aside_hot_key_pressure_merges_loader_calls() {
    let store = FaultyCacheStore::new();
    store
        .set_delays(
            Duration::from_millis(1),
            Duration::from_millis(1),
            Duration::ZERO,
        )
        .await;
    let cache = CacheAside::new(store.clone(), cache_aside_config());
    let key = CacheKey::new("stress", ["hot"]);
    let calls = Arc::new(AtomicUsize::new(0));
    let handles = spawn_hot_key_loaders(
        cache.clone(),
        key,
        calls.clone(),
        Arc::new(Barrier::new(32)),
    );

    for value in join_all(handles).await {
        assert_eq!(value.expect("value")["id"], 7);
    }

    let stats = cache.stats().snapshot();
    assert_eq!(calls.load(Ordering::SeqCst), 1);
    assert_eq!(stats.misses, 32);
    assert!(stats.hits >= 31, "expected lock followers to hit cache");
    assert_eq!(stats.loader_errors, 0);
    assert_eq!(store.snapshot().await.set_calls, 1);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn cache_aside_multi_key_pressure_shards_singleflight() {
    let cache = CacheAside::new(MemoryCacheStore::new(), cache_aside_config());
    let loader_calls = Arc::new(AtomicUsize::new(0));
    let barrier = Arc::new(Barrier::new(40));

    let mut handles = Vec::new();
    for task in 0..40 {
        let cache = cache.clone();
        let loader_calls = loader_calls.clone();
        let barrier = barrier.clone();
        handles.push(tokio::spawn(async move {
            barrier.wait().await;
            let key_index = task % 8;
            let key = cache_key("multi", key_index);
            let value: serde_json::Value = cache
                .get_or_load_json(&key, || async move {
                    loader_calls.fetch_add(1, Ordering::SeqCst);
                    tokio::time::sleep(Duration::from_millis(5)).await;
                    Ok(Some(serde_json::json!({"key": key_index})))
                })
                .await
                .expect("cache aside")
                .expect("value");
            value["key"].as_u64().expect("key")
        }));
    }

    for value in join_all(handles).await {
        assert!(value < 8);
    }

    assert_eq!(loader_calls.load(Ordering::SeqCst), 8);
    let stats = cache.stats().snapshot();
    assert_eq!(stats.loader_errors, 0);
    assert!(stats.hits >= 32);
}

#[tokio::test]
async fn cache_aside_loader_failure_and_set_failure_release_hot_key_lock() {
    let store = FaultyCacheStore::new();
    let cache = CacheAside::new(store, cache_aside_config());
    let key = CacheKey::new("stress", ["recover"]);

    assert_loader_failure_releases_lock(&cache, &key).await;
    assert_set_failure_releases_lock().await;
}

#[tokio::test]
async fn cache_aside_negative_and_corrupt_entries_remain_diagnostic_under_faults() {
    let store = FaultyCacheStore::new();
    let cache = CacheAside::new(store.clone(), cache_aside_config());
    let missing = CacheKey::new("stress", ["missing"]);

    let first: Option<serde_json::Value> = cache
        .get_or_load_json(&missing, || async { Ok(None) })
        .await
        .expect("negative load");
    assert!(first.is_none());
    let second: Option<serde_json::Value> = cache
        .get_or_load_json(&missing, || async {
            panic!("negative cache hit must not call loader")
        })
        .await
        .expect("negative hit");
    assert!(second.is_none());
    assert_eq!(cache.stats().snapshot().negative_hits, 1);

    let corrupt = CacheKey::new("stress", ["corrupt"]);
    store.insert_raw(&corrupt, b"not-json".to_vec()).await;
    store.fail_next_deletes(1).await;
    let value: serde_json::Value = cache
        .get_or_load_json(&corrupt, || async {
            Ok(Some(serde_json::json!({"fixed": 1})))
        })
        .await
        .expect("reload after corrupt")
        .expect("value");
    assert_eq!(value["fixed"], 1);
    assert_eq!(cache.stats().snapshot().delete_errors, 1);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn lru_concurrent_writes_keep_capacity_and_expire_entries() {
    let store = LruCacheStore::new(16).expect("lru");
    exercise_lru_concurrent_writes(store.clone()).await;

    let snapshot = store.snapshot().await;
    assert_eq!(snapshot.capacity, 16);
    assert!(snapshot.entries <= 16);
    assert!(snapshot.evictions >= 240, "snapshot: {snapshot:?}");

    assert_lru_ttl_removal(&store).await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn two_level_concurrent_backfill_delete_and_l2_set_failure_preserve_invariants() {
    let l1 = LruCacheStore::new(8).expect("l1");
    let l2 = MemoryCacheStore::new();
    seed_two_level_l2(&l2).await;

    let store = TwoLevelCacheStore::new(l1.clone(), l2);
    exercise_two_level_concurrent_backfill(store.clone()).await;

    let snapshot = store.stats().snapshot().await;
    assert!(snapshot.l1_misses >= 16, "snapshot: {snapshot:?}");
    assert!(snapshot.l2_hits >= 16, "snapshot: {snapshot:?}");
    assert!(snapshot.backfills >= 16, "snapshot: {snapshot:?}");
    assert!(l1.snapshot().await.entries <= 8);

    assert_l2_set_failure_does_not_pollute_l1().await;
}