allframe_core/resilience/
rate_limit_redis.rs1use std::time::Duration;
21
22use redis::{aio::ConnectionManager, AsyncCommands, Client};
23
24use super::RateLimitError;
25
26#[derive(Debug, Clone)]
28pub struct RedisRateLimiterConfig {
29 pub max_requests: u32,
31 pub window_seconds: u64,
33 pub key_prefix: String,
35}
36
37impl Default for RedisRateLimiterConfig {
38 fn default() -> Self {
39 Self {
40 max_requests: 100,
41 window_seconds: 60,
42 key_prefix: "ratelimit".to_string(),
43 }
44 }
45}
46
47impl RedisRateLimiterConfig {
48 pub fn new(max_requests: u32, window_seconds: u64) -> Self {
50 Self {
51 max_requests,
52 window_seconds,
53 key_prefix: "ratelimit".to_string(),
54 }
55 }
56
57 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
59 self.key_prefix = prefix.into();
60 self
61 }
62}
63
64#[derive(Debug)]
66pub enum RedisRateLimiterError {
67 Connection(String),
69 Redis(String),
71}
72
73impl std::fmt::Display for RedisRateLimiterError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 RedisRateLimiterError::Connection(msg) => write!(f, "Redis connection error: {}", msg),
77 RedisRateLimiterError::Redis(msg) => write!(f, "Redis error: {}", msg),
78 }
79 }
80}
81
82impl std::error::Error for RedisRateLimiterError {}
83
84impl From<redis::RedisError> for RedisRateLimiterError {
85 fn from(err: redis::RedisError) -> Self {
86 RedisRateLimiterError::Redis(err.to_string())
87 }
88}
89
90pub struct RedisRateLimiter {
109 conn: ConnectionManager,
110 config: RedisRateLimiterConfig,
111}
112
113impl RedisRateLimiter {
114 pub async fn new(
121 redis_url: &str,
122 max_requests: u32,
123 window_seconds: u64,
124 ) -> Result<Self, RedisRateLimiterError> {
125 Self::with_config(
126 redis_url,
127 RedisRateLimiterConfig::new(max_requests, window_seconds),
128 )
129 .await
130 }
131
132 pub async fn with_config(
134 redis_url: &str,
135 config: RedisRateLimiterConfig,
136 ) -> Result<Self, RedisRateLimiterError> {
137 let client = Client::open(redis_url)
138 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
139
140 let conn = ConnectionManager::new(client)
141 .await
142 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
143
144 Ok(Self { conn, config })
145 }
146
147 pub fn from_connection(conn: ConnectionManager, config: RedisRateLimiterConfig) -> Self {
149 Self { conn, config }
150 }
151
152 pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
157 let redis_key = format!("{}:{}", self.config.key_prefix, key);
158 let now = std::time::SystemTime::now()
159 .duration_since(std::time::UNIX_EPOCH)
160 .unwrap()
161 .as_millis() as f64;
162
163 let window_start = now - (self.config.window_seconds as f64 * 1000.0);
164
165 let mut conn = self.conn.clone();
166
167 let script = redis::Script::new(
170 r#"
171 local key = KEYS[1]
172 local now = tonumber(ARGV[1])
173 local window_start = tonumber(ARGV[2])
174 local max_requests = tonumber(ARGV[3])
175 local window_ms = tonumber(ARGV[4])
176
177 -- Remove old entries
178 redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
179
180 -- Count current entries
181 local count = redis.call('ZCARD', key)
182
183 if count < max_requests then
184 -- Add new entry
185 redis.call('ZADD', key, now, now)
186 -- Set expiry
187 redis.call('PEXPIRE', key, window_ms)
188 return max_requests - count - 1
189 else
190 -- Get oldest entry to calculate retry time
191 local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
192 if #oldest > 0 then
193 return -(oldest[2] + window_ms - now)
194 end
195 return -1
196 end
197 "#,
198 );
199
200 let result: i64 = script
201 .key(&redis_key)
202 .arg(now)
203 .arg(window_start)
204 .arg(self.config.max_requests)
205 .arg(self.config.window_seconds * 1000)
206 .invoke_async(&mut conn)
207 .await
208 .map_err(|_| RateLimitError {
209 retry_after: Duration::from_secs(1),
210 })?;
211
212 if result >= 0 {
213 Ok(result as u32)
214 } else {
215 let retry_ms = (-result) as u64;
216 Err(RateLimitError {
217 retry_after: Duration::from_millis(retry_ms.max(1)),
218 })
219 }
220 }
221
222 pub async fn get_count(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
224 let redis_key = format!("{}:{}", self.config.key_prefix, key);
225 let now = std::time::SystemTime::now()
226 .duration_since(std::time::UNIX_EPOCH)
227 .unwrap()
228 .as_millis() as f64;
229
230 let window_start = now - (self.config.window_seconds as f64 * 1000.0);
231
232 let mut conn = self.conn.clone();
233
234 let _: () = conn
236 .zrembyscore(&redis_key, "-inf", window_start)
237 .await?;
238
239 let count: u32 = conn.zcard(&redis_key).await?;
241
242 Ok(count)
243 }
244
245 pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
247 let count = self.get_count(key).await?;
248 Ok(self.config.max_requests.saturating_sub(count))
249 }
250
251 pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
253 let redis_key = format!("{}:{}", self.config.key_prefix, key);
254 let mut conn = self.conn.clone();
255 let _: () = conn.del(&redis_key).await?;
256 Ok(())
257 }
258
259 pub fn config(&self) -> &RedisRateLimiterConfig {
261 &self.config
262 }
263}
264
265pub struct KeyedRedisRateLimiter {
270 conn: ConnectionManager,
271 default_config: RedisRateLimiterConfig,
272 custom_configs: std::collections::HashMap<String, RedisRateLimiterConfig>,
274}
275
276impl KeyedRedisRateLimiter {
277 pub async fn new(
279 redis_url: &str,
280 default_config: RedisRateLimiterConfig,
281 ) -> Result<Self, RedisRateLimiterError> {
282 let client = Client::open(redis_url)
283 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
284
285 let conn = ConnectionManager::new(client)
286 .await
287 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
288
289 Ok(Self {
290 conn,
291 default_config,
292 custom_configs: std::collections::HashMap::new(),
293 })
294 }
295
296 pub fn set_config(&mut self, key: impl Into<String>, config: RedisRateLimiterConfig) {
298 self.custom_configs.insert(key.into(), config);
299 }
300
301 pub fn remove_config(&mut self, key: &str) {
303 self.custom_configs.remove(key);
304 }
305
306 pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
308 let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
309 let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
310 limiter.check(key).await
311 }
312
313 pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
315 let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
316 let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
317 limiter.get_remaining(key).await
318 }
319
320 pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
322 let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
323 let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
324 limiter.reset(key).await
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_config_default() {
334 let config = RedisRateLimiterConfig::default();
335 assert_eq!(config.max_requests, 100);
336 assert_eq!(config.window_seconds, 60);
337 assert_eq!(config.key_prefix, "ratelimit");
338 }
339
340 #[test]
341 fn test_config_builder() {
342 let config = RedisRateLimiterConfig::new(50, 30).with_prefix("myapp");
343
344 assert_eq!(config.max_requests, 50);
345 assert_eq!(config.window_seconds, 30);
346 assert_eq!(config.key_prefix, "myapp");
347 }
348
349 #[test]
350 fn test_error_display() {
351 let err = RedisRateLimiterError::Connection("timeout".to_string());
352 assert!(err.to_string().contains("timeout"));
353
354 let err = RedisRateLimiterError::Redis("command failed".to_string());
355 assert!(err.to_string().contains("command failed"));
356 }
357
358 #[tokio::test]
362 #[ignore = "requires Redis"]
363 async fn test_redis_rate_limiter_basic() {
364 let limiter = RedisRateLimiter::new("redis://localhost:6379", 5, 10)
365 .await
366 .expect("Failed to connect to Redis");
367
368 limiter.reset("test:basic").await.ok();
370
371 for i in 0..5 {
373 let result = limiter.check("test:basic").await;
374 assert!(result.is_ok(), "Request {} should be allowed", i);
375 }
376
377 let result = limiter.check("test:basic").await;
379 assert!(result.is_err(), "6th request should be denied");
380 }
381
382 #[tokio::test]
383 #[ignore = "requires Redis"]
384 async fn test_redis_rate_limiter_remaining() {
385 let limiter = RedisRateLimiter::new("redis://localhost:6379", 10, 60)
386 .await
387 .expect("Failed to connect to Redis");
388
389 limiter.reset("test:remaining").await.ok();
390
391 let remaining = limiter.get_remaining("test:remaining").await.unwrap();
393 assert_eq!(remaining, 10);
394
395 for _ in 0..3 {
397 limiter.check("test:remaining").await.ok();
398 }
399
400 let remaining = limiter.get_remaining("test:remaining").await.unwrap();
401 assert_eq!(remaining, 7);
402 }
403
404 #[tokio::test]
405 #[ignore = "requires Redis"]
406 async fn test_redis_rate_limiter_reset() {
407 let limiter = RedisRateLimiter::new("redis://localhost:6379", 2, 60)
408 .await
409 .expect("Failed to connect to Redis");
410
411 limiter.reset("test:reset").await.ok();
412
413 limiter.check("test:reset").await.ok();
415 limiter.check("test:reset").await.ok();
416
417 assert!(limiter.check("test:reset").await.is_err());
419
420 limiter.reset("test:reset").await.unwrap();
422
423 assert!(limiter.check("test:reset").await.is_ok());
425 }
426}