use std::time::Duration;
use async_trait::async_trait;
use redis::aio::ConnectionManager;
use super::{RateLimitOutcome, RateLimiter};
use crate::error::DataError;
const CHECK_SCRIPT: &str = "\
local current = redis.call('INCR', KEYS[1])
if current == 1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
end
return {current, redis.call('PTTL', KEYS[1])}";
#[derive(Clone)]
pub struct RedisRateLimiter {
conn: ConnectionManager,
prefix: String,
}
impl RedisRateLimiter {
#[must_use]
pub fn new(conn: ConnectionManager) -> Self {
Self { conn, prefix: "ratelimit:".to_owned() }
}
#[must_use]
pub fn with_prefix(conn: ConnectionManager, prefix: impl Into<String>) -> Self {
Self { conn, prefix: prefix.into() }
}
}
#[async_trait]
impl RateLimiter for RedisRateLimiter {
async fn check(
&self,
key: &str,
max: u32,
window: Duration,
) -> Result<RateLimitOutcome, DataError> {
let max = max.max(1);
let window_ms = window.as_millis().min(i64::MAX as u128) as i64;
let redis_key = format!("{}{key}", self.prefix);
let mut conn = self.conn.clone();
let (count, pttl_ms): (i64, i64) = redis::Script::new(CHECK_SCRIPT)
.key(redis_key)
.arg(window_ms)
.invoke_async(&mut conn)
.await?;
if count <= i64::from(max) {
let remaining = (i64::from(max) - count).max(0) as u32;
Ok(RateLimitOutcome::Allowed { remaining })
} else {
let retry_ms = if pttl_ms > 0 { pttl_ms as u64 } else { window_ms.max(0) as u64 };
Ok(RateLimitOutcome::Limited { retry_after: Duration::from_millis(retry_ms) })
}
}
}
const TOKEN_BUCKET_SCRIPT: &str = "\
local capacity = tonumber(ARGV[1])
local refill = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local t = redis.call('TIME')
local now = (tonumber(t[1]) * 1000) + (tonumber(t[2]) / 1000)
local data = redis.call('HMGET', KEYS[1], 'tokens', 'ts')
local tokens = tonumber(data[1])
local ts = tonumber(data[2])
if tokens == nil then tokens = capacity; ts = now end
local elapsed = now - ts
if elapsed < 0 then elapsed = 0 end
tokens = math.min(capacity, tokens + (elapsed * refill))
local allowed = 0
local retry = 0
if tokens >= 1 then
tokens = tokens - 1
allowed = 1
else
retry = math.ceil((1 - tokens) / refill)
end
redis.call('HSET', KEYS[1], 'tokens', tokens, 'ts', now)
redis.call('PEXPIRE', KEYS[1], ttl)
return {allowed, math.floor(tokens), retry}";
#[derive(Clone)]
pub struct RedisTokenBucket {
conn: ConnectionManager,
prefix: String,
}
impl RedisTokenBucket {
#[must_use]
pub fn new(conn: ConnectionManager) -> Self {
Self { conn, prefix: "ratelimit:tb:".to_owned() }
}
#[must_use]
pub fn with_prefix(conn: ConnectionManager, prefix: impl Into<String>) -> Self {
Self { conn, prefix: prefix.into() }
}
}
#[async_trait]
impl RateLimiter for RedisTokenBucket {
async fn check(
&self,
key: &str,
max: u32,
window: Duration,
) -> Result<RateLimitOutcome, DataError> {
let capacity = f64::from(max.max(1));
let window_ms = window.as_secs_f64().max(f64::MIN_POSITIVE) * 1000.0;
let refill_per_ms = capacity / window_ms;
let ttl_ms = window.as_millis().min(i64::MAX as u128) as i64;
let redis_key = format!("{}{key}", self.prefix);
let mut conn = self.conn.clone();
let (allowed, remaining, retry_ms): (i64, i64, i64) =
redis::Script::new(TOKEN_BUCKET_SCRIPT)
.key(redis_key)
.arg(capacity)
.arg(refill_per_ms)
.arg(ttl_ms)
.invoke_async(&mut conn)
.await?;
if allowed == 1 {
Ok(RateLimitOutcome::Allowed { remaining: remaining.max(0) as u32 })
} else {
Ok(RateLimitOutcome::Limited {
retry_after: Duration::from_millis(retry_ms.max(0) as u64),
})
}
}
}