aa-storage-redis 0.0.1-beta.2

Redis L2 shared-cache storage driver (SessionStore, RateLimitCounter, PolicyStore) for Agent Assembly
Documentation
//! [`RateLimitCounter`] using an atomic Lua `INCRBY` + `EXPIRE` script.

use aa_storage::{RateLimitCounter, Result};
use async_trait::async_trait;
use deadpool_redis::Pool;

use crate::error::backend;

/// Lua source executed atomically by Redis for each
/// [`increment`](RateLimitCounter::increment).
///
/// `INCRBY` the counter, then arm `EXPIRE` only when this call created the key
/// (the returned total equals the amount just added). Running both commands
/// inside one script makes the read-modify-write atomic with respect to
/// concurrent callers and starts a **fixed** window at the first increment.
const INCREMENT_SCRIPT: &str = r"
local current = redis.call('INCRBY', KEYS[1], ARGV[1])
if tonumber(current) == tonumber(ARGV[1]) then
    redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return current
";

/// Redis-backed [`RateLimitCounter`].
///
/// Counters live at `aa:ratelimit:<key>`. Cheap to [`Clone`] — clones share
/// the underlying [`Pool`].
#[derive(Clone)]
pub struct RedisRateLimitCounter {
    pool: Pool,
}

impl RedisRateLimitCounter {
    /// Create a counter over an existing connection pool.
    pub fn new(pool: Pool) -> Self {
        Self { pool }
    }
}

fn counter_key(key: &str) -> String {
    format!("aa:ratelimit:{key}")
}

#[async_trait]
impl RateLimitCounter for RedisRateLimitCounter {
    async fn increment(&self, key: &str, amount: u64, window_secs: u64) -> Result<u64> {
        let mut conn = self.pool.get().await.map_err(backend)?;
        let total: i64 = redis::Script::new(INCREMENT_SCRIPT)
            .key(counter_key(key))
            .arg(amount)
            .arg(window_secs)
            .invoke_async(&mut conn)
            .await
            .map_err(backend)?;
        Ok(u64::try_from(total).unwrap_or(0))
    }

    async fn current(&self, key: &str) -> Result<u64> {
        let mut conn = self.pool.get().await.map_err(backend)?;
        let value: Option<u64> = redis::cmd("GET")
            .arg(counter_key(key))
            .query_async(&mut conn)
            .await
            .map_err(backend)?;
        Ok(value.unwrap_or(0))
    }

    async fn reset(&self, key: &str) -> Result<()> {
        let mut conn = self.pool.get().await.map_err(backend)?;
        let _: () = redis::cmd("DEL")
            .arg(counter_key(key))
            .query_async(&mut conn)
            .await
            .map_err(backend)?;
        Ok(())
    }
}