use std::time::Duration;
use redis::{aio::ConnectionManager, AsyncCommands, Client};
use super::RateLimitError;
#[derive(Debug, Clone)]
pub struct RedisRateLimiterConfig {
pub max_requests: u32,
pub window_seconds: u64,
pub key_prefix: String,
}
impl Default for RedisRateLimiterConfig {
fn default() -> Self {
Self {
max_requests: 100,
window_seconds: 60,
key_prefix: "ratelimit".to_string(),
}
}
}
impl RedisRateLimiterConfig {
pub fn new(max_requests: u32, window_seconds: u64) -> Self {
Self {
max_requests,
window_seconds,
key_prefix: "ratelimit".to_string(),
}
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.key_prefix = prefix.into();
self
}
}
#[derive(Debug)]
pub enum RedisRateLimiterError {
Connection(String),
Redis(String),
}
impl std::fmt::Display for RedisRateLimiterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RedisRateLimiterError::Connection(msg) => write!(f, "Redis connection error: {}", msg),
RedisRateLimiterError::Redis(msg) => write!(f, "Redis error: {}", msg),
}
}
}
impl std::error::Error for RedisRateLimiterError {}
impl From<redis::RedisError> for RedisRateLimiterError {
fn from(err: redis::RedisError) -> Self {
RedisRateLimiterError::Redis(err.to_string())
}
}
pub struct RedisRateLimiter {
conn: ConnectionManager,
config: RedisRateLimiterConfig,
}
impl RedisRateLimiter {
pub async fn new(
redis_url: &str,
max_requests: u32,
window_seconds: u64,
) -> Result<Self, RedisRateLimiterError> {
Self::with_config(
redis_url,
RedisRateLimiterConfig::new(max_requests, window_seconds),
)
.await
}
pub async fn with_config(
redis_url: &str,
config: RedisRateLimiterConfig,
) -> Result<Self, RedisRateLimiterError> {
let client = Client::open(redis_url)
.map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
let conn = ConnectionManager::new(client)
.await
.map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
Ok(Self { conn, config })
}
pub fn from_connection(conn: ConnectionManager, config: RedisRateLimiterConfig) -> Self {
Self { conn, config }
}
pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
let redis_key = format!("{}:{}", self.config.key_prefix, key);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as f64;
let window_start = now - (self.config.window_seconds as f64 * 1000.0);
let mut conn = self.conn.clone();
let script = redis::Script::new(
r#"
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window_start = tonumber(ARGV[2])
local max_requests = tonumber(ARGV[3])
local window_ms = tonumber(ARGV[4])
-- Remove old entries
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
-- Count current entries
local count = redis.call('ZCARD', key)
if count < max_requests then
-- Add new entry
redis.call('ZADD', key, now, now)
-- Set expiry
redis.call('PEXPIRE', key, window_ms)
return max_requests - count - 1
else
-- Get oldest entry to calculate retry time
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
if #oldest > 0 then
return -(oldest[2] + window_ms - now)
end
return -1
end
"#,
);
let result: i64 = script
.key(&redis_key)
.arg(now)
.arg(window_start)
.arg(self.config.max_requests)
.arg(self.config.window_seconds * 1000)
.invoke_async(&mut conn)
.await
.map_err(|_| RateLimitError {
retry_after: Duration::from_secs(1),
})?;
if result >= 0 {
Ok(result as u32)
} else {
let retry_ms = (-result) as u64;
Err(RateLimitError {
retry_after: Duration::from_millis(retry_ms.max(1)),
})
}
}
pub async fn get_count(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
let redis_key = format!("{}:{}", self.config.key_prefix, key);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as f64;
let window_start = now - (self.config.window_seconds as f64 * 1000.0);
let mut conn = self.conn.clone();
let _: () = conn.zrembyscore(&redis_key, "-inf", window_start).await?;
let count: u32 = conn.zcard(&redis_key).await?;
Ok(count)
}
pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
let count = self.get_count(key).await?;
Ok(self.config.max_requests.saturating_sub(count))
}
pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
let redis_key = format!("{}:{}", self.config.key_prefix, key);
let mut conn = self.conn.clone();
let _: () = conn.del(&redis_key).await?;
Ok(())
}
pub fn config(&self) -> &RedisRateLimiterConfig {
&self.config
}
}
pub struct KeyedRedisRateLimiter {
conn: ConnectionManager,
default_config: RedisRateLimiterConfig,
custom_configs: std::collections::HashMap<String, RedisRateLimiterConfig>,
}
impl KeyedRedisRateLimiter {
pub async fn new(
redis_url: &str,
default_config: RedisRateLimiterConfig,
) -> Result<Self, RedisRateLimiterError> {
let client = Client::open(redis_url)
.map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
let conn = ConnectionManager::new(client)
.await
.map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
Ok(Self {
conn,
default_config,
custom_configs: std::collections::HashMap::new(),
})
}
pub fn set_config(&mut self, key: impl Into<String>, config: RedisRateLimiterConfig) {
self.custom_configs.insert(key.into(), config);
}
pub fn remove_config(&mut self, key: &str) {
self.custom_configs.remove(key);
}
pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
limiter.check(key).await
}
pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
limiter.get_remaining(key).await
}
pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
limiter.reset(key).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = RedisRateLimiterConfig::default();
assert_eq!(config.max_requests, 100);
assert_eq!(config.window_seconds, 60);
assert_eq!(config.key_prefix, "ratelimit");
}
#[test]
fn test_config_builder() {
let config = RedisRateLimiterConfig::new(50, 30).with_prefix("myapp");
assert_eq!(config.max_requests, 50);
assert_eq!(config.window_seconds, 30);
assert_eq!(config.key_prefix, "myapp");
}
#[test]
fn test_error_display() {
let err = RedisRateLimiterError::Connection("timeout".to_string());
assert!(err.to_string().contains("timeout"));
let err = RedisRateLimiterError::Redis("command failed".to_string());
assert!(err.to_string().contains("command failed"));
}
#[tokio::test]
#[ignore = "requires Redis"]
async fn test_redis_rate_limiter_basic() {
let limiter = RedisRateLimiter::new("redis://localhost:6379", 5, 10)
.await
.expect("Failed to connect to Redis");
limiter.reset("test:basic").await.ok();
for i in 0..5 {
let result = limiter.check("test:basic").await;
assert!(result.is_ok(), "Request {} should be allowed", i);
}
let result = limiter.check("test:basic").await;
assert!(result.is_err(), "6th request should be denied");
}
#[tokio::test]
#[ignore = "requires Redis"]
async fn test_redis_rate_limiter_remaining() {
let limiter = RedisRateLimiter::new("redis://localhost:6379", 10, 60)
.await
.expect("Failed to connect to Redis");
limiter.reset("test:remaining").await.ok();
let remaining = limiter.get_remaining("test:remaining").await.unwrap();
assert_eq!(remaining, 10);
for _ in 0..3 {
limiter.check("test:remaining").await.ok();
}
let remaining = limiter.get_remaining("test:remaining").await.unwrap();
assert_eq!(remaining, 7);
}
#[tokio::test]
#[ignore = "requires Redis"]
async fn test_redis_rate_limiter_reset() {
let limiter = RedisRateLimiter::new("redis://localhost:6379", 2, 60)
.await
.expect("Failed to connect to Redis");
limiter.reset("test:reset").await.ok();
limiter.check("test:reset").await.ok();
limiter.check("test:reset").await.ok();
assert!(limiter.check("test:reset").await.is_err());
limiter.reset("test:reset").await.unwrap();
assert!(limiter.check("test:reset").await.is_ok());
}
}