kindly_guard_server/
rate_limit.rs

1// Copyright 2025 Kindly Software Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Rate limiting for MCP server requests
15//! Implements token bucket algorithm with per-client and per-operation limits
16
17use anyhow::Result;
18use parking_lot::Mutex;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23use tokio::sync::RwLock;
24
25/// Rate limiting configuration
26///
27/// # Security Implications
28///
29/// Rate limiting is essential for preventing abuse and DoS attacks:
30/// - **Prevents brute force attacks** - Limits authentication attempts
31/// - **Protects against resource exhaustion** - Controls request rates
32/// - **Mitigates data harvesting** - Slows down automated scraping
33/// - **Adaptive penalties** - Automatically restricts suspicious clients
34///
35/// # Example: Secure Production Configuration
36///
37/// ```toml
38/// [rate_limit]
39/// enabled = true
40/// default_rpm = 60           # 1 request per second average
41/// burst_capacity = 10        # Allow short bursts
42/// cleanup_interval_secs = 300
43/// adaptive = true            # Auto-adjust based on threats
44/// threat_penalty_multiplier = 0.5  # Halve limits for threats
45///
46/// [rate_limit.method_limits]
47/// "tools/list" = { rpm = 120, burst = 20 }     # Read operations
48/// "tools/call" = { rpm = 30, burst = 5 }       # Execution operations
49/// "security/neutralize" = { rpm = 10, burst = 2 }  # Sensitive operations
50///
51/// [rate_limit.client_limits]
52/// "trusted-app" = { rpm = 300, burst = 50, priority = "high" }
53/// "public-api" = { rpm = 30, burst = 5, priority = "low" }
54/// ```
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct RateLimitConfig {
57    /// Enable rate limiting
58    ///
59    /// **Default**: false (for easier testing)
60    /// **Security**: MUST be true in production to prevent abuse.
61    /// Without rate limiting, attackers can overwhelm the service.
62    /// **Warning**: Disabling exposes you to DoS and brute force attacks
63    pub enabled: bool,
64
65    /// Default requests per minute
66    ///
67    /// **Default**: 60 (1 per second average)
68    /// **Security**: Lower values are more secure but may impact usability.
69    /// Consider your threat model and legitimate usage patterns.
70    /// **Range**: 10-600 (recommend 30-120 for most APIs)
71    pub default_rpm: u32,
72
73    /// Burst capacity (tokens available immediately)
74    ///
75    /// **Default**: 10
76    /// **Security**: Allows legitimate burst traffic while preventing abuse.
77    /// Too high enables rapid attacks; too low impacts user experience.
78    /// **Range**: 1-50 (recommend 5-20, should be < default_rpm/6)
79    pub burst_capacity: u32,
80
81    /// Per-method rate limits (overrides default)
82    ///
83    /// **Default**: Sensible limits for common operations
84    /// **Security**: Set stricter limits on sensitive operations.
85    /// Read operations can have higher limits than write operations.
86    /// **Best Practice**: Order from least to most sensitive
87    pub method_limits: HashMap<String, MethodLimit>,
88
89    /// Per-client rate limits (by client ID)
90    ///
91    /// **Default**: Empty (all clients use default limits)
92    /// **Security**: Assign higher limits only to trusted clients.
93    /// Use priority levels to ensure critical clients aren't blocked.
94    /// **Warning**: Overly generous limits can be exploited
95    pub client_limits: HashMap<String, ClientLimit>,
96
97    /// Clean up interval for expired buckets (seconds)
98    ///
99    /// **Default**: 300 (5 minutes)
100    /// **Security**: Regular cleanup prevents memory exhaustion.
101    /// Shorter intervals use more CPU but free memory faster.
102    /// **Range**: 60-3600 (recommend 300-900)
103    pub cleanup_interval_secs: u64,
104
105    /// Enable adaptive rate limiting based on load
106    ///
107    /// **Default**: false
108    /// **Security**: Automatically tightens limits under attack.
109    /// Reduces false positives during traffic spikes.
110    /// **Trade-off**: Adds complexity but improves resilience
111    pub adaptive: bool,
112
113    /// Penalty for security threats (multiplier)
114    ///
115    /// **Default**: 0.5 (halve the rate limit)
116    /// **Security**: Clients triggering security alerts get reduced limits.
117    /// Helps contain attacks while allowing recovery for false positives.
118    /// **Range**: 0.1-1.0 (0.1 = 90% reduction, 1.0 = no penalty)
119    pub threat_penalty_multiplier: f32,
120
121    /// Whitelist of client IDs exempt from rate limiting
122    ///
123    /// **Default**: Empty set
124    /// **Security**: Only whitelist fully trusted internal clients.
125    /// Whitelisted clients can still trigger other security measures.
126    /// **Warning**: Use sparingly - prefer higher rate limits over exemption
127    #[serde(default)]
128    pub whitelist: HashSet<String>,
129
130    /// Blacklist of client IDs to always block
131    ///
132    /// **Default**: Empty set
133    /// **Security**: Immediately reject requests from blacklisted clients.
134    /// Useful for blocking known malicious actors or compromised credentials.
135    /// **Note**: Blacklist takes precedence over whitelist
136    #[serde(default)]
137    pub blacklist: HashSet<String>,
138
139    /// IP-specific rate limits
140    ///
141    /// **Default**: Empty map
142    /// **Security**: Set stricter limits for suspicious IP ranges.
143    /// Useful for geographic restrictions or known problematic networks.
144    /// **Format**: IP address or CIDR notation as key
145    #[serde(default)]
146    pub ip_limits: HashMap<String, IpLimit>,
147
148    /// Global requests per minute limit across all clients
149    ///
150    /// **Default**: None (no global limit)
151    /// **Security**: Prevents total system overload regardless of client distribution.
152    /// Individual client limits still apply within the global limit.
153    /// **Range**: 100-10000 (recommend 10x expected peak traffic)
154    pub global_rpm: Option<u32>,
155
156    /// Method for tracking clients
157    ///
158    /// **Default**: ClientId
159    /// **Security**: Determines how rate limits are applied.
160    /// IP-based tracking is more strict but may affect legitimate shared IPs.
161    /// **Options**: ClientId, IpAddress, Combined (both must pass)
162    #[serde(default)]
163    pub track_by: TrackingMethod,
164}
165
166/// IP-specific rate limit
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct IpLimit {
169    pub rpm: u32,
170    pub burst: u32,
171    pub block: bool,
172}
173
174/// Client tracking method
175#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
176#[serde(rename_all = "snake_case")]
177pub enum TrackingMethod {
178    /// Track by client ID only
179    ClientId,
180    /// Track by IP address only
181    IpAddress,
182    /// Track by both (both must pass rate limits)
183    Combined,
184}
185
186impl Default for TrackingMethod {
187    fn default() -> Self {
188        Self::ClientId
189    }
190}
191
192/// Method-specific rate limit
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct MethodLimit {
195    pub rpm: u32,
196    pub burst: u32,
197}
198
199/// Client-specific rate limit
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ClientLimit {
202    pub rpm: u32,
203    pub burst: u32,
204    pub priority: ClientPriority,
205}
206
207/// Client priority levels
208#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
209#[serde(rename_all = "snake_case")]
210pub enum ClientPriority {
211    Low = 0,
212    Normal = 1,
213    High = 2,
214    Premium = 3,
215}
216
217impl Default for RateLimitConfig {
218    fn default() -> Self {
219        let mut method_limits = HashMap::new();
220
221        // Higher limits for read operations
222        method_limits.insert("tools/list".to_string(), MethodLimit { rpm: 60, burst: 10 });
223        method_limits.insert(
224            "resources/list".to_string(),
225            MethodLimit { rpm: 60, burst: 10 },
226        );
227
228        // Lower limits for execution operations
229        method_limits.insert("tools/call".to_string(), MethodLimit { rpm: 30, burst: 5 });
230
231        // Very low limits for security-sensitive operations
232        method_limits.insert(
233            "security/threats".to_string(),
234            MethodLimit { rpm: 10, burst: 2 },
235        );
236
237        Self {
238            enabled: false,
239            default_rpm: 60,
240            burst_capacity: 10,
241            method_limits,
242            client_limits: HashMap::new(),
243            cleanup_interval_secs: 300, // 5 minutes
244            adaptive: false,
245            threat_penalty_multiplier: 0.5, // Halve rate limit on threat detection
246            whitelist: HashSet::new(),
247            blacklist: HashSet::new(),
248            ip_limits: HashMap::new(),
249            global_rpm: None,
250            track_by: TrackingMethod::default(),
251        }
252    }
253}
254
255/// Token bucket for rate limiting
256#[derive(Debug)]
257struct TokenBucket {
258    /// Maximum tokens (burst capacity)
259    capacity: f64,
260
261    /// Current tokens available
262    tokens: f64,
263
264    /// Refill rate (tokens per second)
265    refill_rate: f64,
266
267    /// Last refill time
268    last_refill: Instant,
269
270    /// Penalty factor (0.0 to 1.0, where 1.0 is normal)
271    penalty_factor: f64,
272}
273
274impl TokenBucket {
275    /// Create a new token bucket
276    fn new(rpm: u32, burst: u32) -> Self {
277        let capacity = f64::from(burst);
278        let refill_rate = f64::from(rpm) / 60.0; // Convert RPM to tokens per second
279
280        Self {
281            capacity,
282            tokens: capacity, // Start with full bucket
283            refill_rate,
284            last_refill: Instant::now(),
285            penalty_factor: 1.0,
286        }
287    }
288
289    /// Try to consume tokens
290    fn try_consume(&mut self, tokens: f64) -> bool {
291        self.refill();
292
293        if self.tokens >= tokens {
294            self.tokens -= tokens;
295            true
296        } else {
297            false
298        }
299    }
300
301    /// Refill tokens based on elapsed time
302    fn refill(&mut self) {
303        let now = Instant::now();
304        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
305
306        // Add tokens based on refill rate and penalty factor
307        let new_tokens = elapsed * self.refill_rate * self.penalty_factor;
308        self.tokens = (self.tokens + new_tokens).min(self.capacity);
309        self.last_refill = now;
310    }
311
312    /// Apply penalty (reduce refill rate temporarily)
313    fn apply_penalty(&mut self, factor: f64) {
314        self.penalty_factor = (self.penalty_factor * factor).max(0.1); // Min 10% rate
315    }
316
317    /// Get time until next token is available
318    fn time_until_available(&self, tokens: f64) -> Duration {
319        if self.tokens >= tokens {
320            Duration::ZERO
321        } else {
322            let needed = tokens - self.tokens;
323            let seconds = needed / (self.refill_rate * self.penalty_factor);
324            Duration::from_secs_f64(seconds)
325        }
326    }
327}
328
329/// Rate limiter key
330#[derive(Debug, Clone, Hash, PartialEq, Eq)]
331struct RateLimitKey {
332    client_id: String,
333    method: Option<String>,
334}
335
336/// Rate limiter
337pub struct RateLimiter {
338    config: RateLimitConfig,
339    buckets: Arc<RwLock<HashMap<RateLimitKey, Arc<Mutex<TokenBucket>>>>>,
340    #[allow(dead_code)] // Reserved for periodic bucket cleanup
341    last_cleanup: Arc<Mutex<Instant>>,
342}
343
344/// Rate limit result
345#[derive(Debug)]
346pub struct RateLimitResult {
347    /// Whether the request is allowed
348    pub allowed: bool,
349
350    /// Remaining requests in current window
351    pub remaining: u32,
352
353    /// Time until rate limit resets
354    pub reset_after: Duration,
355
356    /// Current limit (requests per minute)
357    pub limit: u32,
358}
359
360impl RateLimiter {
361    /// Create a new rate limiter
362    pub fn new(config: RateLimitConfig) -> Self {
363        let limiter = Self {
364            config,
365            buckets: Arc::new(RwLock::new(HashMap::new())),
366            last_cleanup: Arc::new(Mutex::new(Instant::now())),
367        };
368
369        // Start cleanup task if enabled
370        if limiter.config.enabled && limiter.config.cleanup_interval_secs > 0 {
371            let buckets = limiter.buckets.clone();
372            let interval = limiter.config.cleanup_interval_secs;
373
374            tokio::spawn(async move {
375                let mut interval = tokio::time::interval(Duration::from_secs(interval));
376                loop {
377                    interval.tick().await;
378                    Self::cleanup_buckets(buckets.clone()).await;
379                }
380            });
381        }
382
383        limiter
384    }
385
386    /// Check rate limit for a request
387    pub async fn check_limit(
388        &self,
389        client_id: &str,
390        method: Option<&str>,
391        tokens: f64,
392    ) -> Result<RateLimitResult> {
393        if !self.config.enabled {
394            return Ok(RateLimitResult {
395                allowed: true,
396                remaining: u32::MAX,
397                reset_after: Duration::ZERO,
398                limit: u32::MAX,
399            });
400        }
401
402        // Check blacklist first
403        if self.config.blacklist.contains(client_id) {
404            return Ok(RateLimitResult {
405                allowed: false,
406                remaining: 0,
407                reset_after: Duration::from_secs(3600), // 1 hour
408                limit: 0,
409            });
410        }
411
412        // Check whitelist - exempt from rate limiting
413        if self.config.whitelist.contains(client_id) {
414            return Ok(RateLimitResult {
415                allowed: true,
416                remaining: u32::MAX,
417                reset_after: Duration::ZERO,
418                limit: u32::MAX,
419            });
420        }
421
422        // Get applicable limits
423        let (rpm, burst) = self.get_limits(client_id, method);
424
425        // Create key for this check
426        let key = RateLimitKey {
427            client_id: client_id.to_string(),
428            method: method.map(String::from),
429        };
430
431        // Get or create bucket
432        let bucket = self.get_or_create_bucket(&key, rpm, burst).await;
433
434        // Try to consume tokens
435        let mut bucket = bucket.lock();
436        let allowed = bucket.try_consume(tokens);
437        let remaining = bucket.tokens as u32;
438        let reset_after = bucket.time_until_available(1.0);
439
440        Ok(RateLimitResult {
441            allowed,
442            remaining,
443            reset_after,
444            limit: rpm,
445        })
446    }
447
448    /// Apply penalty to a client (e.g., for security threats)
449    pub async fn apply_penalty(&self, client_id: &str, factor: f64) -> Result<()> {
450        if !self.config.enabled {
451            return Ok(());
452        }
453
454        let buckets = self.buckets.read().await;
455
456        // Apply penalty to all buckets for this client
457        for (key, bucket) in buckets.iter() {
458            if key.client_id == client_id {
459                bucket.lock().apply_penalty(factor);
460            }
461        }
462
463        Ok(())
464    }
465
466    /// Get rate limit status for a client
467    pub async fn get_status(&self, client_id: &str) -> Result<HashMap<String, RateLimitResult>> {
468        let mut status = HashMap::new();
469
470        if !self.config.enabled {
471            return Ok(status);
472        }
473
474        let buckets = self.buckets.read().await;
475
476        for (key, bucket) in buckets.iter() {
477            if key.client_id == client_id {
478                let bucket = bucket.lock();
479                let method = key.method.as_deref().unwrap_or("default");
480
481                status.insert(
482                    method.to_string(),
483                    RateLimitResult {
484                        allowed: bucket.tokens >= 1.0,
485                        remaining: bucket.tokens as u32,
486                        reset_after: bucket.time_until_available(1.0),
487                        limit: (bucket.refill_rate * 60.0) as u32,
488                    },
489                );
490            }
491        }
492
493        Ok(status)
494    }
495
496    /// Get or create a token bucket
497    async fn get_or_create_bucket(
498        &self,
499        key: &RateLimitKey,
500        rpm: u32,
501        burst: u32,
502    ) -> Arc<Mutex<TokenBucket>> {
503        let mut buckets = self.buckets.write().await;
504
505        buckets
506            .entry(key.clone())
507            .or_insert_with(|| Arc::new(Mutex::new(TokenBucket::new(rpm, burst))))
508            .clone()
509    }
510
511    /// Get applicable rate limits for a client/method
512    fn get_limits(&self, client_id: &str, method: Option<&str>) -> (u32, u32) {
513        // Check client-specific limits first
514        if let Some(client_limit) = self.config.client_limits.get(client_id) {
515            return (client_limit.rpm, client_limit.burst);
516        }
517
518        // Check method-specific limits
519        if let Some(method) = method {
520            if let Some(method_limit) = self.config.method_limits.get(method) {
521                return (method_limit.rpm, method_limit.burst);
522            }
523        }
524
525        // Use defaults
526        (self.config.default_rpm, self.config.burst_capacity)
527    }
528
529    /// Clean up expired buckets
530    async fn cleanup_buckets(buckets: Arc<RwLock<HashMap<RateLimitKey, Arc<Mutex<TokenBucket>>>>>) {
531        let mut buckets = buckets.write().await;
532        let now = Instant::now();
533
534        // Remove buckets that haven't been used in 10 minutes
535        buckets.retain(|_, bucket| {
536            let bucket = bucket.lock();
537            now.duration_since(bucket.last_refill) < Duration::from_secs(600)
538        });
539    }
540
541    /// Create rate limit headers for HTTP responses
542    pub fn create_headers(result: &RateLimitResult) -> HashMap<String, String> {
543        let mut headers = HashMap::new();
544
545        headers.insert("X-RateLimit-Limit".to_string(), result.limit.to_string());
546        headers.insert(
547            "X-RateLimit-Remaining".to_string(),
548            result.remaining.to_string(),
549        );
550        headers.insert(
551            "X-RateLimit-Reset".to_string(),
552            (Instant::now() + result.reset_after)
553                .duration_since(Instant::now())
554                .as_secs()
555                .to_string(),
556        );
557
558        if !result.allowed {
559            headers.insert(
560                "Retry-After".to_string(),
561                result.reset_after.as_secs().to_string(),
562            );
563        }
564
565        headers
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_token_bucket() {
575        let mut bucket = TokenBucket::new(60, 10);
576
577        // Should start with full capacity
578        assert!(bucket.try_consume(10.0));
579        assert!(!bucket.try_consume(1.0)); // Should be empty
580
581        // Wait a bit and refill should work
582        std::thread::sleep(Duration::from_millis(1100)); // 1.1 seconds
583        bucket.refill();
584        assert!(bucket.tokens > 0.0); // Should have ~1 token
585    }
586
587    #[tokio::test]
588    async fn test_rate_limiter() {
589        let config = RateLimitConfig {
590            enabled: true,
591            default_rpm: 60,
592            burst_capacity: 10,
593            ..Default::default()
594        };
595
596        let limiter = RateLimiter::new(config);
597
598        // Should allow burst
599        for _ in 0..10 {
600            let result = limiter.check_limit("test-client", None, 1.0).await.unwrap();
601            assert!(result.allowed);
602        }
603
604        // Should be rate limited
605        let result = limiter.check_limit("test-client", None, 1.0).await.unwrap();
606        assert!(!result.allowed);
607        assert!(result.reset_after > Duration::ZERO);
608    }
609
610    #[test]
611    fn test_penalty_application() {
612        let mut bucket = TokenBucket::new(60, 10);
613        bucket.apply_penalty(0.5);
614
615        // Refill rate should be halved
616        assert_eq!(bucket.penalty_factor, 0.5);
617    }
618}