use std::time::Duration;
use redis::aio::ConnectionManager;
use redis::Script;
use tracing::{error, info, warn};
#[derive(Clone)]
pub struct RedisRateLimiter {
conn: ConnectionManager,
max_requests: u32,
window_secs: u64,
key_prefix: String,
script: Script,
}
#[derive(Debug, Clone)]
pub struct RedisRateLimitResult {
pub allowed: bool,
pub current_count: u32,
pub max_requests: u32,
pub remaining: u32,
}
const RATE_LIMIT_SCRIPT: &str = r#"
local key = KEYS[1]
local max_requests = tonumber(ARGV[1])
local window_secs = tonumber(ARGV[2])
local current = redis.call('INCR', key)
if current == 1 then
redis.call('EXPIRE', key, window_secs)
end
local ttl = redis.call('TTL', key)
-- Safety: if TTL is -1 (no expiry set, race condition), reset it
if ttl == -1 then
redis.call('EXPIRE', key, window_secs)
ttl = window_secs
end
return {current, ttl}
"#;
impl RedisRateLimiter {
pub async fn new(
redis_url: &str,
max_requests: u32,
window: Duration,
) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(redis_url)?;
let conn = ConnectionManager::new(client).await?;
info!(
"Redis rate limiter connected to {} (max {} requests per {}s)",
redis_url,
max_requests,
window.as_secs()
);
Ok(Self {
conn,
max_requests,
window_secs: window.as_secs(),
key_prefix: "datasynth:ratelimit:".to_string(),
script: Script::new(RATE_LIMIT_SCRIPT),
})
}
pub fn with_key_prefix(mut self, prefix: String) -> Self {
self.key_prefix = prefix;
self
}
fn make_key(&self, client_key: &str) -> String {
format!("{}{}", self.key_prefix, client_key)
}
pub async fn check_rate_limit(&self, client_key: &str) -> RedisRateLimitResult {
let key = self.make_key(client_key);
let mut conn = self.conn.clone();
match self
.script
.key(&key)
.arg(self.max_requests)
.arg(self.window_secs)
.invoke_async::<Vec<u32>>(&mut conn)
.await
{
Ok(result) if result.len() == 2 => {
let current_count = result[0];
let allowed = current_count <= self.max_requests;
let remaining = if allowed {
self.max_requests.saturating_sub(current_count)
} else {
0
};
RedisRateLimitResult {
allowed,
current_count,
max_requests: self.max_requests,
remaining,
}
}
Ok(result) => {
warn!(
"Unexpected Redis rate limit script result length: {}",
result.len()
);
RedisRateLimitResult {
allowed: true,
current_count: 0,
max_requests: self.max_requests,
remaining: self.max_requests,
}
}
Err(e) => {
error!("Redis rate limit check failed (failing open): {}", e);
RedisRateLimitResult {
allowed: true,
current_count: 0,
max_requests: self.max_requests,
remaining: self.max_requests,
}
}
}
}
pub async fn remaining(&self, client_key: &str) -> u32 {
let key = self.make_key(client_key);
let mut conn = self.conn.clone();
match redis::cmd("GET")
.arg(&key)
.query_async::<Option<u32>>(&mut conn)
.await
{
Ok(Some(count)) => self.max_requests.saturating_sub(count),
Ok(None) => self.max_requests,
Err(e) => {
error!("Redis remaining check failed (failing open): {}", e);
self.max_requests
}
}
}
pub fn max_requests(&self) -> u32 {
self.max_requests
}
pub fn window_secs(&self) -> u64 {
self.window_secs
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_make_key() {
let key_prefix = "datasynth:ratelimit:".to_string();
let key = format!("{}192.168.1.1", key_prefix);
assert_eq!(key, "datasynth:ratelimit:192.168.1.1");
}
#[test]
fn test_rate_limit_result() {
let result = RedisRateLimitResult {
allowed: true,
current_count: 5,
max_requests: 100,
remaining: 95,
};
assert!(result.allowed);
assert_eq!(result.remaining, 95);
}
#[test]
fn test_rate_limit_result_exceeded() {
let result = RedisRateLimitResult {
allowed: false,
current_count: 101,
max_requests: 100,
remaining: 0,
};
assert!(!result.allowed);
assert_eq!(result.remaining, 0);
}
}