1use crate::cache::Cache;
23use crate::http::{HttpResponse, Request, Response};
24use crate::middleware::{Middleware, Next};
25use async_trait::async_trait;
26use dashmap::DashMap;
27use std::sync::{Arc, OnceLock};
28use std::time::{Duration, SystemTime, UNIX_EPOCH};
29
30type LimiterFn = Arc<dyn Fn(&Request) -> Vec<Limit> + Send + Sync>;
32
33fn limiter_registry() -> &'static DashMap<String, LimiterFn> {
35 static REGISTRY: OnceLock<DashMap<String, LimiterFn>> = OnceLock::new();
36 REGISTRY.get_or_init(DashMap::new)
37}
38
39pub struct Limit {
61 pub max_requests: u32,
63 pub window_seconds: u64,
65 key: Option<String>,
67 response_fn: Option<Arc<dyn Fn() -> HttpResponse + Send + Sync>>,
69}
70
71impl Limit {
72 pub fn per_second(max: u32) -> Self {
74 Self {
75 max_requests: max,
76 window_seconds: 1,
77 key: None,
78 response_fn: None,
79 }
80 }
81
82 pub fn per_minute(max: u32) -> Self {
84 Self {
85 max_requests: max,
86 window_seconds: 60,
87 key: None,
88 response_fn: None,
89 }
90 }
91
92 pub fn per_hour(max: u32) -> Self {
94 Self {
95 max_requests: max,
96 window_seconds: 3600,
97 key: None,
98 response_fn: None,
99 }
100 }
101
102 pub fn per_day(max: u32) -> Self {
104 Self {
105 max_requests: max,
106 window_seconds: 86400,
107 key: None,
108 response_fn: None,
109 }
110 }
111
112 pub fn by(mut self, key: impl Into<String>) -> Self {
123 self.key = Some(key.into());
124 self
125 }
126
127 pub fn response<F>(mut self, f: F) -> Self
137 where
138 F: Fn() -> HttpResponse + Send + Sync + 'static,
139 {
140 self.response_fn = Some(Arc::new(f));
141 self
142 }
143}
144
145pub enum LimiterResponse {
149 Single(Limit),
151 Multiple(Vec<Limit>),
153}
154
155impl From<Limit> for LimiterResponse {
156 fn from(limit: Limit) -> Self {
157 LimiterResponse::Single(limit)
158 }
159}
160
161impl From<Vec<Limit>> for LimiterResponse {
162 fn from(limits: Vec<Limit>) -> Self {
163 LimiterResponse::Multiple(limits)
164 }
165}
166
167impl LimiterResponse {
168 fn into_vec(self) -> Vec<Limit> {
169 match self {
170 LimiterResponse::Single(limit) => vec![limit],
171 LimiterResponse::Multiple(limits) => limits,
172 }
173 }
174}
175
176pub struct RateLimiter;
209
210impl RateLimiter {
211 pub fn define<F, T>(name: &str, f: F)
215 where
216 F: Fn(&Request) -> T + Send + Sync + 'static,
217 T: Into<LimiterResponse>,
218 {
219 let wrapped: LimiterFn = Arc::new(move |req| {
220 let response: LimiterResponse = f(req).into();
221 response.into_vec()
222 });
223 limiter_registry().insert(name.to_string(), wrapped);
224 }
225
226 pub fn resolve(name: &str, req: &Request) -> Option<Vec<Limit>> {
230 limiter_registry().get(name).map(|f| f(req))
231 }
232
233 pub fn per_second(max: u32) -> Limit {
235 Limit::per_second(max)
236 }
237
238 pub fn per_minute(max: u32) -> Limit {
240 Limit::per_minute(max)
241 }
242
243 pub fn per_hour(max: u32) -> Limit {
245 Limit::per_hour(max)
246 }
247
248 pub fn per_day(max: u32) -> Limit {
250 Limit::per_day(max)
251 }
252}
253
254struct RateLimitResult {
256 allowed: bool,
257 limit: u32,
258 remaining: u32,
259 retry_after: u64,
260}
261
262fn get_client_ip(request: &Request) -> String {
266 request
267 .header("X-Forwarded-For")
268 .and_then(|s| s.split(',').next())
269 .map(|s| s.trim().to_string())
270 .or_else(|| request.header("X-Real-IP").map(|s| s.to_string()))
271 .unwrap_or_else(|| "unknown".to_string())
272}
273
274async fn check_rate_limit(limit: &Limit, name: &str, identifier: &str) -> RateLimitResult {
279 let now_secs = SystemTime::now()
280 .duration_since(UNIX_EPOCH)
281 .unwrap_or_default()
282 .as_secs();
283 let window_number = now_secs / limit.window_seconds;
284 let key = format!("rate_limit:{name}:{identifier}:{window_number}");
285
286 let count = match Cache::increment(&key, 1).await {
288 Ok(c) => c as u32,
289 Err(e) => {
290 eprintln!("[ferro] Rate limiter cache error (fail-open): {e}");
291 return RateLimitResult {
292 allowed: true,
293 limit: limit.max_requests,
294 remaining: limit.max_requests,
295 retry_after: limit.window_seconds,
296 };
297 }
298 };
299
300 if count == 1 {
302 let ttl = Duration::from_secs(limit.window_seconds + 1);
303 if let Err(e) = Cache::expire(&key, ttl).await {
304 eprintln!("[ferro] Rate limiter expire error: {e}");
305 }
306 }
307
308 let remaining = limit.max_requests.saturating_sub(count);
309 let retry_after = limit.window_seconds - (now_secs % limit.window_seconds);
310
311 RateLimitResult {
312 allowed: count <= limit.max_requests,
313 limit: limit.max_requests,
314 remaining,
315 retry_after,
316 }
317}
318
319fn add_rate_limit_headers(
321 response: HttpResponse,
322 limit: u32,
323 remaining: u32,
324 retry_after: u64,
325) -> HttpResponse {
326 response
327 .header("X-RateLimit-Limit", limit.to_string())
328 .header("X-RateLimit-Remaining", remaining.to_string())
329 .header("X-RateLimit-Reset", retry_after.to_string())
330}
331
332pub struct Throttle {
353 name: Option<String>,
355 inline_limits: Vec<Limit>,
357}
358
359impl Throttle {
360 pub fn named(name: &str) -> Self {
365 Self {
366 name: Some(name.to_string()),
367 inline_limits: Vec::new(),
368 }
369 }
370
371 pub fn per_second(max: u32) -> Self {
373 Self {
374 name: None,
375 inline_limits: vec![Limit::per_second(max)],
376 }
377 }
378
379 pub fn per_minute(max: u32) -> Self {
381 Self {
382 name: None,
383 inline_limits: vec![Limit::per_minute(max)],
384 }
385 }
386
387 pub fn per_hour(max: u32) -> Self {
389 Self {
390 name: None,
391 inline_limits: vec![Limit::per_hour(max)],
392 }
393 }
394
395 pub fn per_day(max: u32) -> Self {
397 Self {
398 name: None,
399 inline_limits: vec![Limit::per_day(max)],
400 }
401 }
402}
403
404#[async_trait]
405impl Middleware for Throttle {
406 async fn handle(&self, request: Request, next: Next) -> Response {
407 let (limiter_name, limits) = if let Some(ref name) = self.name {
409 match RateLimiter::resolve(name, &request) {
410 Some(limits) => (name.clone(), limits),
411 None => {
412 eprintln!(
413 "[ferro] Rate limiter '{name}' not registered (fail-open, allowing request)"
414 );
415 return next(request).await;
416 }
417 }
418 } else {
419 let limits: Vec<Limit> = self
422 .inline_limits
423 .iter()
424 .map(|l| Limit {
425 max_requests: l.max_requests,
426 window_seconds: l.window_seconds,
427 key: l.key.clone(),
428 response_fn: l.response_fn.clone(),
429 })
430 .collect();
431 ("inline".to_string(), limits)
432 };
433
434 let client_ip = get_client_ip(&request);
436
437 let mut most_restrictive: Option<(
439 RateLimitResult,
440 Option<Arc<dyn Fn() -> HttpResponse + Send + Sync>>,
441 )> = None;
442
443 for limit in &limits {
445 let identifier = limit.key.as_deref().unwrap_or(&client_ip);
446 let result = check_rate_limit(limit, &limiter_name, identifier).await;
447
448 if !result.allowed {
449 let error_response = if let Some(ref response_fn) = limit.response_fn {
451 response_fn()
452 } else {
453 HttpResponse::json(serde_json::json!({
454 "error": "Too Many Requests",
455 "message": "Rate limit exceeded. Please try again later.",
456 "retry_after": result.retry_after
457 }))
458 .status(429)
459 };
460
461 let error_response =
462 add_rate_limit_headers(error_response, result.limit, 0, result.retry_after)
463 .header("Retry-After", result.retry_after.to_string());
464
465 return Err(error_response);
466 }
467
468 let is_more_restrictive = most_restrictive
470 .as_ref()
471 .map(|(prev, _)| result.remaining < prev.remaining)
472 .unwrap_or(true);
473
474 if is_more_restrictive {
475 most_restrictive = Some((result, limit.response_fn.clone()));
476 }
477 }
478
479 let response = next(request).await;
481
482 if let Some((result, _)) = most_restrictive {
484 match response {
485 Ok(http_response) => Ok(add_rate_limit_headers(
486 http_response,
487 result.limit,
488 result.remaining,
489 result.retry_after,
490 )),
491 Err(http_response) => Err(add_rate_limit_headers(
492 http_response,
493 result.limit,
494 result.remaining,
495 result.retry_after,
496 )),
497 }
498 } else {
499 response
500 }
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::cache::{CacheStore, InMemoryCache};
508 use crate::container::App;
509 use serial_test::serial;
510 use std::sync::Arc;
511
512 fn setup_test_cache() {
514 App::bind::<dyn CacheStore>(Arc::new(InMemoryCache::new()));
515 }
516
517 async fn test_request() -> Request {
519 use hyper_util::rt::TokioIo;
520 use std::sync::Mutex;
521 use tokio::sync::oneshot;
522
523 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
524 let addr = listener.local_addr().unwrap();
525 let (tx, rx) = oneshot::channel();
526 let tx_holder = Arc::new(Mutex::new(Some(tx)));
527
528 tokio::spawn(async move {
529 let (stream, _) = listener.accept().await.unwrap();
530 let io = TokioIo::new(stream);
531
532 let tx_holder = tx_holder.clone();
533 let service =
534 hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
535 let tx_holder = tx_holder.clone();
536 async move {
537 if let Some(tx) = tx_holder.lock().unwrap().take() {
538 let _ = tx.send(Request::new(req));
539 }
540 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Empty::<
541 bytes::Bytes,
542 >::new(
543 )))
544 }
545 });
546
547 hyper::server::conn::http1::Builder::new()
548 .serve_connection(io, service)
549 .await
550 .ok();
551 });
552
553 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
555 let io = TokioIo::new(stream);
556
557 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
558 tokio::spawn(async move {
559 conn.await.ok();
560 });
561
562 let req = hyper::Request::builder()
563 .uri("/test")
564 .body(http_body_util::Empty::<bytes::Bytes>::new())
565 .unwrap();
566
567 let _ = sender.send_request(req).await;
568 rx.await.unwrap()
569 }
570
571 #[test]
576 fn test_limit_per_minute() {
577 let limit = Limit::per_minute(60);
578 assert_eq!(limit.max_requests, 60);
579 assert_eq!(limit.window_seconds, 60);
580 assert!(limit.key.is_none());
581 assert!(limit.response_fn.is_none());
582 }
583
584 #[test]
585 fn test_limit_per_hour() {
586 let limit = Limit::per_hour(1000);
587 assert_eq!(limit.max_requests, 1000);
588 assert_eq!(limit.window_seconds, 3600);
589 }
590
591 #[test]
592 fn test_limit_per_second() {
593 let limit = Limit::per_second(10);
594 assert_eq!(limit.max_requests, 10);
595 assert_eq!(limit.window_seconds, 1);
596 }
597
598 #[test]
599 fn test_limit_per_day() {
600 let limit = Limit::per_day(10000);
601 assert_eq!(limit.max_requests, 10000);
602 assert_eq!(limit.window_seconds, 86400);
603 }
604
605 #[test]
606 fn test_limit_by_key() {
607 let limit = Limit::per_minute(60).by("user:1");
608 assert_eq!(limit.key, Some("user:1".to_string()));
609 }
610
611 #[test]
612 fn test_limit_response_factory() {
613 let limit = Limit::per_minute(60)
614 .response(|| HttpResponse::json(serde_json::json!({"error": "custom"})).status(429));
615 assert!(limit.response_fn.is_some());
616 }
617
618 #[tokio::test]
623 #[serial]
624 async fn test_define_and_resolve() {
625 limiter_registry().clear();
627
628 RateLimiter::define("test", |_req| Limit::per_minute(100));
629
630 let req = test_request().await;
631 let limits = RateLimiter::resolve("test", &req);
632 assert!(limits.is_some(), "defined limiter should resolve");
633
634 let limits = limits.unwrap();
635 assert_eq!(limits.len(), 1);
636 assert_eq!(limits[0].max_requests, 100);
637 assert_eq!(limits[0].window_seconds, 60);
638 }
639
640 #[tokio::test]
641 #[serial]
642 async fn test_resolve_undefined() {
643 limiter_registry().clear();
644
645 let req = test_request().await;
646 let result = RateLimiter::resolve("nonexistent", &req);
647 assert!(result.is_none(), "undefined limiter should resolve to None");
648 }
649
650 #[tokio::test]
651 #[serial]
652 async fn test_define_multiple_limits() {
653 limiter_registry().clear();
654
655 RateLimiter::define("login", |_req| {
656 vec![Limit::per_minute(500), Limit::per_minute(5).by("email")]
657 });
658
659 let req = test_request().await;
660 let limits = RateLimiter::resolve("login", &req).unwrap();
661 assert_eq!(limits.len(), 2);
662 assert_eq!(limits[0].max_requests, 500);
663 assert!(limits[0].key.is_none());
664 assert_eq!(limits[1].max_requests, 5);
665 assert_eq!(limits[1].key, Some("email".to_string()));
666 }
667
668 #[tokio::test]
673 #[serial]
674 async fn test_allows_within_limit() {
675 setup_test_cache();
676
677 let limit = Limit::per_minute(10);
678 for i in 1..=5 {
679 let result = check_rate_limit(&limit, "test_allow", "ip:127.0.0.1").await;
680 assert!(result.allowed, "request {i} should be allowed");
681 assert_eq!(result.remaining, 10 - i);
682 assert_eq!(result.limit, 10);
683 }
684 }
685
686 #[tokio::test]
687 #[serial]
688 async fn test_exceeds_limit() {
689 setup_test_cache();
690
691 let limit = Limit::per_minute(3);
692 for i in 1..=3 {
694 let result = check_rate_limit(&limit, "test_exceed", "ip:10.0.0.1").await;
695 assert!(result.allowed, "request {i} should be allowed");
696 }
697 let result = check_rate_limit(&limit, "test_exceed", "ip:10.0.0.1").await;
699 assert!(!result.allowed, "request 4 should be rate limited");
700 assert_eq!(result.remaining, 0);
701 }
702
703 #[tokio::test]
704 #[serial]
705 async fn test_separate_keys_independent() {
706 setup_test_cache();
707
708 let limit = Limit::per_minute(2);
709 for _ in 0..2 {
711 check_rate_limit(&limit, "test_sep", "key_a").await;
712 }
713 let result_a = check_rate_limit(&limit, "test_sep", "key_a").await;
714 assert!(!result_a.allowed, "key_a should be exhausted");
715
716 let result_b = check_rate_limit(&limit, "test_sep", "key_b").await;
718 assert!(result_b.allowed, "key_b should still be allowed");
719 assert_eq!(result_b.remaining, 1);
720 }
721
722 #[tokio::test]
723 #[serial]
724 async fn test_cache_failure_allows_request() {
725 let limit = Limit::per_minute(5);
749 let result = check_rate_limit(&limit, "failopen", "test").await;
750 assert!(result.allowed);
754 }
755
756 #[test]
761 fn test_throttle_per_minute() {
762 let throttle = Throttle::per_minute(60);
763 assert!(throttle.name.is_none());
764 assert_eq!(throttle.inline_limits.len(), 1);
765 assert_eq!(throttle.inline_limits[0].max_requests, 60);
766 assert_eq!(throttle.inline_limits[0].window_seconds, 60);
767 }
768
769 #[test]
770 fn test_throttle_per_second() {
771 let throttle = Throttle::per_second(10);
772 assert_eq!(throttle.inline_limits[0].max_requests, 10);
773 assert_eq!(throttle.inline_limits[0].window_seconds, 1);
774 }
775
776 #[test]
777 fn test_throttle_per_hour() {
778 let throttle = Throttle::per_hour(1000);
779 assert_eq!(throttle.inline_limits[0].max_requests, 1000);
780 assert_eq!(throttle.inline_limits[0].window_seconds, 3600);
781 }
782
783 #[test]
784 fn test_throttle_per_day() {
785 let throttle = Throttle::per_day(5000);
786 assert_eq!(throttle.inline_limits[0].max_requests, 5000);
787 assert_eq!(throttle.inline_limits[0].window_seconds, 86400);
788 }
789
790 #[test]
791 fn test_throttle_named() {
792 let throttle = Throttle::named("api");
793 assert_eq!(throttle.name, Some("api".to_string()));
794 assert!(throttle.inline_limits.is_empty());
795 }
796
797 #[test]
802 fn test_limiter_response_single() {
803 let response: LimiterResponse = Limit::per_minute(60).into();
804 let limits = response.into_vec();
805 assert_eq!(limits.len(), 1);
806 assert_eq!(limits[0].max_requests, 60);
807 }
808
809 #[test]
810 fn test_limiter_response_multiple() {
811 let response: LimiterResponse = vec![Limit::per_minute(60), Limit::per_hour(1000)].into();
812 let limits = response.into_vec();
813 assert_eq!(limits.len(), 2);
814 assert_eq!(limits[0].max_requests, 60);
815 assert_eq!(limits[1].max_requests, 1000);
816 }
817}