kindly_guard_server/rate_limit.rs
1// Copyright 2025 Kindly Software Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Rate limiting for MCP server requests
15//! Implements token bucket algorithm with per-client and per-operation limits
16
17use anyhow::Result;
18use parking_lot::Mutex;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tokio::sync::RwLock;
24
25/// Rate limiting configuration
26///
27/// # Security Implications
28///
29/// Rate limiting is essential for preventing abuse and DoS attacks:
30/// - **Prevents brute force attacks** - Limits authentication attempts
31/// - **Protects against resource exhaustion** - Controls request rates
32/// - **Mitigates data harvesting** - Slows down automated scraping
33/// - **Adaptive penalties** - Automatically restricts suspicious clients
34///
35/// # Example: Secure Production Configuration
36///
37/// ```toml
38/// [rate_limit]
39/// enabled = true
40/// default_rpm = 60 # 1 request per second average
41/// burst_capacity = 10 # Allow short bursts
42/// cleanup_interval_secs = 300
43/// adaptive = true # Auto-adjust based on threats
44/// threat_penalty_multiplier = 0.5 # Halve limits for threats
45///
46/// [rate_limit.method_limits]
47/// "tools/list" = { rpm = 120, burst = 20 } # Read operations
48/// "tools/call" = { rpm = 30, burst = 5 } # Execution operations
49/// "security/neutralize" = { rpm = 10, burst = 2 } # Sensitive operations
50///
51/// [rate_limit.client_limits]
52/// "trusted-app" = { rpm = 300, burst = 50, priority = "high" }
53/// "public-api" = { rpm = 30, burst = 5, priority = "low" }
54/// ```
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct RateLimitConfig {
57 /// Enable rate limiting
58 ///
59 /// **Default**: false (for easier testing)
60 /// **Security**: MUST be true in production to prevent abuse.
61 /// Without rate limiting, attackers can overwhelm the service.
62 /// **Warning**: Disabling exposes you to DoS and brute force attacks
63 pub enabled: bool,
64
65 /// Default requests per minute
66 ///
67 /// **Default**: 60 (1 per second average)
68 /// **Security**: Lower values are more secure but may impact usability.
69 /// Consider your threat model and legitimate usage patterns.
70 /// **Range**: 10-600 (recommend 30-120 for most APIs)
71 pub default_rpm: u32,
72
73 /// Burst capacity (tokens available immediately)
74 ///
75 /// **Default**: 10
76 /// **Security**: Allows legitimate burst traffic while preventing abuse.
77 /// Too high enables rapid attacks; too low impacts user experience.
78 /// **Range**: 1-50 (recommend 5-20, should be < default_rpm/6)
79 pub burst_capacity: u32,
80
81 /// Per-method rate limits (overrides default)
82 ///
83 /// **Default**: Sensible limits for common operations
84 /// **Security**: Set stricter limits on sensitive operations.
85 /// Read operations can have higher limits than write operations.
86 /// **Best Practice**: Order from least to most sensitive
87 pub method_limits: HashMap<String, MethodLimit>,
88
89 /// Per-client rate limits (by client ID)
90 ///
91 /// **Default**: Empty (all clients use default limits)
92 /// **Security**: Assign higher limits only to trusted clients.
93 /// Use priority levels to ensure critical clients aren't blocked.
94 /// **Warning**: Overly generous limits can be exploited
95 pub client_limits: HashMap<String, ClientLimit>,
96
97 /// Clean up interval for expired buckets (seconds)
98 ///
99 /// **Default**: 300 (5 minutes)
100 /// **Security**: Regular cleanup prevents memory exhaustion.
101 /// Shorter intervals use more CPU but free memory faster.
102 /// **Range**: 60-3600 (recommend 300-900)
103 pub cleanup_interval_secs: u64,
104
105 /// Enable adaptive rate limiting based on load
106 ///
107 /// **Default**: false
108 /// **Security**: Automatically tightens limits under attack.
109 /// Reduces false positives during traffic spikes.
110 /// **Trade-off**: Adds complexity but improves resilience
111 pub adaptive: bool,
112
113 /// Penalty for security threats (multiplier)
114 ///
115 /// **Default**: 0.5 (halve the rate limit)
116 /// **Security**: Clients triggering security alerts get reduced limits.
117 /// Helps contain attacks while allowing recovery for false positives.
118 /// **Range**: 0.1-1.0 (0.1 = 90% reduction, 1.0 = no penalty)
119 pub threat_penalty_multiplier: f32,
120
121 /// Whitelist of client IDs exempt from rate limiting
122 ///
123 /// **Default**: Empty set
124 /// **Security**: Only whitelist fully trusted internal clients.
125 /// Whitelisted clients can still trigger other security measures.
126 /// **Warning**: Use sparingly - prefer higher rate limits over exemption
127 #[serde(default)]
128 pub whitelist: HashSet<String>,
129
130 /// Blacklist of client IDs to always block
131 ///
132 /// **Default**: Empty set
133 /// **Security**: Immediately reject requests from blacklisted clients.
134 /// Useful for blocking known malicious actors or compromised credentials.
135 /// **Note**: Blacklist takes precedence over whitelist
136 #[serde(default)]
137 pub blacklist: HashSet<String>,
138
139 /// IP-specific rate limits
140 ///
141 /// **Default**: Empty map
142 /// **Security**: Set stricter limits for suspicious IP ranges.
143 /// Useful for geographic restrictions or known problematic networks.
144 /// **Format**: IP address or CIDR notation as key
145 #[serde(default)]
146 pub ip_limits: HashMap<String, IpLimit>,
147
148 /// Global requests per minute limit across all clients
149 ///
150 /// **Default**: None (no global limit)
151 /// **Security**: Prevents total system overload regardless of client distribution.
152 /// Individual client limits still apply within the global limit.
153 /// **Range**: 100-10000 (recommend 10x expected peak traffic)
154 pub global_rpm: Option<u32>,
155
156 /// Method for tracking clients
157 ///
158 /// **Default**: ClientId
159 /// **Security**: Determines how rate limits are applied.
160 /// IP-based tracking is more strict but may affect legitimate shared IPs.
161 /// **Options**: ClientId, IpAddress, Combined (both must pass)
162 #[serde(default)]
163 pub track_by: TrackingMethod,
164}
165
166/// IP-specific rate limit
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct IpLimit {
169 pub rpm: u32,
170 pub burst: u32,
171 pub block: bool,
172}
173
174/// Client tracking method
175#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
176#[serde(rename_all = "snake_case")]
177pub enum TrackingMethod {
178 /// Track by client ID only
179 ClientId,
180 /// Track by IP address only
181 IpAddress,
182 /// Track by both (both must pass rate limits)
183 Combined,
184}
185
186impl Default for TrackingMethod {
187 fn default() -> Self {
188 Self::ClientId
189 }
190}
191
192/// Method-specific rate limit
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct MethodLimit {
195 pub rpm: u32,
196 pub burst: u32,
197}
198
199/// Client-specific rate limit
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ClientLimit {
202 pub rpm: u32,
203 pub burst: u32,
204 pub priority: ClientPriority,
205}
206
207/// Client priority levels
208#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
209#[serde(rename_all = "snake_case")]
210pub enum ClientPriority {
211 Low = 0,
212 Normal = 1,
213 High = 2,
214 Premium = 3,
215}
216
217impl Default for RateLimitConfig {
218 fn default() -> Self {
219 let mut method_limits = HashMap::new();
220
221 // Higher limits for read operations
222 method_limits.insert("tools/list".to_string(), MethodLimit { rpm: 60, burst: 10 });
223 method_limits.insert(
224 "resources/list".to_string(),
225 MethodLimit { rpm: 60, burst: 10 },
226 );
227
228 // Lower limits for execution operations
229 method_limits.insert("tools/call".to_string(), MethodLimit { rpm: 30, burst: 5 });
230
231 // Very low limits for security-sensitive operations
232 method_limits.insert(
233 "security/threats".to_string(),
234 MethodLimit { rpm: 10, burst: 2 },
235 );
236
237 Self {
238 enabled: false,
239 default_rpm: 60,
240 burst_capacity: 10,
241 method_limits,
242 client_limits: HashMap::new(),
243 cleanup_interval_secs: 300, // 5 minutes
244 adaptive: false,
245 threat_penalty_multiplier: 0.5, // Halve rate limit on threat detection
246 whitelist: HashSet::new(),
247 blacklist: HashSet::new(),
248 ip_limits: HashMap::new(),
249 global_rpm: None,
250 track_by: TrackingMethod::default(),
251 }
252 }
253}
254
255/// Token bucket for rate limiting
256#[derive(Debug)]
257struct TokenBucket {
258 /// Maximum tokens (burst capacity)
259 capacity: f64,
260
261 /// Current tokens available
262 tokens: f64,
263
264 /// Refill rate (tokens per second)
265 refill_rate: f64,
266
267 /// Last refill time
268 last_refill: Instant,
269
270 /// Penalty factor (0.0 to 1.0, where 1.0 is normal)
271 penalty_factor: f64,
272}
273
274impl TokenBucket {
275 /// Create a new token bucket
276 fn new(rpm: u32, burst: u32) -> Self {
277 let capacity = f64::from(burst);
278 let refill_rate = f64::from(rpm) / 60.0; // Convert RPM to tokens per second
279
280 Self {
281 capacity,
282 tokens: capacity, // Start with full bucket
283 refill_rate,
284 last_refill: Instant::now(),
285 penalty_factor: 1.0,
286 }
287 }
288
289 /// Try to consume tokens
290 fn try_consume(&mut self, tokens: f64) -> bool {
291 self.refill();
292
293 if self.tokens >= tokens {
294 self.tokens -= tokens;
295 true
296 } else {
297 false
298 }
299 }
300
301 /// Refill tokens based on elapsed time
302 fn refill(&mut self) {
303 let now = Instant::now();
304 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
305
306 // Add tokens based on refill rate and penalty factor
307 let new_tokens = elapsed * self.refill_rate * self.penalty_factor;
308 self.tokens = (self.tokens + new_tokens).min(self.capacity);
309 self.last_refill = now;
310 }
311
312 /// Apply penalty (reduce refill rate temporarily)
313 fn apply_penalty(&mut self, factor: f64) {
314 self.penalty_factor = (self.penalty_factor * factor).max(0.1); // Min 10% rate
315 }
316
317 /// Get time until next token is available
318 fn time_until_available(&self, tokens: f64) -> Duration {
319 if self.tokens >= tokens {
320 Duration::ZERO
321 } else {
322 let needed = tokens - self.tokens;
323 let seconds = needed / (self.refill_rate * self.penalty_factor);
324 Duration::from_secs_f64(seconds)
325 }
326 }
327}
328
329/// Rate limiter key
330#[derive(Debug, Clone, Hash, PartialEq, Eq)]
331struct RateLimitKey {
332 client_id: String,
333 method: Option<String>,
334}
335
336/// Rate limiter
337pub struct RateLimiter {
338 config: RateLimitConfig,
339 buckets: Arc<RwLock<HashMap<RateLimitKey, Arc<Mutex<TokenBucket>>>>>,
340 #[allow(dead_code)] // Reserved for periodic bucket cleanup
341 last_cleanup: Arc<Mutex<Instant>>,
342}
343
344/// Rate limit result
345#[derive(Debug)]
346pub struct RateLimitResult {
347 /// Whether the request is allowed
348 pub allowed: bool,
349
350 /// Remaining requests in current window
351 pub remaining: u32,
352
353 /// Time until rate limit resets
354 pub reset_after: Duration,
355
356 /// Current limit (requests per minute)
357 pub limit: u32,
358}
359
360impl RateLimiter {
361 /// Create a new rate limiter
362 pub fn new(config: RateLimitConfig) -> Self {
363 let limiter = Self {
364 config,
365 buckets: Arc::new(RwLock::new(HashMap::new())),
366 last_cleanup: Arc::new(Mutex::new(Instant::now())),
367 };
368
369 // Start cleanup task if enabled
370 if limiter.config.enabled && limiter.config.cleanup_interval_secs > 0 {
371 let buckets = limiter.buckets.clone();
372 let interval = limiter.config.cleanup_interval_secs;
373
374 tokio::spawn(async move {
375 let mut interval = tokio::time::interval(Duration::from_secs(interval));
376 loop {
377 interval.tick().await;
378 Self::cleanup_buckets(buckets.clone()).await;
379 }
380 });
381 }
382
383 limiter
384 }
385
386 /// Check rate limit for a request
387 pub async fn check_limit(
388 &self,
389 client_id: &str,
390 method: Option<&str>,
391 tokens: f64,
392 ) -> Result<RateLimitResult> {
393 if !self.config.enabled {
394 return Ok(RateLimitResult {
395 allowed: true,
396 remaining: u32::MAX,
397 reset_after: Duration::ZERO,
398 limit: u32::MAX,
399 });
400 }
401
402 // Check blacklist first
403 if self.config.blacklist.contains(client_id) {
404 return Ok(RateLimitResult {
405 allowed: false,
406 remaining: 0,
407 reset_after: Duration::from_secs(3600), // 1 hour
408 limit: 0,
409 });
410 }
411
412 // Check whitelist - exempt from rate limiting
413 if self.config.whitelist.contains(client_id) {
414 return Ok(RateLimitResult {
415 allowed: true,
416 remaining: u32::MAX,
417 reset_after: Duration::ZERO,
418 limit: u32::MAX,
419 });
420 }
421
422 // Get applicable limits
423 let (rpm, burst) = self.get_limits(client_id, method);
424
425 // Create key for this check
426 let key = RateLimitKey {
427 client_id: client_id.to_string(),
428 method: method.map(String::from),
429 };
430
431 // Get or create bucket
432 let bucket = self.get_or_create_bucket(&key, rpm, burst).await;
433
434 // Try to consume tokens
435 let mut bucket = bucket.lock();
436 let allowed = bucket.try_consume(tokens);
437 let remaining = bucket.tokens as u32;
438 let reset_after = bucket.time_until_available(1.0);
439
440 Ok(RateLimitResult {
441 allowed,
442 remaining,
443 reset_after,
444 limit: rpm,
445 })
446 }
447
448 /// Apply penalty to a client (e.g., for security threats)
449 pub async fn apply_penalty(&self, client_id: &str, factor: f64) -> Result<()> {
450 if !self.config.enabled {
451 return Ok(());
452 }
453
454 let buckets = self.buckets.read().await;
455
456 // Apply penalty to all buckets for this client
457 for (key, bucket) in buckets.iter() {
458 if key.client_id == client_id {
459 bucket.lock().apply_penalty(factor);
460 }
461 }
462
463 Ok(())
464 }
465
466 /// Get rate limit status for a client
467 pub async fn get_status(&self, client_id: &str) -> Result<HashMap<String, RateLimitResult>> {
468 let mut status = HashMap::new();
469
470 if !self.config.enabled {
471 return Ok(status);
472 }
473
474 let buckets = self.buckets.read().await;
475
476 for (key, bucket) in buckets.iter() {
477 if key.client_id == client_id {
478 let bucket = bucket.lock();
479 let method = key.method.as_deref().unwrap_or("default");
480
481 status.insert(
482 method.to_string(),
483 RateLimitResult {
484 allowed: bucket.tokens >= 1.0,
485 remaining: bucket.tokens as u32,
486 reset_after: bucket.time_until_available(1.0),
487 limit: (bucket.refill_rate * 60.0) as u32,
488 },
489 );
490 }
491 }
492
493 Ok(status)
494 }
495
496 /// Get or create a token bucket
497 async fn get_or_create_bucket(
498 &self,
499 key: &RateLimitKey,
500 rpm: u32,
501 burst: u32,
502 ) -> Arc<Mutex<TokenBucket>> {
503 let mut buckets = self.buckets.write().await;
504
505 buckets
506 .entry(key.clone())
507 .or_insert_with(|| Arc::new(Mutex::new(TokenBucket::new(rpm, burst))))
508 .clone()
509 }
510
511 /// Get applicable rate limits for a client/method
512 fn get_limits(&self, client_id: &str, method: Option<&str>) -> (u32, u32) {
513 // Check client-specific limits first
514 if let Some(client_limit) = self.config.client_limits.get(client_id) {
515 return (client_limit.rpm, client_limit.burst);
516 }
517
518 // Check method-specific limits
519 if let Some(method) = method {
520 if let Some(method_limit) = self.config.method_limits.get(method) {
521 return (method_limit.rpm, method_limit.burst);
522 }
523 }
524
525 // Use defaults
526 (self.config.default_rpm, self.config.burst_capacity)
527 }
528
529 /// Clean up expired buckets
530 async fn cleanup_buckets(buckets: Arc<RwLock<HashMap<RateLimitKey, Arc<Mutex<TokenBucket>>>>>) {
531 let mut buckets = buckets.write().await;
532 let now = Instant::now();
533
534 // Remove buckets that haven't been used in 10 minutes
535 buckets.retain(|_, bucket| {
536 let bucket = bucket.lock();
537 now.duration_since(bucket.last_refill) < Duration::from_secs(600)
538 });
539 }
540
541 /// Create rate limit headers for HTTP responses
542 pub fn create_headers(result: &RateLimitResult) -> HashMap<String, String> {
543 let mut headers = HashMap::new();
544
545 headers.insert("X-RateLimit-Limit".to_string(), result.limit.to_string());
546 headers.insert(
547 "X-RateLimit-Remaining".to_string(),
548 result.remaining.to_string(),
549 );
550 headers.insert(
551 "X-RateLimit-Reset".to_string(),
552 (Instant::now() + result.reset_after)
553 .duration_since(Instant::now())
554 .as_secs()
555 .to_string(),
556 );
557
558 if !result.allowed {
559 headers.insert(
560 "Retry-After".to_string(),
561 result.reset_after.as_secs().to_string(),
562 );
563 }
564
565 headers
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_token_bucket() {
575 let mut bucket = TokenBucket::new(60, 10);
576
577 // Should start with full capacity
578 assert!(bucket.try_consume(10.0));
579 assert!(!bucket.try_consume(1.0)); // Should be empty
580
581 // Wait a bit and refill should work
582 std::thread::sleep(Duration::from_millis(1100)); // 1.1 seconds
583 bucket.refill();
584 assert!(bucket.tokens > 0.0); // Should have ~1 token
585 }
586
587 #[tokio::test]
588 async fn test_rate_limiter() {
589 let config = RateLimitConfig {
590 enabled: true,
591 default_rpm: 60,
592 burst_capacity: 10,
593 ..Default::default()
594 };
595
596 let limiter = RateLimiter::new(config);
597
598 // Should allow burst
599 for _ in 0..10 {
600 let result = limiter.check_limit("test-client", None, 1.0).await.unwrap();
601 assert!(result.allowed);
602 }
603
604 // Should be rate limited
605 let result = limiter.check_limit("test-client", None, 1.0).await.unwrap();
606 assert!(!result.allowed);
607 assert!(result.reset_after > Duration::ZERO);
608 }
609
610 #[test]
611 fn test_penalty_application() {
612 let mut bucket = TokenBucket::new(60, 10);
613 bucket.apply_penalty(0.5);
614
615 // Refill rate should be halved
616 assert_eq!(bucket.penalty_factor, 0.5);
617 }
618}