accelerator 0.1.1

MVP multi-level cache runtime with singleflight load de-duplication
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use accelerator::builder::LevelCacheBuilder;
use accelerator::cache::ReadOptions;
use accelerator::config::CacheMode;
use accelerator::loader::{Loader, MLoader};
use accelerator::{CacheError, CacheResult, local, remote};
use redis::AsyncCommands;
use sqlx::PgPool;
use sqlx::postgres::PgPoolOptions;

#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct DbUser {
    id: u64,
    name: String,
}

#[derive(Clone)]
struct PgUserLoader {
    pool: PgPool,
    load_calls: Arc<AtomicUsize>,
    mload_calls: Arc<AtomicUsize>,
}

impl PgUserLoader {
    fn new(pool: PgPool) -> Self {
        Self {
            pool,
            load_calls: Arc::new(AtomicUsize::new(0)),
            mload_calls: Arc::new(AtomicUsize::new(0)),
        }
    }

    fn load_calls(&self) -> usize {
        self.load_calls.load(Ordering::SeqCst)
    }

    fn mload_calls(&self) -> usize {
        self.mload_calls.load(Ordering::SeqCst)
    }
}

impl Loader<u64, DbUser> for PgUserLoader {
    async fn load(&self, key: &u64) -> CacheResult<Option<DbUser>> {
        self.load_calls.fetch_add(1, Ordering::SeqCst);

        let row = sqlx::query_as::<_, (i64, String)>(
            "SELECT id, name FROM accelerator_users WHERE id = $1",
        )
        .bind(*key as i64)
        .fetch_optional(&self.pool)
        .await
        .map_err(|err| CacheError::Loader(format!("sqlx load failed: {err}")))?;

        Ok(row.map(|(id, name)| DbUser {
            id: id as u64,
            name,
        }))
    }
}

impl MLoader<u64, DbUser> for PgUserLoader {
    async fn mload(&self, keys: &[u64]) -> CacheResult<HashMap<u64, Option<DbUser>>> {
        self.mload_calls.fetch_add(1, Ordering::SeqCst);

        if keys.is_empty() {
            return Ok(HashMap::new());
        }

        let ids = keys.iter().map(|key| *key as i64).collect::<Vec<_>>();
        let rows = sqlx::query_as::<_, (i64, String)>(
            "SELECT id, name FROM accelerator_users WHERE id = ANY($1::bigint[])",
        )
        .bind(&ids)
        .fetch_all(&self.pool)
        .await
        .map_err(|err| CacheError::Loader(format!("sqlx mload failed: {err}")))?;

        let mut found = HashMap::with_capacity(rows.len());
        for (id, name) in rows {
            found.insert(
                id as u64,
                DbUser {
                    id: id as u64,
                    name,
                },
            );
        }

        let mut values = HashMap::with_capacity(keys.len());
        for key in keys {
            values.insert(*key, found.get(key).cloned());
        }

        Ok(values)
    }
}

fn redis_url() -> String {
    std::env::var("ACCELERATOR_TEST_REDIS_URL")
        .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string())
}

fn postgres_dsn() -> String {
    std::env::var("ACCELERATOR_TEST_POSTGRES_DSN").unwrap_or_else(|_| {
        "postgres://accelerator:accelerator@127.0.0.1:5432/accelerator".to_string()
    })
}

async fn redis_ready(url: &str) -> bool {
    let client = match redis::Client::open(url) {
        Ok(client) => client,
        Err(_) => return false,
    };

    let mut conn = match client.get_multiplexed_async_connection().await {
        Ok(conn) => conn,
        Err(_) => return false,
    };

    conn.ping::<String>().await.is_ok()
}

async fn postgres_connect(dsn: &str) -> Option<PgPool> {
    PgPoolOptions::new()
        .max_connections(5)
        .acquire_timeout(Duration::from_secs(2))
        .connect(dsn)
        .await
        .ok()
}

fn unique_scope(tag: &str) -> String {
    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_nanos();
    format!("{tag}-{}-{nanos}", std::process::id())
}

async fn skip_if_stack_unavailable(test_name: &str) -> Option<(String, PgPool)> {
    let redis = redis_url();
    if !redis_ready(&redis).await {
        eprintln!("skip `{test_name}`: redis is not reachable at {redis}");
        return None;
    }

    let dsn = postgres_dsn();
    let Some(pg_client) = postgres_connect(&dsn).await else {
        eprintln!("skip `{test_name}`: postgres is not reachable with dsn `{dsn}`");
        return None;
    };

    Some((redis, pg_client))
}

