allframe_core/resilience/
rate_limit_redis.rs

1//! Redis-backed rate limiting for distributed deployments.
2//!
3//! Provides sliding window rate limiting using Redis as a backend,
4//! allowing rate limits to be shared across multiple instances.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use allframe_core::resilience::RedisRateLimiter;
10//!
11//! // Connect to Redis
12//! let limiter = RedisRateLimiter::new("redis://localhost:6379", 100, 60).await?;
13//!
14//! // Check rate limit for a key
15//! if limiter.check("user:123").await.is_ok() {
16//!     // Process request
17//! }
18//! ```
19
20use std::time::Duration;
21
22use redis::{aio::ConnectionManager, AsyncCommands, Client};
23
24use super::RateLimitError;
25
26/// Configuration for Redis rate limiter.
27#[derive(Debug, Clone)]
28pub struct RedisRateLimiterConfig {
29    /// Maximum requests allowed in the window.
30    pub max_requests: u32,
31    /// Time window in seconds.
32    pub window_seconds: u64,
33    /// Key prefix for Redis keys.
34    pub key_prefix: String,
35}
36
37impl Default for RedisRateLimiterConfig {
38    fn default() -> Self {
39        Self {
40            max_requests: 100,
41            window_seconds: 60,
42            key_prefix: "ratelimit".to_string(),
43        }
44    }
45}
46
47impl RedisRateLimiterConfig {
48    /// Create a new config with specified limits.
49    pub fn new(max_requests: u32, window_seconds: u64) -> Self {
50        Self {
51            max_requests,
52            window_seconds,
53            key_prefix: "ratelimit".to_string(),
54        }
55    }
56
57    /// Set a custom key prefix.
58    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
59        self.key_prefix = prefix.into();
60        self
61    }
62}
63
64/// Error type for Redis rate limiter operations.
65#[derive(Debug)]
66pub enum RedisRateLimiterError {
67    /// Redis connection error.
68    Connection(String),
69    /// Redis operation error.
70    Redis(String),
71}
72
73impl std::fmt::Display for RedisRateLimiterError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            RedisRateLimiterError::Connection(msg) => write!(f, "Redis connection error: {}", msg),
77            RedisRateLimiterError::Redis(msg) => write!(f, "Redis error: {}", msg),
78        }
79    }
80}
81
82impl std::error::Error for RedisRateLimiterError {}
83
84impl From<redis::RedisError> for RedisRateLimiterError {
85    fn from(err: redis::RedisError) -> Self {
86        RedisRateLimiterError::Redis(err.to_string())
87    }
88}
89
90/// Redis-backed sliding window rate limiter.
91///
92/// Uses Redis sorted sets to implement a sliding window rate limiter
93/// that works across distributed deployments.
94///
95/// ## Algorithm
96///
97/// Uses the sliding window log algorithm:
98/// 1. Remove timestamps older than the window
99/// 2. Count remaining timestamps
100/// 3. If under limit, add current timestamp
101/// 4. Return allow/deny
102///
103/// ## Features
104///
105/// - **Distributed**: Works across multiple instances
106/// - **Sliding window**: More accurate than fixed windows
107/// - **Auto-cleanup**: Old entries are automatically removed
108pub struct RedisRateLimiter {
109    conn: ConnectionManager,
110    config: RedisRateLimiterConfig,
111}
112
113impl RedisRateLimiter {
114    /// Create a new Redis rate limiter.
115    ///
116    /// # Arguments
117    /// * `redis_url` - Redis connection URL (e.g., "redis://localhost:6379")
118    /// * `max_requests` - Maximum requests per window
119    /// * `window_seconds` - Window duration in seconds
120    pub async fn new(
121        redis_url: &str,
122        max_requests: u32,
123        window_seconds: u64,
124    ) -> Result<Self, RedisRateLimiterError> {
125        Self::with_config(
126            redis_url,
127            RedisRateLimiterConfig::new(max_requests, window_seconds),
128        )
129        .await
130    }
131
132    /// Create a new Redis rate limiter with custom configuration.
133    pub async fn with_config(
134        redis_url: &str,
135        config: RedisRateLimiterConfig,
136    ) -> Result<Self, RedisRateLimiterError> {
137        let client = Client::open(redis_url)
138            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
139
140        let conn = ConnectionManager::new(client)
141            .await
142            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
143
144        Ok(Self { conn, config })
145    }
146
147    /// Create from an existing Redis connection manager.
148    pub fn from_connection(conn: ConnectionManager, config: RedisRateLimiterConfig) -> Self {
149        Self { conn, config }
150    }
151
152    /// Check if a request for the given key is allowed.
153    ///
154    /// Returns `Ok(remaining)` with the number of remaining requests if allowed,
155    /// or `Err(RateLimitError)` if rate limited.
156    pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
157        let redis_key = format!("{}:{}", self.config.key_prefix, key);
158        let now = std::time::SystemTime::now()
159            .duration_since(std::time::UNIX_EPOCH)
160            .unwrap()
161            .as_millis() as f64;
162
163        let window_start = now - (self.config.window_seconds as f64 * 1000.0);
164
165        let mut conn = self.conn.clone();
166
167        // Lua script for atomic rate limiting
168        // This ensures the check-and-increment is atomic
169        let script = redis::Script::new(
170            r#"
171            local key = KEYS[1]
172            local now = tonumber(ARGV[1])
173            local window_start = tonumber(ARGV[2])
174            local max_requests = tonumber(ARGV[3])
175            local window_ms = tonumber(ARGV[4])
176
177            -- Remove old entries
178            redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
179
180            -- Count current entries
181            local count = redis.call('ZCARD', key)
182
183            if count < max_requests then
184                -- Add new entry
185                redis.call('ZADD', key, now, now)
186                -- Set expiry
187                redis.call('PEXPIRE', key, window_ms)
188                return max_requests - count - 1
189            else
190                -- Get oldest entry to calculate retry time
191                local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
192                if #oldest > 0 then
193                    return -(oldest[2] + window_ms - now)
194                end
195                return -1
196            end
197            "#,
198        );
199
200        let result: i64 = script
201            .key(&redis_key)
202            .arg(now)
203            .arg(window_start)
204            .arg(self.config.max_requests)
205            .arg(self.config.window_seconds * 1000)
206            .invoke_async(&mut conn)
207            .await
208            .map_err(|_| RateLimitError {
209                retry_after: Duration::from_secs(1),
210            })?;
211
212        if result >= 0 {
213            Ok(result as u32)
214        } else {
215            let retry_ms = (-result) as u64;
216            Err(RateLimitError {
217                retry_after: Duration::from_millis(retry_ms.max(1)),
218            })
219        }
220    }
221
222    /// Get the current count for a key without incrementing.
223    pub async fn get_count(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
224        let redis_key = format!("{}:{}", self.config.key_prefix, key);
225        let now = std::time::SystemTime::now()
226            .duration_since(std::time::UNIX_EPOCH)
227            .unwrap()
228            .as_millis() as f64;
229
230        let window_start = now - (self.config.window_seconds as f64 * 1000.0);
231
232        let mut conn = self.conn.clone();
233
234        // Remove old entries first
235        let _: () = conn
236            .zrembyscore(&redis_key, "-inf", window_start)
237            .await?;
238
239        // Count current entries
240        let count: u32 = conn.zcard(&redis_key).await?;
241
242        Ok(count)
243    }
244
245    /// Get remaining requests for a key.
246    pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
247        let count = self.get_count(key).await?;
248        Ok(self.config.max_requests.saturating_sub(count))
249    }
250
251    /// Reset the rate limit for a key.
252    pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
253        let redis_key = format!("{}:{}", self.config.key_prefix, key);
254        let mut conn = self.conn.clone();
255        let _: () = conn.del(&redis_key).await?;
256        Ok(())
257    }
258
259    /// Get the configuration.
260    pub fn config(&self) -> &RedisRateLimiterConfig {
261        &self.config
262    }
263}
264
265/// Keyed Redis rate limiter with per-key configuration.
266///
267/// Allows different rate limits for different keys (e.g., different
268/// limits for different API tiers).
269pub struct KeyedRedisRateLimiter {
270    conn: ConnectionManager,
271    default_config: RedisRateLimiterConfig,
272    /// Custom configs per key pattern
273    custom_configs: std::collections::HashMap<String, RedisRateLimiterConfig>,
274}
275
276impl KeyedRedisRateLimiter {
277    /// Create a new keyed Redis rate limiter.
278    pub async fn new(
279        redis_url: &str,
280        default_config: RedisRateLimiterConfig,
281    ) -> Result<Self, RedisRateLimiterError> {
282        let client = Client::open(redis_url)
283            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
284
285        let conn = ConnectionManager::new(client)
286            .await
287            .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
288
289        Ok(Self {
290            conn,
291            default_config,
292            custom_configs: std::collections::HashMap::new(),
293        })
294    }
295
296    /// Set a custom configuration for a specific key.
297    pub fn set_config(&mut self, key: impl Into<String>, config: RedisRateLimiterConfig) {
298        self.custom_configs.insert(key.into(), config);
299    }
300
301    /// Remove custom configuration for a key.
302    pub fn remove_config(&mut self, key: &str) {
303        self.custom_configs.remove(key);
304    }
305
306    /// Check if a request for the given key is allowed.
307    pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
308        let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
309        let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
310        limiter.check(key).await
311    }
312
313    /// Get remaining requests for a key.
314    pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
315        let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
316        let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
317        limiter.get_remaining(key).await
318    }
319
320    /// Reset rate limit for a key.
321    pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
322        let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
323        let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
324        limiter.reset(key).await
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_config_default() {
334        let config = RedisRateLimiterConfig::default();
335        assert_eq!(config.max_requests, 100);
336        assert_eq!(config.window_seconds, 60);
337        assert_eq!(config.key_prefix, "ratelimit");
338    }
339
340    #[test]
341    fn test_config_builder() {
342        let config = RedisRateLimiterConfig::new(50, 30).with_prefix("myapp");
343
344        assert_eq!(config.max_requests, 50);
345        assert_eq!(config.window_seconds, 30);
346        assert_eq!(config.key_prefix, "myapp");
347    }
348
349    #[test]
350    fn test_error_display() {
351        let err = RedisRateLimiterError::Connection("timeout".to_string());
352        assert!(err.to_string().contains("timeout"));
353
354        let err = RedisRateLimiterError::Redis("command failed".to_string());
355        assert!(err.to_string().contains("command failed"));
356    }
357
358    // Integration tests require a running Redis instance
359    // Run with: cargo test --features resilience-redis -- --ignored
360
361    #[tokio::test]
362    #[ignore = "requires Redis"]
363    async fn test_redis_rate_limiter_basic() {
364        let limiter = RedisRateLimiter::new("redis://localhost:6379", 5, 10)
365            .await
366            .expect("Failed to connect to Redis");
367
368        // Reset any previous state
369        limiter.reset("test:basic").await.ok();
370
371        // Should allow 5 requests
372        for i in 0..5 {
373            let result = limiter.check("test:basic").await;
374            assert!(result.is_ok(), "Request {} should be allowed", i);
375        }
376
377        // 6th request should be denied
378        let result = limiter.check("test:basic").await;
379        assert!(result.is_err(), "6th request should be denied");
380    }
381
382    #[tokio::test]
383    #[ignore = "requires Redis"]
384    async fn test_redis_rate_limiter_remaining() {
385        let limiter = RedisRateLimiter::new("redis://localhost:6379", 10, 60)
386            .await
387            .expect("Failed to connect to Redis");
388
389        limiter.reset("test:remaining").await.ok();
390
391        // Initially should have 10 remaining
392        let remaining = limiter.get_remaining("test:remaining").await.unwrap();
393        assert_eq!(remaining, 10);
394
395        // After 3 requests, should have 7 remaining
396        for _ in 0..3 {
397            limiter.check("test:remaining").await.ok();
398        }
399
400        let remaining = limiter.get_remaining("test:remaining").await.unwrap();
401        assert_eq!(remaining, 7);
402    }
403
404    #[tokio::test]
405    #[ignore = "requires Redis"]
406    async fn test_redis_rate_limiter_reset() {
407        let limiter = RedisRateLimiter::new("redis://localhost:6379", 2, 60)
408            .await
409            .expect("Failed to connect to Redis");
410
411        limiter.reset("test:reset").await.ok();
412
413        // Use up the limit
414        limiter.check("test:reset").await.ok();
415        limiter.check("test:reset").await.ok();
416
417        // Should be denied
418        assert!(limiter.check("test:reset").await.is_err());
419
420        // Reset
421        limiter.reset("test:reset").await.unwrap();
422
423        // Should work again
424        assert!(limiter.check("test:reset").await.is_ok());
425    }
426}