use aa_storage::{RateLimitCounter, Result};
use async_trait::async_trait;
use deadpool_redis::Pool;
use crate::error::backend;
const INCREMENT_SCRIPT: &str = r"
local current = redis.call('INCRBY', KEYS[1], ARGV[1])
if tonumber(current) == tonumber(ARGV[1]) then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return current
";
#[derive(Clone)]
pub struct RedisRateLimitCounter {
pool: Pool,
}
impl RedisRateLimitCounter {
pub fn new(pool: Pool) -> Self {
Self { pool }
}
}
fn counter_key(key: &str) -> String {
format!("aa:ratelimit:{key}")
}
#[async_trait]
impl RateLimitCounter for RedisRateLimitCounter {
async fn increment(&self, key: &str, amount: u64, window_secs: u64) -> Result<u64> {
let mut conn = self.pool.get().await.map_err(backend)?;
let total: i64 = redis::Script::new(INCREMENT_SCRIPT)
.key(counter_key(key))
.arg(amount)
.arg(window_secs)
.invoke_async(&mut conn)
.await
.map_err(backend)?;
Ok(u64::try_from(total).unwrap_or(0))
}
async fn current(&self, key: &str) -> Result<u64> {
let mut conn = self.pool.get().await.map_err(backend)?;
let value: Option<u64> = redis::cmd("GET")
.arg(counter_key(key))
.query_async(&mut conn)
.await
.map_err(backend)?;
Ok(value.unwrap_or(0))
}
async fn reset(&self, key: &str) -> Result<()> {
let mut conn = self.pool.get().await.map_err(backend)?;
let _: () = redis::cmd("DEL")
.arg(counter_key(key))
.query_async(&mut conn)
.await
.map_err(backend)?;
Ok(())
}
}