1use std::collections::HashMap;
25use std::sync::Mutex;
26use std::time::{Duration, Instant};
27
28struct TokenBucket {
35 tokens: f64,
36 capacity: f64,
37 refill_rate: f64, last_refill: Instant,
39}
40
41impl TokenBucket {
42 fn new(capacity: f64, refill_rate: f64) -> Self {
44 Self {
45 tokens: capacity,
46 capacity,
47 refill_rate,
48 last_refill: Instant::now(),
49 }
50 }
51
52 fn refill(&mut self) {
54 let now = Instant::now();
55 let elapsed_secs = now.duration_since(self.last_refill).as_secs_f64();
56 self.tokens = (self.tokens + self.refill_rate * elapsed_secs).min(self.capacity);
57 self.last_refill = now;
58 }
59
60 fn try_consume(&mut self, n: f64) -> bool {
64 self.refill();
65 if self.tokens >= n {
66 self.tokens -= n;
67 true
68 } else {
69 false
70 }
71 }
72
73 #[allow(dead_code)]
75 fn available(&mut self) -> f64 {
76 self.refill();
77 self.tokens
78 }
79
80 fn ms_until_available(&self, n: f64) -> u64 {
84 if self.tokens >= n {
85 return 0;
86 }
87 let deficit = n - self.tokens;
88 let secs = deficit / self.refill_rate;
89 (secs * 1000.0).ceil() as u64
90 }
91}
92
93#[derive(Debug, Clone)]
97pub struct RateLimitConfig {
98 pub rps: f64,
100 pub burst: f64,
102 pub max_clients: usize,
104 pub client_ttl: Duration,
106 pub global_rps: Option<f64>,
108}
109
110impl Default for RateLimitConfig {
111 fn default() -> Self {
112 Self {
113 rps: 10.0,
114 burst: 20.0,
115 max_clients: 10_000,
116 client_ttl: Duration::from_secs(300),
117 global_rps: None,
118 }
119 }
120}
121
122#[derive(Debug, Clone, PartialEq)]
126pub enum RateLimitDecision {
127 Allow,
129 Deny {
131 retry_after_ms: u64,
133 },
134}
135
136impl RateLimitDecision {
137 pub fn is_allowed(&self) -> bool {
139 matches!(self, RateLimitDecision::Allow)
140 }
141
142 pub fn retry_after_ms(&self) -> Option<u64> {
144 match self {
145 RateLimitDecision::Deny { retry_after_ms } => Some(*retry_after_ms),
146 RateLimitDecision::Allow => None,
147 }
148 }
149}
150
151pub struct RateLimiter {
157 config: RateLimitConfig,
158 clients: Mutex<HashMap<String, (TokenBucket, Instant)>>,
160 global: Option<Mutex<TokenBucket>>,
162}
163
164impl RateLimiter {
165 pub fn new(config: RateLimitConfig) -> Self {
167 let global = config.global_rps.map(|rps| {
168 Mutex::new(TokenBucket::new(rps * 2.0, rps))
170 });
171 Self {
172 config,
173 clients: Mutex::new(HashMap::new()),
174 global,
175 }
176 }
177
178 pub fn check(&self, client_id: &str) -> RateLimitDecision {
182 if let Some(ref global_mutex) = self.global {
184 let global = global_mutex
185 .lock()
186 .expect("global rate limiter mutex poisoned");
187 if global.tokens < 1.0 {
188 let retry_ms = global.ms_until_available(1.0);
189 return RateLimitDecision::Deny {
190 retry_after_ms: retry_ms.max(1),
191 };
192 }
193 }
194
195 let mut clients = self
197 .clients
198 .lock()
199 .expect("client rate limiter mutex poisoned");
200
201 if let Some((bucket, _last_seen)) = clients.get_mut(client_id) {
202 bucket.refill();
204 if bucket.tokens < 1.0 {
205 let retry_ms = bucket.ms_until_available(1.0);
206 return RateLimitDecision::Deny {
207 retry_after_ms: retry_ms.max(1),
208 };
209 }
210 }
211 RateLimitDecision::Allow
213 }
214
215 pub fn check_and_consume(&self, client_id: &str) -> RateLimitDecision {
220 if let Some(ref global_mutex) = self.global {
222 let mut global = global_mutex
223 .lock()
224 .expect("global rate limiter mutex poisoned");
225 if !global.try_consume(1.0) {
226 let retry_ms = global.ms_until_available(1.0);
227 return RateLimitDecision::Deny {
228 retry_after_ms: retry_ms.max(1),
229 };
230 }
231 }
232
233 let mut clients = self
234 .clients
235 .lock()
236 .expect("client rate limiter mutex poisoned");
237
238 if clients.len() >= self.config.max_clients {
240 let ttl = self.config.client_ttl;
241 let now = Instant::now();
242 clients.retain(|_, (_, last_seen)| now.duration_since(*last_seen) < ttl);
243 }
244
245 let bucket = clients.entry(client_id.to_owned()).or_insert_with(|| {
246 (
247 TokenBucket::new(self.config.burst, self.config.rps),
248 Instant::now(),
249 )
250 });
251
252 let (token_bucket, last_seen) = bucket;
253 *last_seen = Instant::now();
254
255 if token_bucket.try_consume(1.0) {
256 RateLimitDecision::Allow
257 } else {
258 let retry_ms = token_bucket.ms_until_available(1.0);
259 RateLimitDecision::Deny {
260 retry_after_ms: retry_ms.max(1),
261 }
262 }
263 }
264
265 pub fn evict_stale(&self) {
267 let ttl = self.config.client_ttl;
268 let now = Instant::now();
269 let mut clients = self
270 .clients
271 .lock()
272 .expect("client rate limiter mutex poisoned");
273 clients.retain(|_, (_, last_seen)| now.duration_since(*last_seen) < ttl);
274 }
275
276 pub fn active_clients(&self) -> usize {
278 self.clients
279 .lock()
280 .expect("client rate limiter mutex poisoned")
281 .len()
282 }
283
284 pub fn reset_client(&self, client_id: &str) {
286 self.clients
287 .lock()
288 .expect("client rate limiter mutex poisoned")
289 .remove(client_id);
290 }
291
292 pub fn is_global_limited(&self) -> bool {
294 match &self.global {
295 None => false,
296 Some(global_mutex) => {
297 let global = global_mutex
298 .lock()
299 .expect("global rate limiter mutex poisoned");
300 global.tokens < 1.0
301 }
302 }
303 }
304}
305
306pub fn rate_limit_middleware(
313 limiter: std::sync::Arc<RateLimiter>,
314 client_id: &str,
315) -> RateLimitDecision {
316 limiter.check_and_consume(client_id)
317}
318
319#[cfg(feature = "server")]
326pub fn extract_client_id(headers: &axum::http::HeaderMap) -> String {
327 if let Some(xff) = headers.get("x-forwarded-for") {
329 if let Ok(val) = xff.to_str() {
330 let first = val.split(',').next().unwrap_or("").trim();
331 if !first.is_empty() {
332 return first.to_owned();
333 }
334 }
335 }
336
337 if let Some(real_ip) = headers.get("x-real-ip") {
339 if let Ok(val) = real_ip.to_str() {
340 let trimmed = val.trim();
341 if !trimmed.is_empty() {
342 return trimmed.to_owned();
343 }
344 }
345 }
346
347 "unknown".to_owned()
348}
349
350#[cfg(test)]
353mod tests {
354 use super::*;
355 use std::thread;
356
357 #[test]
358 fn test_token_bucket_initial_full() {
359 let mut bucket = TokenBucket::new(10.0, 1.0);
360 assert!((bucket.available() - 10.0).abs() < 1e-6);
361 }
362
363 #[test]
364 fn test_token_bucket_consume_success() {
365 let mut bucket = TokenBucket::new(10.0, 1.0);
366 assert!(bucket.try_consume(5.0));
367 let remaining = bucket.available();
368 assert!((4.9..=5.1).contains(&remaining), "remaining={remaining}");
369 }
370
371 #[test]
372 fn test_token_bucket_consume_fail_insufficient() {
373 let mut bucket = TokenBucket::new(3.0, 0.01); assert!(bucket.try_consume(3.0)); assert!(!bucket.try_consume(1.0)); }
377
378 #[test]
379 fn test_token_bucket_refills_over_time() {
380 let mut bucket = TokenBucket::new(10.0, 1000.0); assert!(bucket.try_consume(10.0)); thread::sleep(Duration::from_millis(20));
384 let available = bucket.available();
385 assert!(
387 available > 1.0,
388 "bucket should have refilled; got {available}"
389 );
390 }
391
392 #[test]
393 fn test_rate_limiter_allows_first_request() {
394 let config = RateLimitConfig {
395 rps: 10.0,
396 burst: 10.0,
397 ..Default::default()
398 };
399 let limiter = RateLimiter::new(config);
400 let decision = limiter.check_and_consume("client-1");
401 assert_eq!(decision, RateLimitDecision::Allow);
402 }
403
404 #[test]
405 fn test_rate_limiter_denies_after_burst() {
406 let config = RateLimitConfig {
407 rps: 1.0,
408 burst: 3.0, ..Default::default()
410 };
411 let limiter = RateLimiter::new(config);
412
413 for i in 0..3 {
415 let d = limiter.check_and_consume("client-burst");
416 assert_eq!(d, RateLimitDecision::Allow, "request {i} should be allowed");
417 }
418
419 let denied = limiter.check_and_consume("client-burst");
421 assert!(
422 denied.retry_after_ms().is_some(),
423 "4th request should be denied"
424 );
425 }
426
427 #[test]
428 fn test_rate_limiter_different_clients_independent() {
429 let config = RateLimitConfig {
430 rps: 1.0,
431 burst: 1.0,
432 ..Default::default()
433 };
434 let limiter = RateLimiter::new(config);
435
436 assert_eq!(
438 limiter.check_and_consume("client-a"),
439 RateLimitDecision::Allow
440 );
441 let denied = limiter.check_and_consume("client-a");
442 assert!(!denied.is_allowed());
443
444 assert_eq!(
446 limiter.check_and_consume("client-b"),
447 RateLimitDecision::Allow
448 );
449 }
450
451 #[test]
452 fn test_rate_limit_decision_is_allowed() {
453 assert!(RateLimitDecision::Allow.is_allowed());
454 assert_eq!(RateLimitDecision::Allow.retry_after_ms(), None);
455
456 let denied = RateLimitDecision::Deny {
457 retry_after_ms: 500,
458 };
459 assert!(!denied.is_allowed());
460 assert_eq!(denied.retry_after_ms(), Some(500));
461 }
462
463 #[test]
464 fn test_extract_client_id_x_forwarded_for() {
465 use axum::http::HeaderMap;
466 use axum::http::HeaderValue;
467
468 let mut headers = HeaderMap::new();
469 headers.insert(
470 "x-forwarded-for",
471 HeaderValue::from_static("203.0.113.42, 10.0.0.1"),
472 );
473 let id = extract_client_id(&headers);
474 assert_eq!(id, "203.0.113.42");
475 }
476
477 #[test]
478 fn test_extract_client_id_fallback() {
479 use axum::http::HeaderMap;
480 let headers = HeaderMap::new();
481 let id = extract_client_id(&headers);
482 assert_eq!(id, "unknown");
483 }
484
485 #[test]
486 fn test_rate_limiter_active_clients_tracked() {
487 let limiter = RateLimiter::new(RateLimitConfig::default());
488 limiter.check_and_consume("alpha");
489 limiter.check_and_consume("beta");
490 assert_eq!(limiter.active_clients(), 2);
491 limiter.reset_client("alpha");
492 assert_eq!(limiter.active_clients(), 1);
493 }
494
495 #[test]
496 fn test_rate_limiter_no_global_limit_by_default() {
497 let limiter = RateLimiter::new(RateLimitConfig::default());
498 assert!(!limiter.is_global_limited());
499 }
500}