allframe_core/resilience/
rate_limit_redis.rs1use std::time::Duration;
21
22use redis::{aio::ConnectionManager, AsyncCommands, Client};
23
24use super::RateLimitError;
25
26#[derive(Debug, Clone)]
28pub struct RedisRateLimiterConfig {
29 pub max_requests: u32,
31 pub window_seconds: u64,
33 pub key_prefix: String,
35}
36
37impl Default for RedisRateLimiterConfig {
38 fn default() -> Self {
39 Self {
40 max_requests: 100,
41 window_seconds: 60,
42 key_prefix: "ratelimit".to_string(),
43 }
44 }
45}
46
47impl RedisRateLimiterConfig {
48 pub fn new(max_requests: u32, window_seconds: u64) -> Self {
50 Self {
51 max_requests,
52 window_seconds,
53 key_prefix: "ratelimit".to_string(),
54 }
55 }
56
57 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
59 self.key_prefix = prefix.into();
60 self
61 }
62}
63
64#[derive(Debug)]
66pub enum RedisRateLimiterError {
67 Connection(String),
69 Redis(String),
71}
72
73impl std::fmt::Display for RedisRateLimiterError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 RedisRateLimiterError::Connection(msg) => write!(f, "Redis connection error: {}", msg),
77 RedisRateLimiterError::Redis(msg) => write!(f, "Redis error: {}", msg),
78 }
79 }
80}
81
82impl std::error::Error for RedisRateLimiterError {}
83
84impl From<redis::RedisError> for RedisRateLimiterError {
85 fn from(err: redis::RedisError) -> Self {
86 RedisRateLimiterError::Redis(err.to_string())
87 }
88}
89
90pub struct RedisRateLimiter {
109 conn: ConnectionManager,
110 config: RedisRateLimiterConfig,
111}
112
113impl RedisRateLimiter {
114 pub async fn new(
121 redis_url: &str,
122 max_requests: u32,
123 window_seconds: u64,
124 ) -> Result<Self, RedisRateLimiterError> {
125 Self::with_config(
126 redis_url,
127 RedisRateLimiterConfig::new(max_requests, window_seconds),
128 )
129 .await
130 }
131
132 pub async fn with_config(
134 redis_url: &str,
135 config: RedisRateLimiterConfig,
136 ) -> Result<Self, RedisRateLimiterError> {
137 let client = Client::open(redis_url)
138 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
139
140 let conn = ConnectionManager::new(client)
141 .await
142 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
143
144 Ok(Self { conn, config })
145 }
146
147 pub fn from_connection(conn: ConnectionManager, config: RedisRateLimiterConfig) -> Self {
149 Self { conn, config }
150 }
151
152 pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
157 let redis_key = format!("{}:{}", self.config.key_prefix, key);
158 let now = std::time::SystemTime::now()
159 .duration_since(std::time::UNIX_EPOCH)
160 .unwrap()
161 .as_millis() as f64;
162
163 let window_start = now - (self.config.window_seconds as f64 * 1000.0);
164
165 let mut conn = self.conn.clone();
166
167 let script = redis::Script::new(
170 r#"
171 local key = KEYS[1]
172 local now = tonumber(ARGV[1])
173 local window_start = tonumber(ARGV[2])
174 local max_requests = tonumber(ARGV[3])
175 local window_ms = tonumber(ARGV[4])
176
177 -- Remove old entries
178 redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
179
180 -- Count current entries
181 local count = redis.call('ZCARD', key)
182
183 if count < max_requests then
184 -- Add new entry
185 redis.call('ZADD', key, now, now)
186 -- Set expiry
187 redis.call('PEXPIRE', key, window_ms)
188 return max_requests - count - 1
189 else
190 -- Get oldest entry to calculate retry time
191 local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
192 if #oldest > 0 then
193 return -(oldest[2] + window_ms - now)
194 end
195 return -1
196 end
197 "#,
198 );
199
200 let result: i64 = script
201 .key(&redis_key)
202 .arg(now)
203 .arg(window_start)
204 .arg(self.config.max_requests)
205 .arg(self.config.window_seconds * 1000)
206 .invoke_async(&mut conn)
207 .await
208 .map_err(|_| RateLimitError {
209 retry_after: Duration::from_secs(1),
210 })?;
211
212 if result >= 0 {
213 Ok(result as u32)
214 } else {
215 let retry_ms = (-result) as u64;
216 Err(RateLimitError {
217 retry_after: Duration::from_millis(retry_ms.max(1)),
218 })
219 }
220 }
221
222 pub async fn get_count(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
224 let redis_key = format!("{}:{}", self.config.key_prefix, key);
225 let now = std::time::SystemTime::now()
226 .duration_since(std::time::UNIX_EPOCH)
227 .unwrap()
228 .as_millis() as f64;
229
230 let window_start = now - (self.config.window_seconds as f64 * 1000.0);
231
232 let mut conn = self.conn.clone();
233
234 let _: () = conn.zrembyscore(&redis_key, "-inf", window_start).await?;
236
237 let count: u32 = conn.zcard(&redis_key).await?;
239
240 Ok(count)
241 }
242
243 pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
245 let count = self.get_count(key).await?;
246 Ok(self.config.max_requests.saturating_sub(count))
247 }
248
249 pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
251 let redis_key = format!("{}:{}", self.config.key_prefix, key);
252 let mut conn = self.conn.clone();
253 let _: () = conn.del(&redis_key).await?;
254 Ok(())
255 }
256
257 pub fn config(&self) -> &RedisRateLimiterConfig {
259 &self.config
260 }
261}
262
263pub struct KeyedRedisRateLimiter {
268 conn: ConnectionManager,
269 default_config: RedisRateLimiterConfig,
270 custom_configs: std::collections::HashMap<String, RedisRateLimiterConfig>,
272}
273
274impl KeyedRedisRateLimiter {
275 pub async fn new(
277 redis_url: &str,
278 default_config: RedisRateLimiterConfig,
279 ) -> Result<Self, RedisRateLimiterError> {
280 let client = Client::open(redis_url)
281 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
282
283 let conn = ConnectionManager::new(client)
284 .await
285 .map_err(|e| RedisRateLimiterError::Connection(e.to_string()))?;
286
287 Ok(Self {
288 conn,
289 default_config,
290 custom_configs: std::collections::HashMap::new(),
291 })
292 }
293
294 pub fn set_config(&mut self, key: impl Into<String>, config: RedisRateLimiterConfig) {
296 self.custom_configs.insert(key.into(), config);
297 }
298
299 pub fn remove_config(&mut self, key: &str) {
301 self.custom_configs.remove(key);
302 }
303
304 pub async fn check(&self, key: &str) -> Result<u32, RateLimitError> {
306 let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
307 let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
308 limiter.check(key).await
309 }
310
311 pub async fn get_remaining(&self, key: &str) -> Result<u32, RedisRateLimiterError> {
313 let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
314 let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
315 limiter.get_remaining(key).await
316 }
317
318 pub async fn reset(&self, key: &str) -> Result<(), RedisRateLimiterError> {
320 let config = self.custom_configs.get(key).unwrap_or(&self.default_config);
321 let limiter = RedisRateLimiter::from_connection(self.conn.clone(), config.clone());
322 limiter.reset(key).await
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_config_default() {
332 let config = RedisRateLimiterConfig::default();
333 assert_eq!(config.max_requests, 100);
334 assert_eq!(config.window_seconds, 60);
335 assert_eq!(config.key_prefix, "ratelimit");
336 }
337
338 #[test]
339 fn test_config_builder() {
340 let config = RedisRateLimiterConfig::new(50, 30).with_prefix("myapp");
341
342 assert_eq!(config.max_requests, 50);
343 assert_eq!(config.window_seconds, 30);
344 assert_eq!(config.key_prefix, "myapp");
345 }
346
347 #[test]
348 fn test_error_display() {
349 let err = RedisRateLimiterError::Connection("timeout".to_string());
350 assert!(err.to_string().contains("timeout"));
351
352 let err = RedisRateLimiterError::Redis("command failed".to_string());
353 assert!(err.to_string().contains("command failed"));
354 }
355
356 #[tokio::test]
360 #[ignore = "requires Redis"]
361 async fn test_redis_rate_limiter_basic() {
362 let limiter = RedisRateLimiter::new("redis://localhost:6379", 5, 10)
363 .await
364 .expect("Failed to connect to Redis");
365
366 limiter.reset("test:basic").await.ok();
368
369 for i in 0..5 {
371 let result = limiter.check("test:basic").await;
372 assert!(result.is_ok(), "Request {} should be allowed", i);
373 }
374
375 let result = limiter.check("test:basic").await;
377 assert!(result.is_err(), "6th request should be denied");
378 }
379
380 #[tokio::test]
381 #[ignore = "requires Redis"]
382 async fn test_redis_rate_limiter_remaining() {
383 let limiter = RedisRateLimiter::new("redis://localhost:6379", 10, 60)
384 .await
385 .expect("Failed to connect to Redis");
386
387 limiter.reset("test:remaining").await.ok();
388
389 let remaining = limiter.get_remaining("test:remaining").await.unwrap();
391 assert_eq!(remaining, 10);
392
393 for _ in 0..3 {
395 limiter.check("test:remaining").await.ok();
396 }
397
398 let remaining = limiter.get_remaining("test:remaining").await.unwrap();
399 assert_eq!(remaining, 7);
400 }
401
402 #[tokio::test]
403 #[ignore = "requires Redis"]
404 async fn test_redis_rate_limiter_reset() {
405 let limiter = RedisRateLimiter::new("redis://localhost:6379", 2, 60)
406 .await
407 .expect("Failed to connect to Redis");
408
409 limiter.reset("test:reset").await.ok();
410
411 limiter.check("test:reset").await.ok();
413 limiter.check("test:reset").await.ok();
414
415 assert!(limiter.check("test:reset").await.is_err());
417
418 limiter.reset("test:reset").await.unwrap();
420
421 assert!(limiter.check("test:reset").await.is_ok());
423 }
424}