Skip to main content

do_memory_mcp/server/
rate_limiter.rs

1//! Rate limiter for MCP server
2//!
3//! This module provides token bucket-based rate limiting to prevent DoS attacks.
4//! Features:
5//! - Per-client rate limiting (by IP or client ID)
6//! - Token bucket algorithm for smooth rate limiting
7//! - Different limits for read vs write operations
8//! - Configurable via environment variables
9//! - Rate limit headers in responses
10
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::time::{Duration, Instant};
14use tracing::{debug, trace, warn};
15
16mod types;
17pub use types::*;
18
19/// Token bucket for rate limiting
20#[derive(Debug)]
21struct TokenBucket {
22    /// Current number of tokens
23    tokens: f64,
24    /// Maximum burst size
25    capacity: u32,
26    /// Tokens added per second
27    refill_rate: f64,
28    /// Last time tokens were refilled
29    last_refill: Instant,
30    /// Last time this bucket was accessed
31    last_accessed: Instant,
32}
33
34impl TokenBucket {
35    /// Create a new token bucket
36    fn new(requests_per_second: u32, burst_size: u32) -> Self {
37        let now = Instant::now();
38        Self {
39            tokens: burst_size as f64,
40            capacity: burst_size,
41            refill_rate: requests_per_second as f64,
42            last_refill: now,
43            last_accessed: now,
44        }
45    }
46
47    /// Refill tokens based on elapsed time
48    fn refill(&mut self) {
49        let now = Instant::now();
50        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
51        let tokens_to_add = elapsed * self.refill_rate;
52
53        self.tokens = (self.tokens + tokens_to_add).min(self.capacity as f64);
54        self.last_refill = now;
55        self.last_accessed = now;
56    }
57
58    /// Try to consume tokens from the bucket
59    /// Returns true if tokens were consumed, false if rate limited
60    fn try_consume(&mut self, tokens: u32) -> bool {
61        self.refill();
62
63        if self.tokens >= tokens as f64 {
64            self.tokens -= tokens as f64;
65            true
66        } else {
67            false
68        }
69    }
70
71    /// Get current token count
72    fn tokens(&mut self) -> u32 {
73        self.refill();
74        self.tokens as u32
75    }
76
77    /// Check if this bucket is stale (not accessed for a while)
78    #[allow(dead_code)] // Utility for bucket cleanup in rate limiter
79    fn is_stale(&self, timeout: Duration) -> bool {
80        Instant::now().duration_since(self.last_accessed) > timeout
81    }
82
83    /// Get time until next token is available
84    fn time_until_next_token(&self) -> Duration {
85        if self.tokens >= 1.0 {
86            Duration::from_secs(0)
87        } else {
88            let tokens_needed = 1.0 - self.tokens;
89            let seconds = tokens_needed / self.refill_rate;
90            Duration::from_secs_f64(seconds)
91        }
92    }
93}
94
95/// Rate limiter using token bucket algorithm
96#[derive(Debug)]
97pub struct RateLimiter {
98    /// Configuration
99    pub config: RateLimitConfig,
100    /// Token buckets for read operations per client
101    read_buckets: RwLock<HashMap<ClientId, TokenBucket>>,
102    /// Token buckets for write operations per client
103    write_buckets: RwLock<HashMap<ClientId, TokenBucket>>,
104}
105
106impl RateLimiter {
107    /// Create a new rate limiter with the given configuration
108    pub fn new(config: RateLimitConfig) -> Self {
109        let limiter = Self {
110            config,
111            read_buckets: RwLock::new(HashMap::new()),
112            write_buckets: RwLock::new(HashMap::new()),
113        };
114
115        // Spawn cleanup task if enabled
116        if limiter.config.enabled {
117            limiter.spawn_cleanup_task();
118        }
119
120        limiter
121    }
122
123    /// Create a rate limiter from environment configuration
124    pub fn from_env() -> Self {
125        Self::new(RateLimitConfig::from_env())
126    }
127
128    /// Check if rate limiting is enabled
129    pub fn is_enabled(&self) -> bool {
130        self.config.enabled
131    }
132
133    /// Check if a request should be rate limited
134    pub fn check_rate_limit(
135        &self,
136        client_id: &ClientId,
137        operation: OperationType,
138    ) -> RateLimitResult {
139        if !self.config.enabled {
140            return RateLimitResult {
141                allowed: true,
142                remaining: u32::MAX,
143                reset_after: Duration::from_secs(0),
144                limit: u32::MAX,
145                retry_after: None,
146            };
147        }
148
149        let (rps, burst) = match operation {
150            OperationType::Read => (
151                self.config.read_requests_per_second,
152                self.config.read_burst_size,
153            ),
154            OperationType::Write => (
155                self.config.write_requests_per_second,
156                self.config.write_burst_size,
157            ),
158        };
159
160        let buckets = match operation {
161            OperationType::Read => &self.read_buckets,
162            OperationType::Write => &self.write_buckets,
163        };
164
165        let mut buckets_guard = buckets.write();
166
167        // Get or create bucket for this client
168        let bucket = buckets_guard.entry(client_id.clone()).or_insert_with(|| {
169            trace!("Creating new rate limit bucket for client: {}", client_id);
170            TokenBucket::new(rps, burst)
171        });
172
173        // Try to consume one token
174        let allowed = bucket.try_consume(1);
175        let remaining = bucket.tokens();
176        let reset_after = bucket.time_until_next_token();
177
178        if allowed {
179            trace!(
180                "Rate limit check passed for client: {} (op: {:?}, remaining: {})",
181                client_id, operation, remaining
182            );
183            RateLimitResult {
184                allowed: true,
185                remaining,
186                reset_after,
187                limit: burst,
188                retry_after: None,
189            }
190        } else {
191            let retry_after = bucket.time_until_next_token();
192            warn!(
193                "Rate limit exceeded for client: {} (op: {:?}, retry_after: {:?})",
194                client_id, operation, retry_after
195            );
196            RateLimitResult {
197                allowed: false,
198                remaining: 0,
199                reset_after,
200                limit: burst,
201                retry_after: Some(retry_after),
202            }
203        }
204    }
205
206    /// Get rate limit headers for a successful request
207    pub fn get_headers(&self, result: &RateLimitResult) -> Vec<(String, String)> {
208        vec![
209            ("X-RateLimit-Limit".to_string(), result.limit.to_string()),
210            (
211                "X-RateLimit-Remaining".to_string(),
212                result.remaining.to_string(),
213            ),
214            (
215                "X-RateLimit-Reset".to_string(),
216                result.reset_after.as_secs().to_string(),
217            ),
218        ]
219    }
220
221    /// Get rate limit headers for a rate-limited response
222    pub fn get_rate_limited_headers(&self, result: &RateLimitResult) -> Vec<(String, String)> {
223        let mut headers = self.get_headers(result);
224        if let Some(retry_after) = result.retry_after {
225            headers.push(("Retry-After".to_string(), retry_after.as_secs().to_string()));
226        }
227        headers
228    }
229
230    /// Spawn a background task to clean up stale buckets
231    fn spawn_cleanup_task(&self) {
232        // This is a placeholder for the cleanup task
233        // In a production implementation, we would use a background task
234        // to periodically clean up stale buckets. For now, we rely on
235        // lazy cleanup during check_rate_limit calls.
236        debug!("Rate limiter cleanup task registered (lazy cleanup enabled)");
237    }
238
239    /// Get current statistics about the rate limiter
240    pub fn get_stats(&self) -> RateLimiterStats {
241        RateLimiterStats {
242            read_buckets_count: self.read_buckets.read().len(),
243            write_buckets_count: self.write_buckets.read().len(),
244            enabled: self.config.enabled,
245            read_config: (
246                self.config.read_requests_per_second,
247                self.config.read_burst_size,
248            ),
249            write_config: (
250                self.config.write_requests_per_second,
251                self.config.write_burst_size,
252            ),
253        }
254    }
255
256    /// Manually clean up stale buckets (for testing)
257    #[cfg(test)]
258    pub fn cleanup_stale_buckets(&self, stale_threshold: Duration) {
259        // Clean up stale read buckets
260        {
261            let mut read_guard = self.read_buckets.write();
262            let stale_clients: Vec<ClientId> = read_guard
263                .iter()
264                .filter(|(_, bucket)| bucket.is_stale(stale_threshold))
265                .map(|(client_id, _)| client_id.clone())
266                .collect();
267
268            for client_id in stale_clients {
269                debug!("Removing stale rate limit bucket for client: {}", client_id);
270                read_guard.remove(&client_id);
271            }
272        }
273
274        // Clean up stale write buckets
275        {
276            let mut write_guard = self.write_buckets.write();
277            let stale_clients: Vec<ClientId> = write_guard
278                .iter()
279                .filter(|(_, bucket)| bucket.is_stale(stale_threshold))
280                .map(|(client_id, _)| client_id.clone())
281                .collect();
282
283            for client_id in stale_clients {
284                debug!("Removing stale rate limit bucket for client: {}", client_id);
285                write_guard.remove(&client_id);
286            }
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_token_bucket_basic() {
297        let mut bucket = TokenBucket::new(10, 20);
298        assert_eq!(bucket.tokens(), 20);
299
300        // Consume some tokens
301        assert!(bucket.try_consume(5));
302        assert_eq!(bucket.tokens(), 15);
303
304        // Consume all remaining
305        assert!(bucket.try_consume(15));
306        assert_eq!(bucket.tokens(), 0);
307
308        // Should fail when empty
309        assert!(!bucket.try_consume(1));
310    }
311
312    #[test]
313    fn test_rate_limiter_disabled() {
314        let config = RateLimitConfig {
315            enabled: false,
316            ..Default::default()
317        };
318        let limiter = RateLimiter::new(config);
319
320        let client_id = ClientId::from_string("test");
321        let result = limiter.check_rate_limit(&client_id, OperationType::Read);
322
323        assert!(result.allowed);
324        assert_eq!(result.remaining, u32::MAX);
325    }
326
327    #[test]
328    fn test_rate_limiter_basic() {
329        let config = RateLimitConfig {
330            enabled: true,
331            read_requests_per_second: 10,
332            read_burst_size: 5,
333            write_requests_per_second: 5,
334            write_burst_size: 3,
335            cleanup_interval: Duration::from_secs(60),
336            stale_threshold: Duration::from_secs(300),
337            client_id_header: "X-Client-ID".to_string(),
338        };
339        let limiter = RateLimiter::new(config);
340
341        let client_id = ClientId::from_string("test");
342
343        // Should allow burst size requests
344        for i in 0..5 {
345            let result = limiter.check_rate_limit(&client_id, OperationType::Read);
346            assert!(result.allowed, "Request {} should be allowed", i);
347        }
348
349        // 6th request should be rate limited
350        let result = limiter.check_rate_limit(&client_id, OperationType::Read);
351        assert!(!result.allowed);
352        assert!(result.retry_after.is_some());
353    }
354
355    #[test]
356    fn test_rate_limit_headers() {
357        let config = RateLimitConfig::default();
358        let limiter = RateLimiter::new(config);
359
360        let result = RateLimitResult {
361            allowed: true,
362            remaining: 50,
363            reset_after: Duration::from_secs(30),
364            limit: 100,
365            retry_after: None,
366        };
367
368        let headers = limiter.get_headers(&result);
369        assert!(
370            headers
371                .iter()
372                .any(|(k, v)| k == "X-RateLimit-Limit" && v == "100")
373        );
374        assert!(
375            headers
376                .iter()
377                .any(|(k, v)| k == "X-RateLimit-Remaining" && v == "50")
378        );
379        assert!(
380            headers
381                .iter()
382                .any(|(k, v)| k == "X-RateLimit-Reset" && v == "30")
383        );
384    }
385
386    #[test]
387    fn test_rate_limited_headers() {
388        let config = RateLimitConfig::default();
389        let limiter = RateLimiter::new(config);
390
391        let result = RateLimitResult {
392            allowed: false,
393            remaining: 0,
394            reset_after: Duration::from_secs(60),
395            limit: 100,
396            retry_after: Some(Duration::from_secs(5)),
397        };
398
399        let headers = limiter.get_rate_limited_headers(&result);
400        assert!(headers.iter().any(|(k, v)| k == "Retry-After" && v == "5"));
401    }
402}