Skip to main content

fastmcp_server/
rate_limiting.rs

1//! Rate limiting middleware for protecting FastMCP servers from abuse.
2//!
3//! This module provides two rate limiting strategies:
4//!
5//! - [`RateLimitingMiddleware`]: Token bucket algorithm for burst-friendly limits
6//! - [`SlidingWindowRateLimitingMiddleware`]: Sliding window for precise tracking
7//!
8//! # Example
9//!
10//! ```ignore
11//! use fastmcp_rust::prelude::*;
12//! use fastmcp_server::rate_limiting::RateLimitingMiddleware;
13//!
14//! // Allow 10 requests per second with bursts up to 20
15//! let rate_limiter = RateLimitingMiddleware::new(10.0)
16//!     .burst_capacity(20);
17//!
18//! Server::new("my-server", "1.0.0")
19//!     .middleware(rate_limiter)
20//!     .run_stdio();
21//! ```
22
23use std::collections::{HashMap, VecDeque};
24use std::sync::Mutex;
25use std::time::Instant;
26
27use fastmcp_core::{McpContext, McpError, McpErrorCode, McpResult};
28use fastmcp_protocol::JsonRpcRequest;
29
30use crate::{Middleware, MiddlewareDecision};
31
32/// Error code for rate limit exceeded (-32005).
33///
34/// This is in the MCP server error range (-32000 to -32099).
35pub const RATE_LIMIT_ERROR_CODE: i32 = -32005;
36
37/// Creates a rate limit exceeded error.
38#[must_use]
39pub fn rate_limit_error(message: impl Into<String>) -> McpError {
40    McpError::new(McpErrorCode::Custom(RATE_LIMIT_ERROR_CODE), message)
41}
42
43/// Token bucket implementation for rate limiting.
44///
45/// The token bucket algorithm allows for burst traffic while maintaining
46/// a sustainable long-term rate. Tokens are added at a constant rate and
47/// consumed when requests arrive.
48#[derive(Debug)]
49pub struct TokenBucketRateLimiter {
50    /// Maximum number of tokens in the bucket.
51    capacity: usize,
52    /// Tokens added per second.
53    refill_rate: f64,
54    /// Current number of tokens (as f64 for fractional tokens).
55    tokens: Mutex<f64>,
56    /// Last time tokens were refilled.
57    last_refill: Mutex<Instant>,
58}
59
60impl TokenBucketRateLimiter {
61    /// Creates a new token bucket rate limiter.
62    ///
63    /// # Arguments
64    ///
65    /// * `capacity` - Maximum number of tokens (burst capacity)
66    /// * `refill_rate` - Tokens added per second (sustained rate)
67    #[must_use]
68    pub fn new(capacity: usize, refill_rate: f64) -> Self {
69        Self {
70            capacity,
71            refill_rate,
72            tokens: Mutex::new(capacity as f64),
73            last_refill: Mutex::new(Instant::now()),
74        }
75    }
76
77    /// Tries to consume tokens from the bucket.
78    ///
79    /// Returns `true` if tokens were available and consumed, `false` otherwise.
80    pub fn try_consume(&self, tokens: usize) -> bool {
81        let mut current_tokens = self
82            .tokens
83            .lock()
84            .unwrap_or_else(std::sync::PoisonError::into_inner);
85        let mut last_refill = self
86            .last_refill
87            .lock()
88            .unwrap_or_else(std::sync::PoisonError::into_inner);
89
90        let now = Instant::now();
91        let elapsed = now.duration_since(*last_refill).as_secs_f64();
92
93        // Add tokens based on elapsed time
94        *current_tokens = (*current_tokens + elapsed * self.refill_rate).min(self.capacity as f64);
95        *last_refill = now;
96
97        let tokens_needed = tokens as f64;
98        if *current_tokens >= tokens_needed {
99            *current_tokens -= tokens_needed;
100            true
101        } else {
102            false
103        }
104    }
105
106    /// Returns the current number of available tokens.
107    #[must_use]
108    pub fn available_tokens(&self) -> f64 {
109        let mut current_tokens = self
110            .tokens
111            .lock()
112            .unwrap_or_else(std::sync::PoisonError::into_inner);
113        let mut last_refill = self
114            .last_refill
115            .lock()
116            .unwrap_or_else(std::sync::PoisonError::into_inner);
117
118        let now = Instant::now();
119        let elapsed = now.duration_since(*last_refill).as_secs_f64();
120
121        // Update tokens without consuming
122        *current_tokens = (*current_tokens + elapsed * self.refill_rate).min(self.capacity as f64);
123        *last_refill = now;
124
125        *current_tokens
126    }
127}
128
129/// Sliding window rate limiter implementation.
130///
131/// Tracks individual request timestamps within a time window for precise
132/// rate limiting. More memory-intensive than token bucket but provides
133/// exact request counting.
134#[derive(Debug)]
135pub struct SlidingWindowRateLimiter {
136    /// Maximum requests allowed in the time window.
137    max_requests: usize,
138    /// Time window in seconds.
139    window_seconds: u64,
140    /// Request timestamps (as durations from a fixed start time).
141    requests: Mutex<VecDeque<Instant>>,
142}
143
144impl SlidingWindowRateLimiter {
145    /// Creates a new sliding window rate limiter.
146    ///
147    /// # Arguments
148    ///
149    /// * `max_requests` - Maximum requests allowed in the time window
150    /// * `window_seconds` - Time window duration in seconds
151    #[must_use]
152    pub fn new(max_requests: usize, window_seconds: u64) -> Self {
153        Self {
154            max_requests,
155            window_seconds,
156            requests: Mutex::new(VecDeque::new()),
157        }
158    }
159
160    /// Checks if a request is allowed under the rate limit.
161    ///
162    /// If allowed, records the request timestamp and returns `true`.
163    /// Otherwise returns `false`.
164    pub fn is_allowed(&self) -> bool {
165        let mut requests = self
166            .requests
167            .lock()
168            .unwrap_or_else(std::sync::PoisonError::into_inner);
169
170        let now = Instant::now();
171        let cutoff = now - std::time::Duration::from_secs(self.window_seconds);
172
173        // Remove old requests outside the window
174        while let Some(&oldest) = requests.front() {
175            if oldest < cutoff {
176                requests.pop_front();
177            } else {
178                break;
179            }
180        }
181
182        if requests.len() < self.max_requests {
183            requests.push_back(now);
184            true
185        } else {
186            false
187        }
188    }
189
190    /// Returns the current number of requests in the window.
191    #[must_use]
192    pub fn current_requests(&self) -> usize {
193        let mut requests = self
194            .requests
195            .lock()
196            .unwrap_or_else(std::sync::PoisonError::into_inner);
197
198        let now = Instant::now();
199        let cutoff = now - std::time::Duration::from_secs(self.window_seconds);
200
201        // Remove old requests outside the window
202        while let Some(&oldest) = requests.front() {
203            if oldest < cutoff {
204                requests.pop_front();
205            } else {
206                break;
207            }
208        }
209
210        requests.len()
211    }
212}
213
214/// Function type for extracting client ID from request context.
215pub type ClientIdExtractor =
216    Box<dyn Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync>;
217
218/// Rate limiting middleware using token bucket algorithm.
219///
220/// Uses a token bucket algorithm by default, allowing for burst traffic
221/// while maintaining a sustainable long-term rate.
222///
223/// # Example
224///
225/// ```ignore
226/// use fastmcp_server::rate_limiting::RateLimitingMiddleware;
227///
228/// // Allow 10 requests per second with bursts up to 20
229/// let rate_limiter = RateLimitingMiddleware::new(10.0)
230///     .burst_capacity(20);
231/// ```
232pub struct RateLimitingMiddleware {
233    /// Sustained requests per second allowed.
234    max_requests_per_second: f64,
235    /// Maximum burst capacity.
236    burst_capacity: usize,
237    /// Function to extract client ID from context (for per-client limiting).
238    get_client_id: Option<ClientIdExtractor>,
239    /// If true, apply limit globally; if false, per-client.
240    global_limit: bool,
241    /// Storage for rate limiters per client.
242    limiters: Mutex<HashMap<String, TokenBucketRateLimiter>>,
243    /// Global rate limiter (used when global_limit is true).
244    global_limiter: Option<TokenBucketRateLimiter>,
245}
246
247impl std::fmt::Debug for RateLimitingMiddleware {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        f.debug_struct("RateLimitingMiddleware")
250            .field("max_requests_per_second", &self.max_requests_per_second)
251            .field("burst_capacity", &self.burst_capacity)
252            .field("global_limit", &self.global_limit)
253            .finish()
254    }
255}
256
257impl RateLimitingMiddleware {
258    /// Creates a new rate limiting middleware with the specified rate.
259    ///
260    /// # Arguments
261    ///
262    /// * `max_requests_per_second` - Sustained requests per second allowed
263    ///
264    /// Burst capacity defaults to 2x the sustained rate.
265    #[must_use]
266    pub fn new(max_requests_per_second: f64) -> Self {
267        let burst_capacity = (max_requests_per_second * 2.0) as usize;
268        Self {
269            max_requests_per_second,
270            burst_capacity,
271            get_client_id: None,
272            global_limit: false,
273            limiters: Mutex::new(HashMap::new()),
274            global_limiter: None,
275        }
276    }
277
278    /// Sets the burst capacity (maximum tokens in the bucket).
279    #[must_use]
280    pub fn burst_capacity(mut self, capacity: usize) -> Self {
281        self.burst_capacity = capacity;
282        // Re-create global limiter if it exists
283        if self.global_limit {
284            self.global_limiter = Some(TokenBucketRateLimiter::new(
285                capacity,
286                self.max_requests_per_second,
287            ));
288        }
289        self
290    }
291
292    /// Sets a custom function to extract client ID from the request context.
293    ///
294    /// If not set, all clients share a single rate limit (global limiting).
295    #[must_use]
296    pub fn client_id_extractor<F>(mut self, extractor: F) -> Self
297    where
298        F: Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync + 'static,
299    {
300        self.get_client_id = Some(Box::new(extractor));
301        self
302    }
303
304    /// Enables global rate limiting (all clients share one limit).
305    ///
306    /// When enabled, all requests count against a single rate limit
307    /// regardless of client identity.
308    #[must_use]
309    pub fn global(mut self) -> Self {
310        self.global_limit = true;
311        self.global_limiter = Some(TokenBucketRateLimiter::new(
312            self.burst_capacity,
313            self.max_requests_per_second,
314        ));
315        self
316    }
317
318    fn get_client_identifier(&self, ctx: &McpContext, request: &JsonRpcRequest) -> String {
319        if let Some(ref extractor) = self.get_client_id {
320            if let Some(id) = extractor(ctx, request) {
321                return id;
322            }
323        }
324        "global".to_string()
325    }
326
327    fn get_or_create_limiter(&self, client_id: &str) -> bool {
328        let mut limiters = self
329            .limiters
330            .lock()
331            .unwrap_or_else(std::sync::PoisonError::into_inner);
332
333        if !limiters.contains_key(client_id) {
334            limiters.insert(
335                client_id.to_string(),
336                TokenBucketRateLimiter::new(self.burst_capacity, self.max_requests_per_second),
337            );
338        }
339
340        limiters.get(client_id).unwrap().try_consume(1)
341    }
342}
343
344impl Middleware for RateLimitingMiddleware {
345    fn on_request(
346        &self,
347        ctx: &McpContext,
348        request: &JsonRpcRequest,
349    ) -> McpResult<MiddlewareDecision> {
350        let allowed = if self.global_limit {
351            // Global rate limiting
352            if let Some(ref limiter) = self.global_limiter {
353                limiter.try_consume(1)
354            } else {
355                true
356            }
357        } else {
358            // Per-client rate limiting
359            let client_id = self.get_client_identifier(ctx, request);
360            self.get_or_create_limiter(&client_id)
361        };
362
363        if allowed {
364            Ok(MiddlewareDecision::Continue)
365        } else {
366            let msg = if self.global_limit {
367                "Global rate limit exceeded".to_string()
368            } else {
369                let client_id = self.get_client_identifier(ctx, request);
370                format!("Rate limit exceeded for client: {client_id}")
371            };
372            Err(rate_limit_error(msg))
373        }
374    }
375}
376
377/// Rate limiting middleware using sliding window algorithm.
378///
379/// Uses a sliding window approach which provides more precise rate limiting
380/// but uses more memory to track individual request timestamps.
381///
382/// # Example
383///
384/// ```ignore
385/// use fastmcp_server::rate_limiting::SlidingWindowRateLimitingMiddleware;
386///
387/// // Allow 100 requests per minute
388/// let rate_limiter = SlidingWindowRateLimitingMiddleware::new(100, 60);
389/// ```
390pub struct SlidingWindowRateLimitingMiddleware {
391    /// Maximum requests allowed in the time window.
392    max_requests: usize,
393    /// Time window in seconds.
394    window_seconds: u64,
395    /// Function to extract client ID from context.
396    get_client_id: Option<ClientIdExtractor>,
397    /// Storage for rate limiters per client.
398    limiters: Mutex<HashMap<String, SlidingWindowRateLimiter>>,
399}
400
401impl std::fmt::Debug for SlidingWindowRateLimitingMiddleware {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        f.debug_struct("SlidingWindowRateLimitingMiddleware")
404            .field("max_requests", &self.max_requests)
405            .field("window_seconds", &self.window_seconds)
406            .finish()
407    }
408}
409
410impl SlidingWindowRateLimitingMiddleware {
411    /// Creates a new sliding window rate limiting middleware.
412    ///
413    /// # Arguments
414    ///
415    /// * `max_requests` - Maximum requests allowed in the time window
416    /// * `window_seconds` - Time window duration in seconds
417    #[must_use]
418    pub fn new(max_requests: usize, window_seconds: u64) -> Self {
419        Self {
420            max_requests,
421            window_seconds,
422            get_client_id: None,
423            limiters: Mutex::new(HashMap::new()),
424        }
425    }
426
427    /// Creates a sliding window rate limiter with minutes-based window.
428    ///
429    /// # Arguments
430    ///
431    /// * `max_requests` - Maximum requests allowed in the time window
432    /// * `window_minutes` - Time window duration in minutes
433    #[must_use]
434    pub fn per_minute(max_requests: usize, window_minutes: u64) -> Self {
435        Self::new(max_requests, window_minutes * 60)
436    }
437
438    /// Sets a custom function to extract client ID from the request context.
439    #[must_use]
440    pub fn client_id_extractor<F>(mut self, extractor: F) -> Self
441    where
442        F: Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync + 'static,
443    {
444        self.get_client_id = Some(Box::new(extractor));
445        self
446    }
447
448    fn get_client_identifier(&self, ctx: &McpContext, request: &JsonRpcRequest) -> String {
449        if let Some(ref extractor) = self.get_client_id {
450            if let Some(id) = extractor(ctx, request) {
451                return id;
452            }
453        }
454        "global".to_string()
455    }
456
457    fn is_request_allowed(&self, client_id: &str) -> bool {
458        let mut limiters = self
459            .limiters
460            .lock()
461            .unwrap_or_else(std::sync::PoisonError::into_inner);
462
463        if !limiters.contains_key(client_id) {
464            limiters.insert(
465                client_id.to_string(),
466                SlidingWindowRateLimiter::new(self.max_requests, self.window_seconds),
467            );
468        }
469
470        limiters.get(client_id).unwrap().is_allowed()
471    }
472}
473
474impl Middleware for SlidingWindowRateLimitingMiddleware {
475    fn on_request(
476        &self,
477        ctx: &McpContext,
478        request: &JsonRpcRequest,
479    ) -> McpResult<MiddlewareDecision> {
480        let client_id = self.get_client_identifier(ctx, request);
481        let allowed = self.is_request_allowed(&client_id);
482
483        if allowed {
484            Ok(MiddlewareDecision::Continue)
485        } else {
486            let window_display = if self.window_seconds >= 60 {
487                format!("{} minute(s)", self.window_seconds / 60)
488            } else {
489                format!("{} second(s)", self.window_seconds)
490            };
491            Err(rate_limit_error(format!(
492                "Rate limit exceeded: {} requests per {} for client: {}",
493                self.max_requests, window_display, client_id
494            )))
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use asupersync::Cx;
503
504    fn test_context() -> McpContext {
505        let cx = Cx::for_testing();
506        McpContext::new(cx, 1)
507    }
508
509    fn test_request(method: &str) -> JsonRpcRequest {
510        JsonRpcRequest {
511            jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
512            method: method.to_string(),
513            params: None,
514            id: Some(fastmcp_protocol::RequestId::Number(1)),
515        }
516    }
517
518    // ========================================
519    // TokenBucketRateLimiter tests
520    // ========================================
521
522    #[test]
523    fn test_token_bucket_allows_burst() {
524        let limiter = TokenBucketRateLimiter::new(5, 1.0);
525
526        // Should allow burst up to capacity
527        assert!(limiter.try_consume(1));
528        assert!(limiter.try_consume(1));
529        assert!(limiter.try_consume(1));
530        assert!(limiter.try_consume(1));
531        assert!(limiter.try_consume(1));
532
533        // Should deny once capacity exhausted
534        assert!(!limiter.try_consume(1));
535    }
536
537    #[test]
538    fn test_token_bucket_refills_over_time() {
539        let limiter = TokenBucketRateLimiter::new(2, 100.0); // 100 tokens per second
540
541        // Exhaust tokens
542        assert!(limiter.try_consume(1));
543        assert!(limiter.try_consume(1));
544        assert!(!limiter.try_consume(1));
545
546        // Wait for refill (10ms should add ~1 token at 100 t/s)
547        std::thread::sleep(std::time::Duration::from_millis(15));
548
549        // Should have refilled
550        assert!(limiter.try_consume(1));
551    }
552
553    #[test]
554    fn test_token_bucket_available_tokens() {
555        let limiter = TokenBucketRateLimiter::new(10, 1.0);
556        assert!((limiter.available_tokens() - 10.0).abs() < 0.1);
557
558        limiter.try_consume(5);
559        assert!((limiter.available_tokens() - 5.0).abs() < 0.1);
560    }
561
562    // ========================================
563    // SlidingWindowRateLimiter tests
564    // ========================================
565
566    #[test]
567    fn test_sliding_window_allows_up_to_limit() {
568        let limiter = SlidingWindowRateLimiter::new(3, 60);
569
570        assert!(limiter.is_allowed());
571        assert!(limiter.is_allowed());
572        assert!(limiter.is_allowed());
573        assert!(!limiter.is_allowed()); // Fourth request denied
574    }
575
576    #[test]
577    fn test_sliding_window_current_requests() {
578        let limiter = SlidingWindowRateLimiter::new(10, 60);
579
580        assert_eq!(limiter.current_requests(), 0);
581        limiter.is_allowed();
582        assert_eq!(limiter.current_requests(), 1);
583        limiter.is_allowed();
584        assert_eq!(limiter.current_requests(), 2);
585    }
586
587    // ========================================
588    // RateLimitingMiddleware tests
589    // ========================================
590
591    #[test]
592    fn test_rate_limiting_middleware_allows_initial_requests() {
593        let middleware = RateLimitingMiddleware::new(10.0).global();
594        let ctx = test_context();
595        let request = test_request("tools/call");
596
597        let result = middleware.on_request(&ctx, &request);
598        assert!(matches!(result, Ok(MiddlewareDecision::Continue)));
599    }
600
601    #[test]
602    fn test_rate_limiting_middleware_denies_after_burst() {
603        let middleware = RateLimitingMiddleware::new(10.0).burst_capacity(2).global();
604        let ctx = test_context();
605        let request = test_request("tools/call");
606
607        // First two should succeed (burst capacity = 2)
608        assert!(middleware.on_request(&ctx, &request).is_ok());
609        assert!(middleware.on_request(&ctx, &request).is_ok());
610
611        // Third should fail
612        let result = middleware.on_request(&ctx, &request);
613        assert!(result.is_err());
614        let err = result.unwrap_err();
615        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
616        assert!(err.message.contains("Global rate limit exceeded"));
617    }
618
619    #[test]
620    fn test_rate_limiting_middleware_per_client() {
621        let middleware = RateLimitingMiddleware::new(10.0)
622            .burst_capacity(1)
623            .client_id_extractor(|_ctx, req| Some(req.method.clone()));
624        let ctx = test_context();
625
626        let request1 = test_request("method_a");
627        let request2 = test_request("method_b");
628
629        // Each "client" (method) gets their own bucket
630        assert!(middleware.on_request(&ctx, &request1).is_ok());
631        assert!(middleware.on_request(&ctx, &request2).is_ok());
632
633        // Now both are exhausted
634        assert!(middleware.on_request(&ctx, &request1).is_err());
635        assert!(middleware.on_request(&ctx, &request2).is_err());
636    }
637
638    // ========================================
639    // SlidingWindowRateLimitingMiddleware tests
640    // ========================================
641
642    #[test]
643    fn test_sliding_window_middleware_allows_up_to_limit() {
644        let middleware = SlidingWindowRateLimitingMiddleware::new(2, 60);
645        let ctx = test_context();
646        let request = test_request("tools/call");
647
648        assert!(middleware.on_request(&ctx, &request).is_ok());
649        assert!(middleware.on_request(&ctx, &request).is_ok());
650
651        let result = middleware.on_request(&ctx, &request);
652        assert!(result.is_err());
653        let err = result.unwrap_err();
654        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
655    }
656
657    #[test]
658    fn test_sliding_window_middleware_per_minute() {
659        let middleware = SlidingWindowRateLimitingMiddleware::per_minute(100, 1);
660        let ctx = test_context();
661        let request = test_request("tools/call");
662
663        // Should allow many requests
664        for _ in 0..100 {
665            assert!(middleware.on_request(&ctx, &request).is_ok());
666        }
667
668        // 101st should fail
669        assert!(middleware.on_request(&ctx, &request).is_err());
670    }
671
672    #[test]
673    fn test_rate_limit_error_code() {
674        let err = rate_limit_error("test");
675        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
676        assert_eq!(err.message, "test");
677    }
678
679    // ========================================
680    // rate_limit_error / RATE_LIMIT_ERROR_CODE
681    // ========================================
682
683    #[test]
684    fn rate_limit_error_code_value() {
685        assert_eq!(RATE_LIMIT_ERROR_CODE, -32005);
686    }
687
688    #[test]
689    fn rate_limit_error_from_string() {
690        let err = rate_limit_error(String::from("custom message"));
691        assert_eq!(err.message, "custom message");
692        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
693    }
694
695    // ========================================
696    // TokenBucketRateLimiter — additional
697    // ========================================
698
699    #[test]
700    fn token_bucket_debug() {
701        let limiter = TokenBucketRateLimiter::new(10, 5.0);
702        let debug = format!("{:?}", limiter);
703        assert!(debug.contains("TokenBucketRateLimiter"));
704        assert!(debug.contains("10"));
705    }
706
707    #[test]
708    fn token_bucket_consume_multiple_at_once() {
709        let limiter = TokenBucketRateLimiter::new(10, 1.0);
710        // Consume 5 at once — should succeed
711        assert!(limiter.try_consume(5));
712        // Consume another 5 — should succeed (exactly 10 tokens)
713        assert!(limiter.try_consume(5));
714        // No tokens left
715        assert!(!limiter.try_consume(1));
716    }
717
718    #[test]
719    fn token_bucket_consume_more_than_capacity() {
720        let limiter = TokenBucketRateLimiter::new(5, 1.0);
721        // Request more than capacity — should fail immediately
722        assert!(!limiter.try_consume(6));
723        // Bucket still has tokens (nothing was consumed on failure)
724        assert!(limiter.try_consume(5));
725    }
726
727    #[test]
728    fn token_bucket_available_tokens_caps_at_capacity() {
729        let limiter = TokenBucketRateLimiter::new(5, 1000.0); // Very high refill
730        // Even with high refill rate, wait a bit — should not exceed capacity
731        std::thread::sleep(std::time::Duration::from_millis(10));
732        assert!(limiter.available_tokens() <= 5.0 + 0.1);
733    }
734
735    #[test]
736    fn token_bucket_available_tokens_after_full_drain() {
737        let limiter = TokenBucketRateLimiter::new(3, 1.0);
738        limiter.try_consume(3);
739        assert!(limiter.available_tokens() < 1.0);
740    }
741
742    // ========================================
743    // SlidingWindowRateLimiter — additional
744    // ========================================
745
746    #[test]
747    fn sliding_window_debug() {
748        let limiter = SlidingWindowRateLimiter::new(100, 60);
749        let debug = format!("{:?}", limiter);
750        assert!(debug.contains("SlidingWindowRateLimiter"));
751        assert!(debug.contains("100"));
752    }
753
754    #[test]
755    fn sliding_window_current_requests_starts_at_zero() {
756        let limiter = SlidingWindowRateLimiter::new(10, 60);
757        assert_eq!(limiter.current_requests(), 0);
758    }
759
760    #[test]
761    fn sliding_window_denied_request_not_counted() {
762        let limiter = SlidingWindowRateLimiter::new(2, 60);
763        assert!(limiter.is_allowed());
764        assert!(limiter.is_allowed());
765        assert!(!limiter.is_allowed()); // denied
766        // Only 2 requests counted (not the denied one)
767        assert_eq!(limiter.current_requests(), 2);
768    }
769
770    // ========================================
771    // RateLimitingMiddleware — construction/Debug
772    // ========================================
773
774    #[test]
775    fn rate_limiting_middleware_default_burst_capacity() {
776        let m = RateLimitingMiddleware::new(10.0);
777        // Default burst capacity is 2x rate = 20
778        assert_eq!(m.burst_capacity, 20);
779        assert!(!m.global_limit);
780        assert!(m.global_limiter.is_none());
781        assert!(m.get_client_id.is_none());
782    }
783
784    #[test]
785    fn rate_limiting_middleware_debug() {
786        let m = RateLimitingMiddleware::new(10.0)
787            .burst_capacity(30)
788            .global();
789        let debug = format!("{:?}", m);
790        assert!(debug.contains("RateLimitingMiddleware"));
791        assert!(debug.contains("30"));
792        assert!(debug.contains("true")); // global_limit
793    }
794
795    #[test]
796    fn rate_limiting_middleware_global_creates_limiter() {
797        let m = RateLimitingMiddleware::new(5.0).global();
798        assert!(m.global_limit);
799        assert!(m.global_limiter.is_some());
800    }
801
802    #[test]
803    fn rate_limiting_middleware_burst_capacity_without_global() {
804        let m = RateLimitingMiddleware::new(10.0).burst_capacity(50);
805        // No global limiter created when not in global mode
806        assert!(m.global_limiter.is_none());
807        assert_eq!(m.burst_capacity, 50);
808    }
809
810    #[test]
811    fn rate_limiting_middleware_burst_capacity_with_global_recreates_limiter() {
812        let m = RateLimitingMiddleware::new(10.0).global().burst_capacity(3);
813        assert_eq!(m.burst_capacity, 3);
814        // Global limiter should exist with new capacity
815        assert!(m.global_limiter.is_some());
816
817        let ctx = test_context();
818        let req = test_request("test");
819        // Should allow exactly 3 requests (burst capacity)
820        assert!(m.on_request(&ctx, &req).is_ok());
821        assert!(m.on_request(&ctx, &req).is_ok());
822        assert!(m.on_request(&ctx, &req).is_ok());
823        assert!(m.on_request(&ctx, &req).is_err());
824    }
825
826    // ========================================
827    // RateLimitingMiddleware — client ID extraction
828    // ========================================
829
830    #[test]
831    fn rate_limiting_middleware_no_extractor_uses_global_key() {
832        let m = RateLimitingMiddleware::new(10.0);
833        let ctx = test_context();
834        let req = test_request("tools/call");
835        let id = m.get_client_identifier(&ctx, &req);
836        assert_eq!(id, "global");
837    }
838
839    #[test]
840    fn rate_limiting_middleware_extractor_returning_none_uses_global() {
841        let m = RateLimitingMiddleware::new(10.0).client_id_extractor(|_ctx, _req| None);
842        let ctx = test_context();
843        let req = test_request("tools/call");
844        let id = m.get_client_identifier(&ctx, &req);
845        assert_eq!(id, "global");
846    }
847
848    #[test]
849    fn rate_limiting_middleware_extractor_returning_some() {
850        let m = RateLimitingMiddleware::new(10.0)
851            .client_id_extractor(|_ctx, _req| Some("user-42".to_string()));
852        let ctx = test_context();
853        let req = test_request("tools/call");
854        let id = m.get_client_identifier(&ctx, &req);
855        assert_eq!(id, "user-42");
856    }
857
858    // ========================================
859    // RateLimitingMiddleware — per-client without extractor
860    // ========================================
861
862    #[test]
863    fn rate_limiting_middleware_per_client_no_extractor_all_share_global_key() {
864        // Without an extractor, per-client mode defaults all to "global"
865        let m = RateLimitingMiddleware::new(10.0).burst_capacity(2);
866        let ctx = test_context();
867        let req_a = test_request("method_a");
868        let req_b = test_request("method_b");
869
870        // Both methods share the same "global" bucket
871        assert!(m.on_request(&ctx, &req_a).is_ok());
872        assert!(m.on_request(&ctx, &req_b).is_ok());
873        // Bucket exhausted for both
874        assert!(m.on_request(&ctx, &req_a).is_err());
875    }
876
877    #[test]
878    fn rate_limiting_middleware_error_msg_per_client() {
879        let m = RateLimitingMiddleware::new(10.0)
880            .burst_capacity(1)
881            .client_id_extractor(|_ctx, _req| Some("alice".to_string()));
882        let ctx = test_context();
883        let req = test_request("tools/call");
884
885        m.on_request(&ctx, &req).unwrap();
886        let err = m.on_request(&ctx, &req).unwrap_err();
887        assert!(
888            err.message
889                .contains("Rate limit exceeded for client: alice")
890        );
891    }
892
893    #[test]
894    fn rate_limiting_middleware_error_msg_global() {
895        let m = RateLimitingMiddleware::new(10.0).burst_capacity(1).global();
896        let ctx = test_context();
897        let req = test_request("tools/call");
898
899        m.on_request(&ctx, &req).unwrap();
900        let err = m.on_request(&ctx, &req).unwrap_err();
901        assert!(err.message.contains("Global rate limit exceeded"));
902    }
903
904    // ========================================
905    // SlidingWindowRateLimitingMiddleware — construction/Debug
906    // ========================================
907
908    #[test]
909    fn sliding_window_middleware_new_fields() {
910        let m = SlidingWindowRateLimitingMiddleware::new(50, 120);
911        assert_eq!(m.max_requests, 50);
912        assert_eq!(m.window_seconds, 120);
913        assert!(m.get_client_id.is_none());
914    }
915
916    #[test]
917    fn sliding_window_middleware_per_minute_converts() {
918        let m = SlidingWindowRateLimitingMiddleware::per_minute(100, 5);
919        assert_eq!(m.max_requests, 100);
920        assert_eq!(m.window_seconds, 300); // 5 * 60
921    }
922
923    #[test]
924    fn sliding_window_middleware_debug() {
925        let m = SlidingWindowRateLimitingMiddleware::new(50, 120);
926        let debug = format!("{:?}", m);
927        assert!(debug.contains("SlidingWindowRateLimitingMiddleware"));
928        assert!(debug.contains("50"));
929        assert!(debug.contains("120"));
930    }
931
932    // ========================================
933    // SlidingWindowRateLimitingMiddleware — client ID
934    // ========================================
935
936    #[test]
937    fn sliding_window_middleware_no_extractor_uses_global() {
938        let m = SlidingWindowRateLimitingMiddleware::new(10, 60);
939        let ctx = test_context();
940        let req = test_request("tools/call");
941        let id = m.get_client_identifier(&ctx, &req);
942        assert_eq!(id, "global");
943    }
944
945    #[test]
946    fn sliding_window_middleware_extractor_returning_none_uses_global() {
947        let m =
948            SlidingWindowRateLimitingMiddleware::new(10, 60).client_id_extractor(|_ctx, _req| None);
949        let ctx = test_context();
950        let req = test_request("tools/call");
951        let id = m.get_client_identifier(&ctx, &req);
952        assert_eq!(id, "global");
953    }
954
955    #[test]
956    fn sliding_window_middleware_extractor_returning_some() {
957        let m = SlidingWindowRateLimitingMiddleware::new(10, 60)
958            .client_id_extractor(|_ctx, _req| Some("bob".to_string()));
959        let ctx = test_context();
960        let req = test_request("tools/call");
961        let id = m.get_client_identifier(&ctx, &req);
962        assert_eq!(id, "bob");
963    }
964
965    // ========================================
966    // SlidingWindowRateLimitingMiddleware — per-client
967    // ========================================
968
969    #[test]
970    fn sliding_window_middleware_per_client() {
971        let m = SlidingWindowRateLimitingMiddleware::new(1, 60)
972            .client_id_extractor(|_ctx, req| Some(req.method.clone()));
973        let ctx = test_context();
974        let req_a = test_request("method_a");
975        let req_b = test_request("method_b");
976
977        // Each client gets their own window
978        assert!(m.on_request(&ctx, &req_a).is_ok());
979        assert!(m.on_request(&ctx, &req_b).is_ok());
980
981        // Both exhausted
982        assert!(m.on_request(&ctx, &req_a).is_err());
983        assert!(m.on_request(&ctx, &req_b).is_err());
984    }
985
986    // ========================================
987    // SlidingWindowRateLimitingMiddleware — error messages
988    // ========================================
989
990    #[test]
991    fn sliding_window_middleware_error_msg_seconds() {
992        let m = SlidingWindowRateLimitingMiddleware::new(1, 30);
993        let ctx = test_context();
994        let req = test_request("tools/call");
995
996        m.on_request(&ctx, &req).unwrap();
997        let err = m.on_request(&ctx, &req).unwrap_err();
998        assert!(err.message.contains("30 second(s)"));
999        assert!(err.message.contains("client: global"));
1000    }
1001
1002    #[test]
1003    fn sliding_window_middleware_error_msg_minutes() {
1004        let m = SlidingWindowRateLimitingMiddleware::new(1, 120);
1005        let ctx = test_context();
1006        let req = test_request("tools/call");
1007
1008        m.on_request(&ctx, &req).unwrap();
1009        let err = m.on_request(&ctx, &req).unwrap_err();
1010        assert!(err.message.contains("2 minute(s)"));
1011    }
1012
1013    #[test]
1014    fn sliding_window_middleware_error_msg_with_client_id() {
1015        let m = SlidingWindowRateLimitingMiddleware::new(1, 60)
1016            .client_id_extractor(|_ctx, _req| Some("alice".to_string()));
1017        let ctx = test_context();
1018        let req = test_request("tools/call");
1019
1020        m.on_request(&ctx, &req).unwrap();
1021        let err = m.on_request(&ctx, &req).unwrap_err();
1022        assert!(err.message.contains("client: alice"));
1023        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
1024    }
1025
1026    // ========================================
1027    // Edge cases
1028    // ========================================
1029
1030    #[test]
1031    fn rate_limiting_middleware_get_or_create_limiter_creates_new() {
1032        let m = RateLimitingMiddleware::new(10.0).burst_capacity(2);
1033        // First call for a new client creates a limiter
1034        assert!(m.get_or_create_limiter("new-client"));
1035        // Second call reuses the same limiter
1036        assert!(m.get_or_create_limiter("new-client"));
1037        // Third call exhausts it
1038        assert!(!m.get_or_create_limiter("new-client"));
1039    }
1040
1041    #[test]
1042    fn sliding_window_middleware_is_request_allowed_creates_new() {
1043        let m = SlidingWindowRateLimitingMiddleware::new(2, 60);
1044        assert!(m.is_request_allowed("c1"));
1045        assert!(m.is_request_allowed("c1"));
1046        assert!(!m.is_request_allowed("c1"));
1047
1048        // Different client gets its own limiter
1049        assert!(m.is_request_allowed("c2"));
1050    }
1051
1052    #[test]
1053    fn sliding_window_requests_expire_after_window() {
1054        let limiter = SlidingWindowRateLimiter::new(2, 1); // 2 requests per 1 second
1055        assert!(limiter.is_allowed());
1056        assert!(limiter.is_allowed());
1057        assert!(!limiter.is_allowed()); // exhausted
1058
1059        // Wait for window to expire
1060        std::thread::sleep(std::time::Duration::from_millis(1100));
1061
1062        // Requests should be allowed again
1063        assert!(limiter.is_allowed());
1064    }
1065
1066    #[test]
1067    fn sliding_window_current_requests_resets_after_window() {
1068        let limiter = SlidingWindowRateLimiter::new(5, 1); // 1 second window
1069        limiter.is_allowed();
1070        limiter.is_allowed();
1071        assert_eq!(limiter.current_requests(), 2);
1072
1073        std::thread::sleep(std::time::Duration::from_millis(1100));
1074
1075        // Old requests should have expired
1076        assert_eq!(limiter.current_requests(), 0);
1077    }
1078
1079    #[test]
1080    fn sliding_window_error_exactly_60_seconds_shows_minutes() {
1081        let m = SlidingWindowRateLimitingMiddleware::new(1, 60);
1082        let ctx = test_context();
1083        let req = test_request("tools/call");
1084
1085        m.on_request(&ctx, &req).unwrap();
1086        let err = m.on_request(&ctx, &req).unwrap_err();
1087        assert!(
1088            err.message.contains("1 minute(s)"),
1089            "60 seconds should display as minutes: {}",
1090            err.message
1091        );
1092    }
1093
1094    #[test]
1095    fn token_bucket_try_consume_zero_always_succeeds() {
1096        let limiter = TokenBucketRateLimiter::new(3, 1.0);
1097        // Drain all tokens
1098        limiter.try_consume(3);
1099        assert!(!limiter.try_consume(1)); // exhausted
1100
1101        // Consuming zero should still succeed
1102        assert!(limiter.try_consume(0));
1103    }
1104
1105    #[test]
1106    fn token_bucket_refill_rate_zero_never_refills() {
1107        let limiter = TokenBucketRateLimiter::new(2, 0.0); // zero refill rate
1108        assert!(limiter.try_consume(2));
1109        assert!(!limiter.try_consume(1));
1110
1111        // Even after waiting, no refill
1112        std::thread::sleep(std::time::Duration::from_millis(50));
1113        assert!(!limiter.try_consume(1));
1114    }
1115}