1use std::collections::{HashMap, VecDeque};
24use std::sync::Mutex;
25use std::time::Instant;
26
27use fastmcp_core::{McpContext, McpError, McpErrorCode, McpResult};
28use fastmcp_protocol::JsonRpcRequest;
29
30use crate::{Middleware, MiddlewareDecision};
31
32pub const RATE_LIMIT_ERROR_CODE: i32 = -32005;
36
37#[must_use]
39pub fn rate_limit_error(message: impl Into<String>) -> McpError {
40 McpError::new(McpErrorCode::Custom(RATE_LIMIT_ERROR_CODE), message)
41}
42
43#[derive(Debug)]
49pub struct TokenBucketRateLimiter {
50 capacity: usize,
52 refill_rate: f64,
54 tokens: Mutex<f64>,
56 last_refill: Mutex<Instant>,
58}
59
60impl TokenBucketRateLimiter {
61 #[must_use]
68 pub fn new(capacity: usize, refill_rate: f64) -> Self {
69 Self {
70 capacity,
71 refill_rate,
72 tokens: Mutex::new(capacity as f64),
73 last_refill: Mutex::new(Instant::now()),
74 }
75 }
76
77 pub fn try_consume(&self, tokens: usize) -> bool {
81 let mut current_tokens = self
82 .tokens
83 .lock()
84 .unwrap_or_else(std::sync::PoisonError::into_inner);
85 let mut last_refill = self
86 .last_refill
87 .lock()
88 .unwrap_or_else(std::sync::PoisonError::into_inner);
89
90 let now = Instant::now();
91 let elapsed = now.duration_since(*last_refill).as_secs_f64();
92
93 *current_tokens = (*current_tokens + elapsed * self.refill_rate).min(self.capacity as f64);
95 *last_refill = now;
96
97 let tokens_needed = tokens as f64;
98 if *current_tokens >= tokens_needed {
99 *current_tokens -= tokens_needed;
100 true
101 } else {
102 false
103 }
104 }
105
106 #[must_use]
108 pub fn available_tokens(&self) -> f64 {
109 let mut current_tokens = self
110 .tokens
111 .lock()
112 .unwrap_or_else(std::sync::PoisonError::into_inner);
113 let mut last_refill = self
114 .last_refill
115 .lock()
116 .unwrap_or_else(std::sync::PoisonError::into_inner);
117
118 let now = Instant::now();
119 let elapsed = now.duration_since(*last_refill).as_secs_f64();
120
121 *current_tokens = (*current_tokens + elapsed * self.refill_rate).min(self.capacity as f64);
123 *last_refill = now;
124
125 *current_tokens
126 }
127}
128
129#[derive(Debug)]
135pub struct SlidingWindowRateLimiter {
136 max_requests: usize,
138 window_seconds: u64,
140 requests: Mutex<VecDeque<Instant>>,
142}
143
144impl SlidingWindowRateLimiter {
145 #[must_use]
152 pub fn new(max_requests: usize, window_seconds: u64) -> Self {
153 Self {
154 max_requests,
155 window_seconds,
156 requests: Mutex::new(VecDeque::new()),
157 }
158 }
159
160 pub fn is_allowed(&self) -> bool {
165 let mut requests = self
166 .requests
167 .lock()
168 .unwrap_or_else(std::sync::PoisonError::into_inner);
169
170 let now = Instant::now();
171 let cutoff = now - std::time::Duration::from_secs(self.window_seconds);
172
173 while let Some(&oldest) = requests.front() {
175 if oldest < cutoff {
176 requests.pop_front();
177 } else {
178 break;
179 }
180 }
181
182 if requests.len() < self.max_requests {
183 requests.push_back(now);
184 true
185 } else {
186 false
187 }
188 }
189
190 #[must_use]
192 pub fn current_requests(&self) -> usize {
193 let mut requests = self
194 .requests
195 .lock()
196 .unwrap_or_else(std::sync::PoisonError::into_inner);
197
198 let now = Instant::now();
199 let cutoff = now - std::time::Duration::from_secs(self.window_seconds);
200
201 while let Some(&oldest) = requests.front() {
203 if oldest < cutoff {
204 requests.pop_front();
205 } else {
206 break;
207 }
208 }
209
210 requests.len()
211 }
212}
213
214pub type ClientIdExtractor =
216 Box<dyn Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync>;
217
218pub struct RateLimitingMiddleware {
233 max_requests_per_second: f64,
235 burst_capacity: usize,
237 get_client_id: Option<ClientIdExtractor>,
239 global_limit: bool,
241 limiters: Mutex<HashMap<String, TokenBucketRateLimiter>>,
243 global_limiter: Option<TokenBucketRateLimiter>,
245}
246
247impl std::fmt::Debug for RateLimitingMiddleware {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 f.debug_struct("RateLimitingMiddleware")
250 .field("max_requests_per_second", &self.max_requests_per_second)
251 .field("burst_capacity", &self.burst_capacity)
252 .field("global_limit", &self.global_limit)
253 .finish()
254 }
255}
256
257impl RateLimitingMiddleware {
258 #[must_use]
266 pub fn new(max_requests_per_second: f64) -> Self {
267 let burst_capacity = (max_requests_per_second * 2.0) as usize;
268 Self {
269 max_requests_per_second,
270 burst_capacity,
271 get_client_id: None,
272 global_limit: false,
273 limiters: Mutex::new(HashMap::new()),
274 global_limiter: None,
275 }
276 }
277
278 #[must_use]
280 pub fn burst_capacity(mut self, capacity: usize) -> Self {
281 self.burst_capacity = capacity;
282 if self.global_limit {
284 self.global_limiter = Some(TokenBucketRateLimiter::new(
285 capacity,
286 self.max_requests_per_second,
287 ));
288 }
289 self
290 }
291
292 #[must_use]
296 pub fn client_id_extractor<F>(mut self, extractor: F) -> Self
297 where
298 F: Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync + 'static,
299 {
300 self.get_client_id = Some(Box::new(extractor));
301 self
302 }
303
304 #[must_use]
309 pub fn global(mut self) -> Self {
310 self.global_limit = true;
311 self.global_limiter = Some(TokenBucketRateLimiter::new(
312 self.burst_capacity,
313 self.max_requests_per_second,
314 ));
315 self
316 }
317
318 fn get_client_identifier(&self, ctx: &McpContext, request: &JsonRpcRequest) -> String {
319 if let Some(ref extractor) = self.get_client_id {
320 if let Some(id) = extractor(ctx, request) {
321 return id;
322 }
323 }
324 "global".to_string()
325 }
326
327 fn get_or_create_limiter(&self, client_id: &str) -> bool {
328 let mut limiters = self
329 .limiters
330 .lock()
331 .unwrap_or_else(std::sync::PoisonError::into_inner);
332
333 if !limiters.contains_key(client_id) {
334 limiters.insert(
335 client_id.to_string(),
336 TokenBucketRateLimiter::new(self.burst_capacity, self.max_requests_per_second),
337 );
338 }
339
340 limiters.get(client_id).unwrap().try_consume(1)
341 }
342}
343
344impl Middleware for RateLimitingMiddleware {
345 fn on_request(
346 &self,
347 ctx: &McpContext,
348 request: &JsonRpcRequest,
349 ) -> McpResult<MiddlewareDecision> {
350 let allowed = if self.global_limit {
351 if let Some(ref limiter) = self.global_limiter {
353 limiter.try_consume(1)
354 } else {
355 true
356 }
357 } else {
358 let client_id = self.get_client_identifier(ctx, request);
360 self.get_or_create_limiter(&client_id)
361 };
362
363 if allowed {
364 Ok(MiddlewareDecision::Continue)
365 } else {
366 let msg = if self.global_limit {
367 "Global rate limit exceeded".to_string()
368 } else {
369 let client_id = self.get_client_identifier(ctx, request);
370 format!("Rate limit exceeded for client: {client_id}")
371 };
372 Err(rate_limit_error(msg))
373 }
374 }
375}
376
377pub struct SlidingWindowRateLimitingMiddleware {
391 max_requests: usize,
393 window_seconds: u64,
395 get_client_id: Option<ClientIdExtractor>,
397 limiters: Mutex<HashMap<String, SlidingWindowRateLimiter>>,
399}
400
401impl std::fmt::Debug for SlidingWindowRateLimitingMiddleware {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 f.debug_struct("SlidingWindowRateLimitingMiddleware")
404 .field("max_requests", &self.max_requests)
405 .field("window_seconds", &self.window_seconds)
406 .finish()
407 }
408}
409
410impl SlidingWindowRateLimitingMiddleware {
411 #[must_use]
418 pub fn new(max_requests: usize, window_seconds: u64) -> Self {
419 Self {
420 max_requests,
421 window_seconds,
422 get_client_id: None,
423 limiters: Mutex::new(HashMap::new()),
424 }
425 }
426
427 #[must_use]
434 pub fn per_minute(max_requests: usize, window_minutes: u64) -> Self {
435 Self::new(max_requests, window_minutes * 60)
436 }
437
438 #[must_use]
440 pub fn client_id_extractor<F>(mut self, extractor: F) -> Self
441 where
442 F: Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync + 'static,
443 {
444 self.get_client_id = Some(Box::new(extractor));
445 self
446 }
447
448 fn get_client_identifier(&self, ctx: &McpContext, request: &JsonRpcRequest) -> String {
449 if let Some(ref extractor) = self.get_client_id {
450 if let Some(id) = extractor(ctx, request) {
451 return id;
452 }
453 }
454 "global".to_string()
455 }
456
457 fn is_request_allowed(&self, client_id: &str) -> bool {
458 let mut limiters = self
459 .limiters
460 .lock()
461 .unwrap_or_else(std::sync::PoisonError::into_inner);
462
463 if !limiters.contains_key(client_id) {
464 limiters.insert(
465 client_id.to_string(),
466 SlidingWindowRateLimiter::new(self.max_requests, self.window_seconds),
467 );
468 }
469
470 limiters.get(client_id).unwrap().is_allowed()
471 }
472}
473
474impl Middleware for SlidingWindowRateLimitingMiddleware {
475 fn on_request(
476 &self,
477 ctx: &McpContext,
478 request: &JsonRpcRequest,
479 ) -> McpResult<MiddlewareDecision> {
480 let client_id = self.get_client_identifier(ctx, request);
481 let allowed = self.is_request_allowed(&client_id);
482
483 if allowed {
484 Ok(MiddlewareDecision::Continue)
485 } else {
486 let window_display = if self.window_seconds >= 60 {
487 format!("{} minute(s)", self.window_seconds / 60)
488 } else {
489 format!("{} second(s)", self.window_seconds)
490 };
491 Err(rate_limit_error(format!(
492 "Rate limit exceeded: {} requests per {} for client: {}",
493 self.max_requests, window_display, client_id
494 )))
495 }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use asupersync::Cx;
503
504 fn test_context() -> McpContext {
505 let cx = Cx::for_testing();
506 McpContext::new(cx, 1)
507 }
508
509 fn test_request(method: &str) -> JsonRpcRequest {
510 JsonRpcRequest {
511 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
512 method: method.to_string(),
513 params: None,
514 id: Some(fastmcp_protocol::RequestId::Number(1)),
515 }
516 }
517
518 #[test]
523 fn test_token_bucket_allows_burst() {
524 let limiter = TokenBucketRateLimiter::new(5, 1.0);
525
526 assert!(limiter.try_consume(1));
528 assert!(limiter.try_consume(1));
529 assert!(limiter.try_consume(1));
530 assert!(limiter.try_consume(1));
531 assert!(limiter.try_consume(1));
532
533 assert!(!limiter.try_consume(1));
535 }
536
537 #[test]
538 fn test_token_bucket_refills_over_time() {
539 let limiter = TokenBucketRateLimiter::new(2, 100.0); assert!(limiter.try_consume(1));
543 assert!(limiter.try_consume(1));
544 assert!(!limiter.try_consume(1));
545
546 std::thread::sleep(std::time::Duration::from_millis(15));
548
549 assert!(limiter.try_consume(1));
551 }
552
553 #[test]
554 fn test_token_bucket_available_tokens() {
555 let limiter = TokenBucketRateLimiter::new(10, 1.0);
556 assert!((limiter.available_tokens() - 10.0).abs() < 0.1);
557
558 limiter.try_consume(5);
559 assert!((limiter.available_tokens() - 5.0).abs() < 0.1);
560 }
561
562 #[test]
567 fn test_sliding_window_allows_up_to_limit() {
568 let limiter = SlidingWindowRateLimiter::new(3, 60);
569
570 assert!(limiter.is_allowed());
571 assert!(limiter.is_allowed());
572 assert!(limiter.is_allowed());
573 assert!(!limiter.is_allowed()); }
575
576 #[test]
577 fn test_sliding_window_current_requests() {
578 let limiter = SlidingWindowRateLimiter::new(10, 60);
579
580 assert_eq!(limiter.current_requests(), 0);
581 limiter.is_allowed();
582 assert_eq!(limiter.current_requests(), 1);
583 limiter.is_allowed();
584 assert_eq!(limiter.current_requests(), 2);
585 }
586
587 #[test]
592 fn test_rate_limiting_middleware_allows_initial_requests() {
593 let middleware = RateLimitingMiddleware::new(10.0).global();
594 let ctx = test_context();
595 let request = test_request("tools/call");
596
597 let result = middleware.on_request(&ctx, &request);
598 assert!(matches!(result, Ok(MiddlewareDecision::Continue)));
599 }
600
601 #[test]
602 fn test_rate_limiting_middleware_denies_after_burst() {
603 let middleware = RateLimitingMiddleware::new(10.0).burst_capacity(2).global();
604 let ctx = test_context();
605 let request = test_request("tools/call");
606
607 assert!(middleware.on_request(&ctx, &request).is_ok());
609 assert!(middleware.on_request(&ctx, &request).is_ok());
610
611 let result = middleware.on_request(&ctx, &request);
613 assert!(result.is_err());
614 let err = result.unwrap_err();
615 assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
616 assert!(err.message.contains("Global rate limit exceeded"));
617 }
618
619 #[test]
620 fn test_rate_limiting_middleware_per_client() {
621 let middleware = RateLimitingMiddleware::new(10.0)
622 .burst_capacity(1)
623 .client_id_extractor(|_ctx, req| Some(req.method.clone()));
624 let ctx = test_context();
625
626 let request1 = test_request("method_a");
627 let request2 = test_request("method_b");
628
629 assert!(middleware.on_request(&ctx, &request1).is_ok());
631 assert!(middleware.on_request(&ctx, &request2).is_ok());
632
633 assert!(middleware.on_request(&ctx, &request1).is_err());
635 assert!(middleware.on_request(&ctx, &request2).is_err());
636 }
637
638 #[test]
643 fn test_sliding_window_middleware_allows_up_to_limit() {
644 let middleware = SlidingWindowRateLimitingMiddleware::new(2, 60);
645 let ctx = test_context();
646 let request = test_request("tools/call");
647
648 assert!(middleware.on_request(&ctx, &request).is_ok());
649 assert!(middleware.on_request(&ctx, &request).is_ok());
650
651 let result = middleware.on_request(&ctx, &request);
652 assert!(result.is_err());
653 let err = result.unwrap_err();
654 assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
655 }
656
657 #[test]
658 fn test_sliding_window_middleware_per_minute() {
659 let middleware = SlidingWindowRateLimitingMiddleware::per_minute(100, 1);
660 let ctx = test_context();
661 let request = test_request("tools/call");
662
663 for _ in 0..100 {
665 assert!(middleware.on_request(&ctx, &request).is_ok());
666 }
667
668 assert!(middleware.on_request(&ctx, &request).is_err());
670 }
671
672 #[test]
673 fn test_rate_limit_error_code() {
674 let err = rate_limit_error("test");
675 assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
676 assert_eq!(err.message, "test");
677 }
678
679 #[test]
684 fn rate_limit_error_code_value() {
685 assert_eq!(RATE_LIMIT_ERROR_CODE, -32005);
686 }
687
688 #[test]
689 fn rate_limit_error_from_string() {
690 let err = rate_limit_error(String::from("custom message"));
691 assert_eq!(err.message, "custom message");
692 assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
693 }
694
695 #[test]
700 fn token_bucket_debug() {
701 let limiter = TokenBucketRateLimiter::new(10, 5.0);
702 let debug = format!("{:?}", limiter);
703 assert!(debug.contains("TokenBucketRateLimiter"));
704 assert!(debug.contains("10"));
705 }
706
707 #[test]
708 fn token_bucket_consume_multiple_at_once() {
709 let limiter = TokenBucketRateLimiter::new(10, 1.0);
710 assert!(limiter.try_consume(5));
712 assert!(limiter.try_consume(5));
714 assert!(!limiter.try_consume(1));
716 }
717
718 #[test]
719 fn token_bucket_consume_more_than_capacity() {
720 let limiter = TokenBucketRateLimiter::new(5, 1.0);
721 assert!(!limiter.try_consume(6));
723 assert!(limiter.try_consume(5));
725 }
726
727 #[test]
728 fn token_bucket_available_tokens_caps_at_capacity() {
729 let limiter = TokenBucketRateLimiter::new(5, 1000.0); std::thread::sleep(std::time::Duration::from_millis(10));
732 assert!(limiter.available_tokens() <= 5.0 + 0.1);
733 }
734
735 #[test]
736 fn token_bucket_available_tokens_after_full_drain() {
737 let limiter = TokenBucketRateLimiter::new(3, 1.0);
738 limiter.try_consume(3);
739 assert!(limiter.available_tokens() < 1.0);
740 }
741
742 #[test]
747 fn sliding_window_debug() {
748 let limiter = SlidingWindowRateLimiter::new(100, 60);
749 let debug = format!("{:?}", limiter);
750 assert!(debug.contains("SlidingWindowRateLimiter"));
751 assert!(debug.contains("100"));
752 }
753
754 #[test]
755 fn sliding_window_current_requests_starts_at_zero() {
756 let limiter = SlidingWindowRateLimiter::new(10, 60);
757 assert_eq!(limiter.current_requests(), 0);
758 }
759
760 #[test]
761 fn sliding_window_denied_request_not_counted() {
762 let limiter = SlidingWindowRateLimiter::new(2, 60);
763 assert!(limiter.is_allowed());
764 assert!(limiter.is_allowed());
765 assert!(!limiter.is_allowed()); assert_eq!(limiter.current_requests(), 2);
768 }
769
770 #[test]
775 fn rate_limiting_middleware_default_burst_capacity() {
776 let m = RateLimitingMiddleware::new(10.0);
777 assert_eq!(m.burst_capacity, 20);
779 assert!(!m.global_limit);
780 assert!(m.global_limiter.is_none());
781 assert!(m.get_client_id.is_none());
782 }
783
784 #[test]
785 fn rate_limiting_middleware_debug() {
786 let m = RateLimitingMiddleware::new(10.0)
787 .burst_capacity(30)
788 .global();
789 let debug = format!("{:?}", m);
790 assert!(debug.contains("RateLimitingMiddleware"));
791 assert!(debug.contains("30"));
792 assert!(debug.contains("true")); }
794
795 #[test]
796 fn rate_limiting_middleware_global_creates_limiter() {
797 let m = RateLimitingMiddleware::new(5.0).global();
798 assert!(m.global_limit);
799 assert!(m.global_limiter.is_some());
800 }
801
802 #[test]
803 fn rate_limiting_middleware_burst_capacity_without_global() {
804 let m = RateLimitingMiddleware::new(10.0).burst_capacity(50);
805 assert!(m.global_limiter.is_none());
807 assert_eq!(m.burst_capacity, 50);
808 }
809
810 #[test]
811 fn rate_limiting_middleware_burst_capacity_with_global_recreates_limiter() {
812 let m = RateLimitingMiddleware::new(10.0).global().burst_capacity(3);
813 assert_eq!(m.burst_capacity, 3);
814 assert!(m.global_limiter.is_some());
816
817 let ctx = test_context();
818 let req = test_request("test");
819 assert!(m.on_request(&ctx, &req).is_ok());
821 assert!(m.on_request(&ctx, &req).is_ok());
822 assert!(m.on_request(&ctx, &req).is_ok());
823 assert!(m.on_request(&ctx, &req).is_err());
824 }
825
826 #[test]
831 fn rate_limiting_middleware_no_extractor_uses_global_key() {
832 let m = RateLimitingMiddleware::new(10.0);
833 let ctx = test_context();
834 let req = test_request("tools/call");
835 let id = m.get_client_identifier(&ctx, &req);
836 assert_eq!(id, "global");
837 }
838
839 #[test]
840 fn rate_limiting_middleware_extractor_returning_none_uses_global() {
841 let m = RateLimitingMiddleware::new(10.0).client_id_extractor(|_ctx, _req| None);
842 let ctx = test_context();
843 let req = test_request("tools/call");
844 let id = m.get_client_identifier(&ctx, &req);
845 assert_eq!(id, "global");
846 }
847
848 #[test]
849 fn rate_limiting_middleware_extractor_returning_some() {
850 let m = RateLimitingMiddleware::new(10.0)
851 .client_id_extractor(|_ctx, _req| Some("user-42".to_string()));
852 let ctx = test_context();
853 let req = test_request("tools/call");
854 let id = m.get_client_identifier(&ctx, &req);
855 assert_eq!(id, "user-42");
856 }
857
858 #[test]
863 fn rate_limiting_middleware_per_client_no_extractor_all_share_global_key() {
864 let m = RateLimitingMiddleware::new(10.0).burst_capacity(2);
866 let ctx = test_context();
867 let req_a = test_request("method_a");
868 let req_b = test_request("method_b");
869
870 assert!(m.on_request(&ctx, &req_a).is_ok());
872 assert!(m.on_request(&ctx, &req_b).is_ok());
873 assert!(m.on_request(&ctx, &req_a).is_err());
875 }
876
877 #[test]
878 fn rate_limiting_middleware_error_msg_per_client() {
879 let m = RateLimitingMiddleware::new(10.0)
880 .burst_capacity(1)
881 .client_id_extractor(|_ctx, _req| Some("alice".to_string()));
882 let ctx = test_context();
883 let req = test_request("tools/call");
884
885 m.on_request(&ctx, &req).unwrap();
886 let err = m.on_request(&ctx, &req).unwrap_err();
887 assert!(
888 err.message
889 .contains("Rate limit exceeded for client: alice")
890 );
891 }
892
893 #[test]
894 fn rate_limiting_middleware_error_msg_global() {
895 let m = RateLimitingMiddleware::new(10.0).burst_capacity(1).global();
896 let ctx = test_context();
897 let req = test_request("tools/call");
898
899 m.on_request(&ctx, &req).unwrap();
900 let err = m.on_request(&ctx, &req).unwrap_err();
901 assert!(err.message.contains("Global rate limit exceeded"));
902 }
903
904 #[test]
909 fn sliding_window_middleware_new_fields() {
910 let m = SlidingWindowRateLimitingMiddleware::new(50, 120);
911 assert_eq!(m.max_requests, 50);
912 assert_eq!(m.window_seconds, 120);
913 assert!(m.get_client_id.is_none());
914 }
915
916 #[test]
917 fn sliding_window_middleware_per_minute_converts() {
918 let m = SlidingWindowRateLimitingMiddleware::per_minute(100, 5);
919 assert_eq!(m.max_requests, 100);
920 assert_eq!(m.window_seconds, 300); }
922
923 #[test]
924 fn sliding_window_middleware_debug() {
925 let m = SlidingWindowRateLimitingMiddleware::new(50, 120);
926 let debug = format!("{:?}", m);
927 assert!(debug.contains("SlidingWindowRateLimitingMiddleware"));
928 assert!(debug.contains("50"));
929 assert!(debug.contains("120"));
930 }
931
932 #[test]
937 fn sliding_window_middleware_no_extractor_uses_global() {
938 let m = SlidingWindowRateLimitingMiddleware::new(10, 60);
939 let ctx = test_context();
940 let req = test_request("tools/call");
941 let id = m.get_client_identifier(&ctx, &req);
942 assert_eq!(id, "global");
943 }
944
945 #[test]
946 fn sliding_window_middleware_extractor_returning_none_uses_global() {
947 let m =
948 SlidingWindowRateLimitingMiddleware::new(10, 60).client_id_extractor(|_ctx, _req| None);
949 let ctx = test_context();
950 let req = test_request("tools/call");
951 let id = m.get_client_identifier(&ctx, &req);
952 assert_eq!(id, "global");
953 }
954
955 #[test]
956 fn sliding_window_middleware_extractor_returning_some() {
957 let m = SlidingWindowRateLimitingMiddleware::new(10, 60)
958 .client_id_extractor(|_ctx, _req| Some("bob".to_string()));
959 let ctx = test_context();
960 let req = test_request("tools/call");
961 let id = m.get_client_identifier(&ctx, &req);
962 assert_eq!(id, "bob");
963 }
964
965 #[test]
970 fn sliding_window_middleware_per_client() {
971 let m = SlidingWindowRateLimitingMiddleware::new(1, 60)
972 .client_id_extractor(|_ctx, req| Some(req.method.clone()));
973 let ctx = test_context();
974 let req_a = test_request("method_a");
975 let req_b = test_request("method_b");
976
977 assert!(m.on_request(&ctx, &req_a).is_ok());
979 assert!(m.on_request(&ctx, &req_b).is_ok());
980
981 assert!(m.on_request(&ctx, &req_a).is_err());
983 assert!(m.on_request(&ctx, &req_b).is_err());
984 }
985
986 #[test]
991 fn sliding_window_middleware_error_msg_seconds() {
992 let m = SlidingWindowRateLimitingMiddleware::new(1, 30);
993 let ctx = test_context();
994 let req = test_request("tools/call");
995
996 m.on_request(&ctx, &req).unwrap();
997 let err = m.on_request(&ctx, &req).unwrap_err();
998 assert!(err.message.contains("30 second(s)"));
999 assert!(err.message.contains("client: global"));
1000 }
1001
1002 #[test]
1003 fn sliding_window_middleware_error_msg_minutes() {
1004 let m = SlidingWindowRateLimitingMiddleware::new(1, 120);
1005 let ctx = test_context();
1006 let req = test_request("tools/call");
1007
1008 m.on_request(&ctx, &req).unwrap();
1009 let err = m.on_request(&ctx, &req).unwrap_err();
1010 assert!(err.message.contains("2 minute(s)"));
1011 }
1012
1013 #[test]
1014 fn sliding_window_middleware_error_msg_with_client_id() {
1015 let m = SlidingWindowRateLimitingMiddleware::new(1, 60)
1016 .client_id_extractor(|_ctx, _req| Some("alice".to_string()));
1017 let ctx = test_context();
1018 let req = test_request("tools/call");
1019
1020 m.on_request(&ctx, &req).unwrap();
1021 let err = m.on_request(&ctx, &req).unwrap_err();
1022 assert!(err.message.contains("client: alice"));
1023 assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
1024 }
1025
1026 #[test]
1031 fn rate_limiting_middleware_get_or_create_limiter_creates_new() {
1032 let m = RateLimitingMiddleware::new(10.0).burst_capacity(2);
1033 assert!(m.get_or_create_limiter("new-client"));
1035 assert!(m.get_or_create_limiter("new-client"));
1037 assert!(!m.get_or_create_limiter("new-client"));
1039 }
1040
1041 #[test]
1042 fn sliding_window_middleware_is_request_allowed_creates_new() {
1043 let m = SlidingWindowRateLimitingMiddleware::new(2, 60);
1044 assert!(m.is_request_allowed("c1"));
1045 assert!(m.is_request_allowed("c1"));
1046 assert!(!m.is_request_allowed("c1"));
1047
1048 assert!(m.is_request_allowed("c2"));
1050 }
1051
1052 #[test]
1053 fn sliding_window_requests_expire_after_window() {
1054 let limiter = SlidingWindowRateLimiter::new(2, 1); assert!(limiter.is_allowed());
1056 assert!(limiter.is_allowed());
1057 assert!(!limiter.is_allowed()); std::thread::sleep(std::time::Duration::from_millis(1100));
1061
1062 assert!(limiter.is_allowed());
1064 }
1065
1066 #[test]
1067 fn sliding_window_current_requests_resets_after_window() {
1068 let limiter = SlidingWindowRateLimiter::new(5, 1); limiter.is_allowed();
1070 limiter.is_allowed();
1071 assert_eq!(limiter.current_requests(), 2);
1072
1073 std::thread::sleep(std::time::Duration::from_millis(1100));
1074
1075 assert_eq!(limiter.current_requests(), 0);
1077 }
1078
1079 #[test]
1080 fn sliding_window_error_exactly_60_seconds_shows_minutes() {
1081 let m = SlidingWindowRateLimitingMiddleware::new(1, 60);
1082 let ctx = test_context();
1083 let req = test_request("tools/call");
1084
1085 m.on_request(&ctx, &req).unwrap();
1086 let err = m.on_request(&ctx, &req).unwrap_err();
1087 assert!(
1088 err.message.contains("1 minute(s)"),
1089 "60 seconds should display as minutes: {}",
1090 err.message
1091 );
1092 }
1093
1094 #[test]
1095 fn token_bucket_try_consume_zero_always_succeeds() {
1096 let limiter = TokenBucketRateLimiter::new(3, 1.0);
1097 limiter.try_consume(3);
1099 assert!(!limiter.try_consume(1)); assert!(limiter.try_consume(0));
1103 }
1104
1105 #[test]
1106 fn token_bucket_refill_rate_zero_never_refills() {
1107 let limiter = TokenBucketRateLimiter::new(2, 0.0); assert!(limiter.try_consume(2));
1109 assert!(!limiter.try_consume(1));
1110
1111 std::thread::sleep(std::time::Duration::from_millis(50));
1113 assert!(!limiter.try_consume(1));
1114 }
1115}