Skip to main content

allframe_core/resilience/
rate_limit_redis.rs

1//! Redis-backed rate limiting for distributed deployments.
2//!
3//! Provides sliding window rate limiting using Redis as a backend,
4//! allowing rate limits to be shared across multiple instances.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use allframe_core::resilience::RedisRateLimiter;
10//!
11//! // Connect to Redis
12//! let limiter = RedisRateLimiter::new("redis://localhost:6379", 100, 60).await?;
13//!
14//! // Check rate limit for a key
15//! if limiter.check("user:123").await.is_ok() {
16//!     // Process request
17//! }
18//! ```
19
20use std::time::Duration;
21
22use redis::{aio::ConnectionManager, AsyncCommands, Client};
23
24use super::RateLimitError;
25
26/// Configuration for Redis rate limiter.
27#[derive(Debug, Clone)]
28pub struct RedisRateLimiterConfig {
29    /// Maximum requests allowed in the window.
30    pub max_requests: u32,
31    /// Time window in seconds.
32    pub window_seconds: u64,
33    /// Key prefix for Redis keys.
34    pub key_prefix: String,
35}
36
37impl Default for RedisRateLimiterConfig {
38    fn default() -> Self {
39        Self {
40            max_requests: 100,
41            window_seconds: 60,
42            key_prefix: "ratelimit".to_string(),
43        }
44    }
45}
46
47impl RedisRateLimiterConfig {
48    /// Create a new config with specified limits.
49    pub fn new(max_requests: u32, window_seconds: u64) -> Self {
50        Self {
51            max_requests,
52            window_seconds,
53            key_prefix: "ratelimit".to_string(),
54        }
55    }
56
57    /// Set a custom key prefix.
58    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
59        self.key_prefix = prefix.into();
60        self
61    }
62}
63
64/// Error type for Redis rate limiter operations.
65#[derive(Debug)]
66pub enum RedisRateLimiterError {
67    /// Redis connection error.
68    Connection(String),
69    /// Redis operation error.
70    Redis(String),
71}
72
73impl std::fmt::Display for RedisRateLimiterError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            RedisRateLimiterError::Connection(msg) => write!(f, "Redis connection error: {}", msg),
77            RedisRateLimiterError::Redis(msg) => write!(f, "Redis error: {}", msg),
78        }
79    }
80}
81
82impl std::error::Error for RedisRateLimiterError {}
83
84impl From<redis::RedisError> for RedisRateLimiterError {
85    fn from(err: redis::RedisError) -> Self {
86        RedisRateLimiterError::Redis(err.to_string())
87    }
88}
89
90/// Redis-backed sliding window rate limiter.
91///
92/// Uses Redis sorted sets to implement a sliding window rate limiter
93/// that works across distributed deployments.
94///
95/// ## Algorithm
96///
97/// Uses the sliding window log algorithm:
98/// 1. Remove timestamps older than the window
99/// 2. Count remaining timestamps
100/// 3. If under limit, add current timestamp
101/// 4. Return allow/deny
102///
103/// ## Features
104///
105/// - **Distributed**: Works across multiple instances
106/// - **Sliding window**: More accurate than fixed windows
107/// - **Auto-cleanup**: Old entries are automatically removed
108pub struct RedisRateLimiter {
109    conn: ConnectionManager,
110    config: RedisRateLimiterConfig,
111}
112
113impl RedisRateLimiter {
114    /// Create a new Redis rate limiter.
115    ///
116    /// # Arguments
117    /// * `redis_url` - Redis connection URL (e.g., "redis://localhost:6379")
118    /// * `max_requests` - Maximum requests per window
119    /// * `window_seconds` - Window duration in seconds
120    pub async fn new(
121        redis_url: &str,
122        max_requests: u32,
123        window_seconds: u64,
124    ) -> Result<Self, RedisRateLimiterError> {
125        Self::with_config(
126            redis_url,
127            RedisRateLimiterConfig::new(max_requests, window_seconds),
128        )
129        .await
130    }
131
132    /// Create a new Redis rate limiter with custom configuration.
133    pub async fn with_config(
134        redis_url: &str,
135        config: RedisRateLimiterConfig,
136    ) -> Result<Self, RedisRateLimiterError> {
137        let client = Client::open(redis_url)
138            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
139
140        let conn = ConnectionManager::new(client)
141            .await
142            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
143
144        Ok(Self { conn, config })
145    }
146
147    /// Create from an existing Redis connection manager.
148    pub fn from_connection(conn: ConnectionManager, config: RedisRateLimiterConfig) -> Self {
149        Self { conn, config }
150    }
151
152    /// Check if a request for the given key is allowed.
153    ///
154    /// Returns `Ok(remaining)` with the number of remaining requests if
155    /// allowed, or `Err(RateLimitError)` if rate limited.
156    pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
157        let redis_key = format!("{}:{}", self.config.key_prefix, key);
158        let now = std::time::SystemTime::now()
159            .duration_since(std::time::UNIX_EPOCH)
160            .unwrap()
161            .as_millis() as f64;
162
163        let window_start = now - (self.config.window_seconds as f64 * 1000.0);
164
165        let mut conn = self.conn.clone();
166
167        // Lua script for atomic rate limiting
168        // This ensures the check-and-increment is atomic
169        let script = redis::Script::new(
170            r#"
171            local key = KEYS[1]
172            local now = tonumber(ARGV[1])
173            local window_start = tonumber(ARGV[2])
174            local max_requests = tonumber(ARGV[3])
175            local window_ms = tonumber(ARGV[4])
176
177            -- Remove old entries
178            redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
179
180            -- Count current entries
181            local count = redis.call('ZCARD', key)
182
183            if count < max_requests then
184                -- Add new entry
185                redis.call('ZADD', key, now, now)
186                -- Set expiry
187                redis.call('PEXPIRE', key, window_ms)
188                return max_requests - count - 1
189            else
190                -- Get oldest entry to calculate retry time
191                local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
192                if #oldest > 0 then
193                    return -(oldest[2] + window_ms - now)
194                end
195                return -1
196            end
197            "#,
198        );
199
200        let result: i64 = script
201            .key(&redis_key)
202            .arg(now)
203            .arg(window_start)
204            .arg(self.config.max_requests)
205            .arg(self.config.window_seconds * 1000)
206            .invoke_async(&mut conn)
207            .await
208            .map_err(|_| RateLimitError {
209                retry_after: Duration::from_secs(1),
210            })?;
211
212        if result >= 0 {
213            Ok(result as u32)
214        } else {
215            let retry_ms = (-result) as u64;
216            Err(RateLimitError {
217                retry_after: Duration::from_millis(retry_ms.max(1)),
218            })
219        }
220    }
221
222    /// Get the current count for a key without incrementing.
223    pub async fn get_count(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
224        let redis_key = format!("{}:{}", self.config.key_prefix, key);
225        let now = std::time::SystemTime::now()
226            .duration_since(std::time::UNIX_EPOCH)
227            .unwrap()
228            .as_millis() as f64;
229
230        let window_start = now - (self.config.window_seconds as f64 * 1000.0);
231
232        let mut conn = self.conn.clone();
233
234        // Remove old entries first
235        let _: () = conn.zrembyscore(&redis_key, "-inf", window_start).await?;
236
237        // Count current entries
238        let count: u32 = conn.zcard(&redis_key).await?;
239
240        Ok(count)
241    }
242
243    /// Get remaining requests for a key.
244    pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
245        let count = self.get_count(key).await?;
246        Ok(self.config.max_requests.saturating_sub(count))
247    }
248
249    /// Reset the rate limit for a key.
250    pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
251        let redis_key = format!("{}:{}", self.config.key_prefix, key);
252        let mut conn = self.conn.clone();
253        let _: () = conn.del(&redis_key).await?;
254        Ok(())
255    }
256
257    /// Get the configuration.
258    pub fn config(&self) -> &RedisRateLimiterConfig {
259        &self.config
260    }
261}
262
263/// Keyed Redis rate limiter with per-key configuration.
264///
265/// Allows different rate limits for different keys (e.g., different
266/// limits for different API tiers).
267pub struct KeyedRedisRateLimiter {
268    conn: ConnectionManager,
269    default_config: RedisRateLimiterConfig,
270    /// Custom configs per key pattern
271    custom_configs: std::collections::HashMap<String, RedisRateLimiterConfig>,
272}
273
274impl KeyedRedisRateLimiter {
275    /// Create a new keyed Redis rate limiter.
276    pub async fn new(
277        redis_url: &str,
278        default_config: RedisRateLimiterConfig,
279    ) -> Result<Self, RedisRateLimiterError> {
280        let client = Client::open(redis_url)
281            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
282
283        let conn = ConnectionManager::new(client)
284            .await
285            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
286
287        Ok(Self {
288            conn,
289            default_config,
290            custom_configs: std::collections::HashMap::new(),
291        })
292    }
293
294    /// Set a custom configuration for a specific key.
295    pub fn set_config(&mut self, key: impl Into<String>, config: RedisRateLimiterConfig) {
296        self.custom_configs.insert(key.into(), config);
297    }
298
299    /// Remove custom configuration for a key.
300    pub fn remove_config(&mut self, key: &str) {
301        self.custom_configs.remove(key);
302    }
303
304    /// Check if a request for the given key is allowed.
305    pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
306        let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
307        let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
308        limiter.check(key).await
309    }
310
311    /// Get remaining requests for a key.
312    pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
313        let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
314        let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
315        limiter.get_remaining(key).await
316    }
317
318    /// Reset rate limit for a key.
319    pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
320        let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
321        let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
322        limiter.reset(key).await
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_config_default() {
332        let config = RedisRateLimiterConfig::default();
333        assert_eq!(config.max_requests, 100);
334        assert_eq!(config.window_seconds, 60);
335        assert_eq!(config.key_prefix, "ratelimit");
336    }
337
338    #[test]
339    fn test_config_builder() {
340        let config = RedisRateLimiterConfig::new(50, 30).with_prefix("myapp");
341
342        assert_eq!(config.max_requests, 50);
343        assert_eq!(config.window_seconds, 30);
344        assert_eq!(config.key_prefix, "myapp");
345    }
346
347    #[test]
348    fn test_error_display() {
349        let err = RedisRateLimiterError::Connection("timeout".to_string());
350        assert!(err.to_string().contains("timeout"));
351
352        let err = RedisRateLimiterError::Redis("command failed".to_string());
353        assert!(err.to_string().contains("command failed"));
354    }
355
356    // Integration tests require a running Redis instance
357    // Run with: cargo test --features resilience-redis -- --ignored
358
359    #[tokio::test]
360    #[ignore = "requires Redis"]
361    async fn test_redis_rate_limiter_basic() {
362        let limiter = RedisRateLimiter::new("redis://localhost:6379", 5, 10)
363            .await
364            .expect("Failed to connect to Redis");
365
366        // Reset any previous state
367        limiter.reset("test:basic").await.ok();
368
369        // Should allow 5 requests
370        for i in 0..5 {
371            let result = limiter.check("test:basic").await;
372            assert!(result.is_ok(), "Request {} should be allowed", i);
373        }
374
375        // 6th request should be denied
376        let result = limiter.check("test:basic").await;
377        assert!(result.is_err(), "6th request should be denied");
378    }
379
380    #[tokio::test]
381    #[ignore = "requires Redis"]
382    async fn test_redis_rate_limiter_remaining() {
383        let limiter = RedisRateLimiter::new("redis://localhost:6379", 10, 60)
384            .await
385            .expect("Failed to connect to Redis");
386
387        limiter.reset("test:remaining").await.ok();
388
389        // Initially should have 10 remaining
390        let remaining = limiter.get_remaining("test:remaining").await.unwrap();
391        assert_eq!(remaining, 10);
392
393        // After 3 requests, should have 7 remaining
394        for _ in 0..3 {
395            limiter.check("test:remaining").await.ok();
396        }
397
398        let remaining = limiter.get_remaining("test:remaining").await.unwrap();
399        assert_eq!(remaining, 7);
400    }
401
402    #[tokio::test]
403    #[ignore = "requires Redis"]
404    async fn test_redis_rate_limiter_reset() {
405        let limiter = RedisRateLimiter::new("redis://localhost:6379", 2, 60)
406            .await
407            .expect("Failed to connect to Redis");
408
409        limiter.reset("test:reset").await.ok();
410
411        // Use up the limit
412        limiter.check("test:reset").await.ok();
413        limiter.check("test:reset").await.ok();
414
415        // Should be denied
416        assert!(limiter.check("test:reset").await.is_err());
417
418        // Reset
419        limiter.reset("test:reset").await.unwrap();
420
421        // Should work again
422        assert!(limiter.check("test:reset").await.is_ok());
423    }
424}