Skip to main content

fastmcp_server/
rate_limiting.rs

1//! Rate limiting middleware for protecting FastMCP servers from abuse.
2//!
3//! This module provides two rate limiting strategies:
4//!
5//! - [`RateLimitingMiddleware`]: Token bucket algorithm for burst-friendly limits
6//! - [`SlidingWindowRateLimitingMiddleware`]: Sliding window for precise tracking
7//!
8//! # Example
9//!
10//! ```ignore
11//! use fastmcp::prelude::*;
12//! use fastmcp_server::rate_limiting::RateLimitingMiddleware;
13//!
14//! // Allow 10 requests per second with bursts up to 20
15//! let rate_limiter = RateLimitingMiddleware::new(10.0)
16//!     .burst_capacity(20);
17//!
18//! Server::new("my-server", "1.0.0")
19//!     .middleware(rate_limiter)
20//!     .run_stdio();
21//! ```
22
23use std::collections::{HashMap, VecDeque};
24use std::sync::Mutex;
25use std::time::Instant;
26
27use fastmcp_core::{McpContext, McpError, McpErrorCode, McpResult};
28use fastmcp_protocol::JsonRpcRequest;
29
30use crate::{Middleware, MiddlewareDecision};
31
32/// Error code for rate limit exceeded (-32005).
33///
34/// This is in the MCP server error range (-32000 to -32099).
35pub const RATE_LIMIT_ERROR_CODE: i32 = -32005;
36
37/// Creates a rate limit exceeded error.
38#[must_use]
39pub fn rate_limit_error(message: impl Into<String>) -> McpError {
40    McpError::new(McpErrorCode::Custom(RATE_LIMIT_ERROR_CODE), message)
41}
42
43/// Token bucket implementation for rate limiting.
44///
45/// The token bucket algorithm allows for burst traffic while maintaining
46/// a sustainable long-term rate. Tokens are added at a constant rate and
47/// consumed when requests arrive.
48#[derive(Debug)]
49pub struct TokenBucketRateLimiter {
50    /// Maximum number of tokens in the bucket.
51    capacity: usize,
52    /// Tokens added per second.
53    refill_rate: f64,
54    /// Current number of tokens (as f64 for fractional tokens).
55    tokens: Mutex<f64>,
56    /// Last time tokens were refilled.
57    last_refill: Mutex<Instant>,
58}
59
60impl TokenBucketRateLimiter {
61    /// Creates a new token bucket rate limiter.
62    ///
63    /// # Arguments
64    ///
65    /// * `capacity` - Maximum number of tokens (burst capacity)
66    /// * `refill_rate` - Tokens added per second (sustained rate)
67    #[must_use]
68    pub fn new(capacity: usize, refill_rate: f64) -> Self {
69        Self {
70            capacity,
71            refill_rate,
72            tokens: Mutex::new(capacity as f64),
73            last_refill: Mutex::new(Instant::now()),
74        }
75    }
76
77    /// Tries to consume tokens from the bucket.
78    ///
79    /// Returns `true` if tokens were available and consumed, `false` otherwise.
80    pub fn try_consume(&self, tokens: usize) -> bool {
81        let mut current_tokens = self
82            .tokens
83            .lock()
84            .unwrap_or_else(std::sync::PoisonError::into_inner);
85        let mut last_refill = self
86            .last_refill
87            .lock()
88            .unwrap_or_else(std::sync::PoisonError::into_inner);
89
90        let now = Instant::now();
91        let elapsed = now.duration_since(*last_refill).as_secs_f64();
92
93        // Add tokens based on elapsed time
94        *current_tokens = (*current_tokens + elapsed * self.refill_rate).min(self.capacity as f64);
95        *last_refill = now;
96
97        let tokens_needed = tokens as f64;
98        if *current_tokens >= tokens_needed {
99            *current_tokens -= tokens_needed;
100            true
101        } else {
102            false
103        }
104    }
105
106    /// Returns the current number of available tokens.
107    #[must_use]
108    pub fn available_tokens(&self) -> f64 {
109        let mut current_tokens = self
110            .tokens
111            .lock()
112            .unwrap_or_else(std::sync::PoisonError::into_inner);
113        let mut last_refill = self
114            .last_refill
115            .lock()
116            .unwrap_or_else(std::sync::PoisonError::into_inner);
117
118        let now = Instant::now();
119        let elapsed = now.duration_since(*last_refill).as_secs_f64();
120
121        // Update tokens without consuming
122        *current_tokens = (*current_tokens + elapsed * self.refill_rate).min(self.capacity as f64);
123        *last_refill = now;
124
125        *current_tokens
126    }
127}
128
129/// Sliding window rate limiter implementation.
130///
131/// Tracks individual request timestamps within a time window for precise
132/// rate limiting. More memory-intensive than token bucket but provides
133/// exact request counting.
134#[derive(Debug)]
135pub struct SlidingWindowRateLimiter {
136    /// Maximum requests allowed in the time window.
137    max_requests: usize,
138    /// Time window in seconds.
139    window_seconds: u64,
140    /// Request timestamps (as durations from a fixed start time).
141    requests: Mutex<VecDeque<Instant>>,
142}
143
144impl SlidingWindowRateLimiter {
145    /// Creates a new sliding window rate limiter.
146    ///
147    /// # Arguments
148    ///
149    /// * `max_requests` - Maximum requests allowed in the time window
150    /// * `window_seconds` - Time window duration in seconds
151    #[must_use]
152    pub fn new(max_requests: usize, window_seconds: u64) -> Self {
153        Self {
154            max_requests,
155            window_seconds,
156            requests: Mutex::new(VecDeque::new()),
157        }
158    }
159
160    /// Checks if a request is allowed under the rate limit.
161    ///
162    /// If allowed, records the request timestamp and returns `true`.
163    /// Otherwise returns `false`.
164    pub fn is_allowed(&self) -> bool {
165        let mut requests = self
166            .requests
167            .lock()
168            .unwrap_or_else(std::sync::PoisonError::into_inner);
169
170        let now = Instant::now();
171        let cutoff = now - std::time::Duration::from_secs(self.window_seconds);
172
173        // Remove old requests outside the window
174        while let Some(&oldest) = requests.front() {
175            if oldest < cutoff {
176                requests.pop_front();
177            } else {
178                break;
179            }
180        }
181
182        if requests.len() < self.max_requests {
183            requests.push_back(now);
184            true
185        } else {
186            false
187        }
188    }
189
190    /// Returns the current number of requests in the window.
191    #[must_use]
192    pub fn current_requests(&self) -> usize {
193        let mut requests = self
194            .requests
195            .lock()
196            .unwrap_or_else(std::sync::PoisonError::into_inner);
197
198        let now = Instant::now();
199        let cutoff = now - std::time::Duration::from_secs(self.window_seconds);
200
201        // Remove old requests outside the window
202        while let Some(&oldest) = requests.front() {
203            if oldest < cutoff {
204                requests.pop_front();
205            } else {
206                break;
207            }
208        }
209
210        requests.len()
211    }
212}
213
214/// Function type for extracting client ID from request context.
215pub type ClientIdExtractor =
216    Box<dyn Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync>;
217
218/// Rate limiting middleware using token bucket algorithm.
219///
220/// Uses a token bucket algorithm by default, allowing for burst traffic
221/// while maintaining a sustainable long-term rate.
222///
223/// # Example
224///
225/// ```ignore
226/// use fastmcp_server::rate_limiting::RateLimitingMiddleware;
227///
228/// // Allow 10 requests per second with bursts up to 20
229/// let rate_limiter = RateLimitingMiddleware::new(10.0)
230///     .burst_capacity(20);
231/// ```
232pub struct RateLimitingMiddleware {
233    /// Sustained requests per second allowed.
234    max_requests_per_second: f64,
235    /// Maximum burst capacity.
236    burst_capacity: usize,
237    /// Function to extract client ID from context (for per-client limiting).
238    get_client_id: Option<ClientIdExtractor>,
239    /// If true, apply limit globally; if false, per-client.
240    global_limit: bool,
241    /// Storage for rate limiters per client.
242    limiters: Mutex<HashMap<String, TokenBucketRateLimiter>>,
243    /// Global rate limiter (used when global_limit is true).
244    global_limiter: Option<TokenBucketRateLimiter>,
245}
246
247impl std::fmt::Debug for RateLimitingMiddleware {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        f.debug_struct("RateLimitingMiddleware")
250            .field("max_requests_per_second", &self.max_requests_per_second)
251            .field("burst_capacity", &self.burst_capacity)
252            .field("global_limit", &self.global_limit)
253            .finish()
254    }
255}
256
257impl RateLimitingMiddleware {
258    /// Creates a new rate limiting middleware with the specified rate.
259    ///
260    /// # Arguments
261    ///
262    /// * `max_requests_per_second` - Sustained requests per second allowed
263    ///
264    /// Burst capacity defaults to 2x the sustained rate.
265    #[must_use]
266    pub fn new(max_requests_per_second: f64) -> Self {
267        let burst_capacity = (max_requests_per_second * 2.0) as usize;
268        Self {
269            max_requests_per_second,
270            burst_capacity,
271            get_client_id: None,
272            global_limit: false,
273            limiters: Mutex::new(HashMap::new()),
274            global_limiter: None,
275        }
276    }
277
278    /// Sets the burst capacity (maximum tokens in the bucket).
279    #[must_use]
280    pub fn burst_capacity(mut self, capacity: usize) -> Self {
281        self.burst_capacity = capacity;
282        // Re-create global limiter if it exists
283        if self.global_limit {
284            self.global_limiter = Some(TokenBucketRateLimiter::new(
285                capacity,
286                self.max_requests_per_second,
287            ));
288        }
289        self
290    }
291
292    /// Sets a custom function to extract client ID from the request context.
293    ///
294    /// If not set, all clients share a single rate limit (global limiting).
295    #[must_use]
296    pub fn client_id_extractor<F>(mut self, extractor: F) -> Self
297    where
298        F: Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync + 'static,
299    {
300        self.get_client_id = Some(Box::new(extractor));
301        self
302    }
303
304    /// Enables global rate limiting (all clients share one limit).
305    ///
306    /// When enabled, all requests count against a single rate limit
307    /// regardless of client identity.
308    #[must_use]
309    pub fn global(mut self) -> Self {
310        self.global_limit = true;
311        self.global_limiter = Some(TokenBucketRateLimiter::new(
312            self.burst_capacity,
313            self.max_requests_per_second,
314        ));
315        self
316    }
317
318    fn get_client_identifier(&self, ctx: &McpContext, request: &JsonRpcRequest) -> String {
319        if let Some(ref extractor) = self.get_client_id {
320            if let Some(id) = extractor(ctx, request) {
321                return id;
322            }
323        }
324        "global".to_string()
325    }
326
327    fn get_or_create_limiter(&self, client_id: &str) -> bool {
328        let mut limiters = self
329            .limiters
330            .lock()
331            .unwrap_or_else(std::sync::PoisonError::into_inner);
332
333        if !limiters.contains_key(client_id) {
334            limiters.insert(
335                client_id.to_string(),
336                TokenBucketRateLimiter::new(self.burst_capacity, self.max_requests_per_second),
337            );
338        }
339
340        limiters.get(client_id).unwrap().try_consume(1)
341    }
342}
343
344impl Middleware for RateLimitingMiddleware {
345    fn on_request(
346        &self,
347        ctx: &McpContext,
348        request: &JsonRpcRequest,
349    ) -> McpResult<MiddlewareDecision> {
350        let allowed = if self.global_limit {
351            // Global rate limiting
352            if let Some(ref limiter) = self.global_limiter {
353                limiter.try_consume(1)
354            } else {
355                true
356            }
357        } else {
358            // Per-client rate limiting
359            let client_id = self.get_client_identifier(ctx, request);
360            self.get_or_create_limiter(&client_id)
361        };
362
363        if allowed {
364            Ok(MiddlewareDecision::Continue)
365        } else {
366            let msg = if self.global_limit {
367                "Global rate limit exceeded".to_string()
368            } else {
369                let client_id = self.get_client_identifier(ctx, request);
370                format!("Rate limit exceeded for client: {client_id}")
371            };
372            Err(rate_limit_error(msg))
373        }
374    }
375}
376
377/// Rate limiting middleware using sliding window algorithm.
378///
379/// Uses a sliding window approach which provides more precise rate limiting
380/// but uses more memory to track individual request timestamps.
381///
382/// # Example
383///
384/// ```ignore
385/// use fastmcp_server::rate_limiting::SlidingWindowRateLimitingMiddleware;
386///
387/// // Allow 100 requests per minute
388/// let rate_limiter = SlidingWindowRateLimitingMiddleware::new(100, 60);
389/// ```
390pub struct SlidingWindowRateLimitingMiddleware {
391    /// Maximum requests allowed in the time window.
392    max_requests: usize,
393    /// Time window in seconds.
394    window_seconds: u64,
395    /// Function to extract client ID from context.
396    get_client_id: Option<ClientIdExtractor>,
397    /// Storage for rate limiters per client.
398    limiters: Mutex<HashMap<String, SlidingWindowRateLimiter>>,
399}
400
401impl std::fmt::Debug for SlidingWindowRateLimitingMiddleware {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        f.debug_struct("SlidingWindowRateLimitingMiddleware")
404            .field("max_requests", &self.max_requests)
405            .field("window_seconds", &self.window_seconds)
406            .finish()
407    }
408}
409
410impl SlidingWindowRateLimitingMiddleware {
411    /// Creates a new sliding window rate limiting middleware.
412    ///
413    /// # Arguments
414    ///
415    /// * `max_requests` - Maximum requests allowed in the time window
416    /// * `window_seconds` - Time window duration in seconds
417    #[must_use]
418    pub fn new(max_requests: usize, window_seconds: u64) -> Self {
419        Self {
420            max_requests,
421            window_seconds,
422            get_client_id: None,
423            limiters: Mutex::new(HashMap::new()),
424        }
425    }
426
427    /// Creates a sliding window rate limiter with minutes-based window.
428    ///
429    /// # Arguments
430    ///
431    /// * `max_requests` - Maximum requests allowed in the time window
432    /// * `window_minutes` - Time window duration in minutes
433    #[must_use]
434    pub fn per_minute(max_requests: usize, window_minutes: u64) -> Self {
435        Self::new(max_requests, window_minutes * 60)
436    }
437
438    /// Sets a custom function to extract client ID from the request context.
439    #[must_use]
440    pub fn client_id_extractor<F>(mut self, extractor: F) -> Self
441    where
442        F: Fn(&McpContext, &JsonRpcRequest) -> Option<String> + Send + Sync + 'static,
443    {
444        self.get_client_id = Some(Box::new(extractor));
445        self
446    }
447
448    fn get_client_identifier(&self, ctx: &McpContext, request: &JsonRpcRequest) -> String {
449        if let Some(ref extractor) = self.get_client_id {
450            if let Some(id) = extractor(ctx, request) {
451                return id;
452            }
453        }
454        "global".to_string()
455    }
456
457    fn is_request_allowed(&self, client_id: &str) -> bool {
458        let mut limiters = self
459            .limiters
460            .lock()
461            .unwrap_or_else(std::sync::PoisonError::into_inner);
462
463        if !limiters.contains_key(client_id) {
464            limiters.insert(
465                client_id.to_string(),
466                SlidingWindowRateLimiter::new(self.max_requests, self.window_seconds),
467            );
468        }
469
470        limiters.get(client_id).unwrap().is_allowed()
471    }
472}
473
474impl Middleware for SlidingWindowRateLimitingMiddleware {
475    fn on_request(
476        &self,
477        ctx: &McpContext,
478        request: &JsonRpcRequest,
479    ) -> McpResult<MiddlewareDecision> {
480        let client_id = self.get_client_identifier(ctx, request);
481        let allowed = self.is_request_allowed(&client_id);
482
483        if allowed {
484            Ok(MiddlewareDecision::Continue)
485        } else {
486            let window_display = if self.window_seconds >= 60 {
487                format!("{} minute(s)", self.window_seconds / 60)
488            } else {
489                format!("{} second(s)", self.window_seconds)
490            };
491            Err(rate_limit_error(format!(
492                "Rate limit exceeded: {} requests per {} for client: {}",
493                self.max_requests, window_display, client_id
494            )))
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use asupersync::Cx;
503
504    fn test_context() -> McpContext {
505        let cx = Cx::for_testing();
506        McpContext::new(cx, 1)
507    }
508
509    fn test_request(method: &str) -> JsonRpcRequest {
510        JsonRpcRequest {
511            jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
512            method: method.to_string(),
513            params: None,
514            id: Some(fastmcp_protocol::RequestId::Number(1)),
515        }
516    }
517
518    // ========================================
519    // TokenBucketRateLimiter tests
520    // ========================================
521
522    #[test]
523    fn test_token_bucket_allows_burst() {
524        let limiter = TokenBucketRateLimiter::new(5, 1.0);
525
526        // Should allow burst up to capacity
527        assert!(limiter.try_consume(1));
528        assert!(limiter.try_consume(1));
529        assert!(limiter.try_consume(1));
530        assert!(limiter.try_consume(1));
531        assert!(limiter.try_consume(1));
532
533        // Should deny once capacity exhausted
534        assert!(!limiter.try_consume(1));
535    }
536
537    #[test]
538    fn test_token_bucket_refills_over_time() {
539        let limiter = TokenBucketRateLimiter::new(2, 100.0); // 100 tokens per second
540
541        // Exhaust tokens
542        assert!(limiter.try_consume(1));
543        assert!(limiter.try_consume(1));
544        assert!(!limiter.try_consume(1));
545
546        // Wait for refill (10ms should add ~1 token at 100 t/s)
547        std::thread::sleep(std::time::Duration::from_millis(15));
548
549        // Should have refilled
550        assert!(limiter.try_consume(1));
551    }
552
553    #[test]
554    fn test_token_bucket_available_tokens() {
555        let limiter = TokenBucketRateLimiter::new(10, 1.0);
556        assert!((limiter.available_tokens() - 10.0).abs() < 0.1);
557
558        limiter.try_consume(5);
559        assert!((limiter.available_tokens() - 5.0).abs() < 0.1);
560    }
561
562    // ========================================
563    // SlidingWindowRateLimiter tests
564    // ========================================
565
566    #[test]
567    fn test_sliding_window_allows_up_to_limit() {
568        let limiter = SlidingWindowRateLimiter::new(3, 60);
569
570        assert!(limiter.is_allowed());
571        assert!(limiter.is_allowed());
572        assert!(limiter.is_allowed());
573        assert!(!limiter.is_allowed()); // Fourth request denied
574    }
575
576    #[test]
577    fn test_sliding_window_current_requests() {
578        let limiter = SlidingWindowRateLimiter::new(10, 60);
579
580        assert_eq!(limiter.current_requests(), 0);
581        limiter.is_allowed();
582        assert_eq!(limiter.current_requests(), 1);
583        limiter.is_allowed();
584        assert_eq!(limiter.current_requests(), 2);
585    }
586
587    // ========================================
588    // RateLimitingMiddleware tests
589    // ========================================
590
591    #[test]
592    fn test_rate_limiting_middleware_allows_initial_requests() {
593        let middleware = RateLimitingMiddleware::new(10.0).global();
594        let ctx = test_context();
595        let request = test_request("tools/call");
596
597        let result = middleware.on_request(&ctx, &request);
598        assert!(matches!(result, Ok(MiddlewareDecision::Continue)));
599    }
600
601    #[test]
602    fn test_rate_limiting_middleware_denies_after_burst() {
603        let middleware = RateLimitingMiddleware::new(10.0).burst_capacity(2).global();
604        let ctx = test_context();
605        let request = test_request("tools/call");
606
607        // First two should succeed (burst capacity = 2)
608        assert!(middleware.on_request(&ctx, &request).is_ok());
609        assert!(middleware.on_request(&ctx, &request).is_ok());
610
611        // Third should fail
612        let result = middleware.on_request(&ctx, &request);
613        assert!(result.is_err());
614        let err = result.unwrap_err();
615        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
616        assert!(err.message.contains("Global rate limit exceeded"));
617    }
618
619    #[test]
620    fn test_rate_limiting_middleware_per_client() {
621        let middleware = RateLimitingMiddleware::new(10.0)
622            .burst_capacity(1)
623            .client_id_extractor(|_ctx, req| Some(req.method.clone()));
624        let ctx = test_context();
625
626        let request1 = test_request("method_a");
627        let request2 = test_request("method_b");
628
629        // Each "client" (method) gets their own bucket
630        assert!(middleware.on_request(&ctx, &request1).is_ok());
631        assert!(middleware.on_request(&ctx, &request2).is_ok());
632
633        // Now both are exhausted
634        assert!(middleware.on_request(&ctx, &request1).is_err());
635        assert!(middleware.on_request(&ctx, &request2).is_err());
636    }
637
638    // ========================================
639    // SlidingWindowRateLimitingMiddleware tests
640    // ========================================
641
642    #[test]
643    fn test_sliding_window_middleware_allows_up_to_limit() {
644        let middleware = SlidingWindowRateLimitingMiddleware::new(2, 60);
645        let ctx = test_context();
646        let request = test_request("tools/call");
647
648        assert!(middleware.on_request(&ctx, &request).is_ok());
649        assert!(middleware.on_request(&ctx, &request).is_ok());
650
651        let result = middleware.on_request(&ctx, &request);
652        assert!(result.is_err());
653        let err = result.unwrap_err();
654        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
655    }
656
657    #[test]
658    fn test_sliding_window_middleware_per_minute() {
659        let middleware = SlidingWindowRateLimitingMiddleware::per_minute(100, 1);
660        let ctx = test_context();
661        let request = test_request("tools/call");
662
663        // Should allow many requests
664        for _ in 0..100 {
665            assert!(middleware.on_request(&ctx, &request).is_ok());
666        }
667
668        // 101st should fail
669        assert!(middleware.on_request(&ctx, &request).is_err());
670    }
671
672    #[test]
673    fn test_rate_limit_error_code() {
674        let err = rate_limit_error("test");
675        assert_eq!(i32::from(err.code), RATE_LIMIT_ERROR_CODE);
676        assert_eq!(err.message, "test");
677    }
678}