#[tokio::test]
async fn cache_component_works_with_redis_and_postgres_loader() {
    let Some((redis, pg_pool)) =
        skip_if_stack_unavailable("cache_component_works_with_redis_and_postgres_loader").await
    else {
        return;
    };

    sqlx::query(
        "CREATE TABLE IF NOT EXISTS accelerator_users (id BIGINT PRIMARY KEY, name TEXT NOT NULL)",
    )
    .execute(&pg_pool)
    .await
    .unwrap();

    let user_1 = 91_001_i64;
    let user_2 = 91_002_i64;

    sqlx::query(
        "INSERT INTO accelerator_users(id, name) VALUES($1, $2)
         ON CONFLICT(id) DO UPDATE SET name = EXCLUDED.name",
    )
    .bind(user_1)
    .bind("alice-v1")
    .execute(&pg_pool)
    .await
    .unwrap();
    sqlx::query(
        "INSERT INTO accelerator_users(id, name) VALUES($1, $2)
         ON CONFLICT(id) DO UPDATE SET name = EXCLUDED.name",
    )
    .bind(user_2)
    .bind("bob-v1")
    .execute(&pg_pool)
    .await
    .unwrap();

    let scope = unique_scope("stack-it");
    let area = format!("area-{scope}");
    let local_backend = local::moka::<DbUser>().max_capacity(128).build().unwrap();
    let remote_backend = remote::redis::<DbUser>()
        .url(redis)
        .key_prefix(scope.clone())
        .build()
        .unwrap();

    let loader = PgUserLoader::new(pg_pool.clone());
    let loader_probe = loader.clone();

    let cache = LevelCacheBuilder::<u64, DbUser, PgUserLoader>::new()
        .area(area.clone())
        .mode(CacheMode::Both)
        .local(local_backend)
        .remote(remote_backend)
        .penetration_protect(false)
        .loader(loader)
        .local_ttl(Duration::from_secs(60))
        .remote_ttl(Duration::from_secs(120))
        .null_ttl(Duration::from_secs(10))
        .build()
        .unwrap();

    let first = cache
        .get(&(user_1 as u64), &ReadOptions::default())
        .await
        .unwrap();
    assert_eq!(
        first,
        Some(DbUser {
            id: user_1 as u64,
            name: "alice-v1".to_string()
        })
    );
    assert_eq!(loader_probe.load_calls(), 1);

    sqlx::query("UPDATE accelerator_users SET name = $2 WHERE id = $1")
        .bind(user_1)
        .bind("alice-v2")
        .execute(&pg_pool)
        .await
        .unwrap();

    let second = cache
        .get(&(user_1 as u64), &ReadOptions::default())
        .await
        .unwrap();
    assert_eq!(
        second,
        Some(DbUser {
            id: user_1 as u64,
            name: "alice-v1".to_string()
        })
    );
    assert_eq!(loader_probe.load_calls(), 1);

    let batch = cache
        .mget(
            &[user_1 as u64, user_2 as u64, 91_003_u64],
            &ReadOptions::default(),
        )
        .await
        .unwrap();

    assert_eq!(
        batch.get(&(user_1 as u64)).cloned().flatten(),
        Some(DbUser {
            id: user_1 as u64,
            name: "alice-v1".to_string()
        })
    );
    assert_eq!(
        batch.get(&(user_2 as u64)).cloned().flatten(),
        Some(DbUser {
            id: user_2 as u64,
            name: "bob-v1".to_string()
        })
    );
    assert_eq!(batch.get(&91_003_u64).cloned().flatten(), None);
    assert_eq!(loader_probe.mload_calls(), 1);

    let otel = cache.otel_metric_points();
    assert!(!otel.is_empty());
    assert!(
        otel.iter()
            .all(|point| { point.attributes == vec![("area", area.clone())] })
    );

    cache.del(&(user_1 as u64)).await.unwrap();
    let after_del = cache
        .get(
            &(user_1 as u64),
            &ReadOptions {
                allow_stale: false,
                disable_load: true,
            },
        )
        .await
        .unwrap();
    assert_eq!(after_del, None);

    sqlx::query("DELETE FROM accelerator_users WHERE id = ANY($1::bigint[])")
        .bind(&vec![user_1, user_2])
        .execute(&pg_pool)
        .await
        .unwrap();
}