codex_memory/mcp_server/
rate_limiter.rs

1//! MCP Rate Limiting System
2//!
3//! This module provides configurable rate limiting for MCP requests,
4//! supporting per-client, per-tool, and global rate limits.
5
6use crate::mcp_server::auth::AuthContext;
7use crate::security::{audit::AuditLogger, SecurityError};
8use anyhow::Result;
9use governor::{
10    clock::DefaultClock,
11    middleware::NoOpMiddleware,
12    state::{InMemoryState, NotKeyed},
13    Quota, RateLimiter as GovernorRateLimiter,
14};
15use nonzero_ext::nonzero;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::env;
19use std::num::NonZeroU32;
20use std::sync::Arc;
21use tokio::sync::RwLock;
22use tracing::{debug, warn};
23
24/// Rate limiting configuration for MCP
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MCPRateLimitConfig {
27    pub enabled: bool,
28    pub global_requests_per_minute: u32,
29    pub global_burst_size: u32,
30    pub per_client_requests_per_minute: u32,
31    pub per_client_burst_size: u32,
32    pub per_tool_requests_per_minute: HashMap<String, u32>,
33    pub per_tool_burst_size: HashMap<String, u32>,
34    pub silent_mode_multiplier: f64,
35    pub whitelist_clients: Vec<String>,
36    pub performance_target_ms: u64,
37}
38
39impl Default for MCPRateLimitConfig {
40    fn default() -> Self {
41        Self {
42            enabled: env::var("MCP_RATE_LIMIT_ENABLED")
43                .map(|s| s.parse().unwrap_or(true))
44                .unwrap_or(true),
45            global_requests_per_minute: env::var("MCP_GLOBAL_RATE_LIMIT")
46                .ok()
47                .and_then(|s| s.parse().ok())
48                .unwrap_or(1000),
49            global_burst_size: env::var("MCP_GLOBAL_BURST_SIZE")
50                .ok()
51                .and_then(|s| s.parse().ok())
52                .unwrap_or(50),
53            per_client_requests_per_minute: env::var("MCP_CLIENT_RATE_LIMIT")
54                .ok()
55                .and_then(|s| s.parse().ok())
56                .unwrap_or(100),
57            per_client_burst_size: env::var("MCP_CLIENT_BURST_SIZE")
58                .ok()
59                .and_then(|s| s.parse().ok())
60                .unwrap_or(10),
61            per_tool_requests_per_minute: Self::load_tool_rates_from_env(),
62            per_tool_burst_size: Self::load_tool_bursts_from_env(),
63            silent_mode_multiplier: env::var("MCP_SILENT_MODE_MULTIPLIER")
64                .ok()
65                .and_then(|s| s.parse().ok())
66                .unwrap_or(0.5), // Reduce limits by 50% in silent mode
67            whitelist_clients: env::var("MCP_RATE_LIMIT_WHITELIST")
68                .map(|s| s.split(',').map(|c| c.trim().to_string()).collect())
69                .unwrap_or_default(),
70            performance_target_ms: 5, // Must be <5ms per requirement
71        }
72    }
73}
74
75impl MCPRateLimitConfig {
76    /// Load per-tool rate limits from environment variables
77    fn load_tool_rates_from_env() -> HashMap<String, u32> {
78        let mut rates = HashMap::new();
79
80        // Default tool-specific rates
81        rates.insert("store_memory".to_string(), 50);
82        rates.insert("search_memory".to_string(), 200);
83        rates.insert("get_statistics".to_string(), 20);
84        rates.insert("what_did_you_remember".to_string(), 30);
85        rates.insert("harvest_conversation".to_string(), 100);
86        rates.insert("get_harvester_metrics".to_string(), 10);
87        rates.insert("migrate_memory".to_string(), 20);
88        rates.insert("delete_memory".to_string(), 10);
89
90        // Load custom rates from environment
91        if let Ok(custom_rates) = env::var("MCP_TOOL_RATE_LIMITS") {
92            if let Ok(parsed) = serde_json::from_str::<HashMap<String, u32>>(&custom_rates) {
93                rates.extend(parsed);
94            }
95        }
96
97        rates
98    }
99
100    /// Load per-tool burst sizes from environment variables
101    fn load_tool_bursts_from_env() -> HashMap<String, u32> {
102        let mut bursts = HashMap::new();
103
104        // Default tool-specific burst sizes
105        bursts.insert("store_memory".to_string(), 5);
106        bursts.insert("search_memory".to_string(), 20);
107        bursts.insert("get_statistics".to_string(), 2);
108        bursts.insert("what_did_you_remember".to_string(), 3);
109        bursts.insert("harvest_conversation".to_string(), 10);
110        bursts.insert("get_harvester_metrics".to_string(), 1);
111        bursts.insert("migrate_memory".to_string(), 2);
112        bursts.insert("delete_memory".to_string(), 1);
113
114        // Load custom burst sizes from environment
115        if let Ok(custom_bursts) = env::var("MCP_TOOL_BURST_SIZES") {
116            if let Ok(parsed) = serde_json::from_str::<HashMap<String, u32>>(&custom_bursts) {
117                bursts.extend(parsed);
118            }
119        }
120
121        bursts
122    }
123
124    /// Create configuration from environment variables
125    pub fn from_env() -> Self {
126        Self::default()
127    }
128}
129
130/// Rate limiting statistics
131#[derive(Debug, Clone, Serialize)]
132pub struct RateLimitStats {
133    pub total_requests: u64,
134    pub rejected_requests: u64,
135    pub rejection_rate: f64,
136    pub per_client_rejections: HashMap<String, u64>,
137    pub per_tool_rejections: HashMap<String, u64>,
138    pub avg_check_duration_ms: f64,
139    pub peak_requests_per_minute: u64,
140}
141
142/// Individual rate limiter for a specific scope
143pub struct ScopedRateLimiter {
144    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
145    requests_per_minute: u32,
146    burst_size: u32,
147    name: String,
148}
149
150impl ScopedRateLimiter {
151    fn new(requests_per_minute: u32, burst_size: u32, name: String) -> Self {
152        let rate = NonZeroU32::new(requests_per_minute).unwrap_or(nonzero!(1u32));
153        let burst = NonZeroU32::new(burst_size).unwrap_or(nonzero!(1u32));
154        let quota = Quota::per_minute(rate).allow_burst(burst);
155        let limiter = Arc::new(GovernorRateLimiter::direct(quota));
156
157        Self {
158            limiter,
159            requests_per_minute,
160            burst_size,
161            name,
162        }
163    }
164
165    fn check_rate_limit(&self) -> Result<(), governor::NotUntil<governor::clock::QuantaInstant>> {
166        self.limiter.check()
167    }
168}
169
170/// MCP Rate Limiter implementation
171pub struct MCPRateLimiter {
172    config: MCPRateLimitConfig,
173    global_limiter: Option<ScopedRateLimiter>,
174    client_limiters: Arc<RwLock<HashMap<String, ScopedRateLimiter>>>,
175    tool_limiters: HashMap<String, ScopedRateLimiter>,
176    stats: Arc<RwLock<RateLimitStats>>,
177    audit_logger: Arc<AuditLogger>,
178}
179
180impl MCPRateLimiter {
181    /// Create a new rate limiter
182    pub fn new(config: MCPRateLimitConfig, audit_logger: Arc<AuditLogger>) -> Self {
183        let global_limiter = if config.enabled {
184            Some(ScopedRateLimiter::new(
185                config.global_requests_per_minute,
186                config.global_burst_size,
187                "global".to_string(),
188            ))
189        } else {
190            None
191        };
192
193        let mut tool_limiters = HashMap::new();
194        if config.enabled {
195            for (tool_name, &rate) in &config.per_tool_requests_per_minute {
196                let burst = config
197                    .per_tool_burst_size
198                    .get(tool_name)
199                    .copied()
200                    .unwrap_or(rate / 10); // Default burst is 10% of rate
201
202                tool_limiters.insert(
203                    tool_name.clone(),
204                    ScopedRateLimiter::new(rate, burst, format!("tool:{tool_name}")),
205                );
206            }
207        }
208
209        let stats = Arc::new(RwLock::new(RateLimitStats {
210            total_requests: 0,
211            rejected_requests: 0,
212            rejection_rate: 0.0,
213            per_client_rejections: HashMap::new(),
214            per_tool_rejections: HashMap::new(),
215            avg_check_duration_ms: 0.0,
216            peak_requests_per_minute: 0,
217        }));
218
219        Self {
220            config,
221            global_limiter,
222            client_limiters: Arc::new(RwLock::new(HashMap::new())),
223            tool_limiters,
224            stats,
225            audit_logger,
226        }
227    }
228
229    /// Check rate limits for an MCP request
230    pub async fn check_rate_limit(
231        &self,
232        auth_context: Option<&AuthContext>,
233        tool_name: &str,
234        silent_mode: bool,
235    ) -> Result<()> {
236        let start_time = std::time::Instant::now();
237
238        if !self.config.enabled {
239            return Ok(());
240        }
241
242        let client_id = auth_context
243            .map(|ctx| ctx.client_id.as_str())
244            .unwrap_or("anonymous");
245
246        // Check if client is whitelisted
247        if self
248            .config
249            .whitelist_clients
250            .contains(&client_id.to_string())
251        {
252            debug!("Client {} is whitelisted, skipping rate limits", client_id);
253            return Ok(());
254        }
255
256        // Update statistics
257        {
258            let mut stats = self.stats.write().await;
259            stats.total_requests += 1;
260        }
261
262        // Apply silent mode multiplier if needed
263        let rate_multiplier = if silent_mode {
264            self.config.silent_mode_multiplier
265        } else {
266            1.0
267        };
268
269        // Check global rate limit
270        if let Some(ref global_limiter) = self.global_limiter {
271            if global_limiter.check_rate_limit().is_err() {
272                self.handle_rate_limit_violation("global", client_id, tool_name)
273                    .await;
274                return Err(SecurityError::RateLimitExceeded.into());
275            }
276        }
277
278        // Check per-client rate limit
279        let client_limiter = self
280            .get_or_create_client_limiter(client_id, rate_multiplier)
281            .await;
282        if client_limiter.check_rate_limit().is_err() {
283            self.handle_rate_limit_violation("client", client_id, tool_name)
284                .await;
285            return Err(SecurityError::RateLimitExceeded.into());
286        }
287
288        // Check per-tool rate limit
289        if let Some(tool_limiter) = self.tool_limiters.get(tool_name) {
290            if tool_limiter.check_rate_limit().is_err() {
291                self.handle_rate_limit_violation("tool", client_id, tool_name)
292                    .await;
293                return Err(SecurityError::RateLimitExceeded.into());
294            }
295        }
296
297        let elapsed = start_time.elapsed();
298
299        // Check performance requirement
300        if elapsed.as_millis() > self.config.performance_target_ms as u128 {
301            warn!(
302                "Rate limit check took {}ms, exceeding target of {}ms",
303                elapsed.as_millis(),
304                self.config.performance_target_ms
305            );
306        }
307
308        // Update average check duration
309        {
310            let mut stats = self.stats.write().await;
311            let total_ms = stats.avg_check_duration_ms * (stats.total_requests - 1) as f64;
312            stats.avg_check_duration_ms =
313                (total_ms + elapsed.as_millis() as f64) / stats.total_requests as f64;
314        }
315
316        debug!(
317            "Rate limit check passed for client: {}, tool: {}",
318            client_id, tool_name
319        );
320        Ok(())
321    }
322
323    /// Get or create a client-specific rate limiter
324    async fn get_or_create_client_limiter(
325        &self,
326        client_id: &str,
327        rate_multiplier: f64,
328    ) -> ScopedRateLimiter {
329        {
330            let limiters = self.client_limiters.read().await;
331            if let Some(limiter) = limiters.get(client_id) {
332                return ScopedRateLimiter {
333                    limiter: limiter.limiter.clone(),
334                    requests_per_minute: limiter.requests_per_minute,
335                    burst_size: limiter.burst_size,
336                    name: limiter.name.clone(),
337                };
338            }
339        }
340
341        // Create new limiter for this client
342        let adjusted_rate =
343            (self.config.per_client_requests_per_minute as f64 * rate_multiplier) as u32;
344        let adjusted_burst = (self.config.per_client_burst_size as f64 * rate_multiplier) as u32;
345
346        let limiter = ScopedRateLimiter::new(
347            adjusted_rate.max(1),
348            adjusted_burst.max(1),
349            format!("client:{client_id}"),
350        );
351
352        // Store the limiter for future use
353        {
354            let mut limiters = self.client_limiters.write().await;
355            limiters.insert(
356                client_id.to_string(),
357                ScopedRateLimiter {
358                    limiter: limiter.limiter.clone(),
359                    requests_per_minute: limiter.requests_per_minute,
360                    burst_size: limiter.burst_size,
361                    name: limiter.name.clone(),
362                },
363            );
364        }
365
366        limiter
367    }
368
369    /// Handle rate limit violations
370    async fn handle_rate_limit_violation(
371        &self,
372        limit_type: &str,
373        client_id: &str,
374        tool_name: &str,
375    ) {
376        warn!(
377            "Rate limit violation - Type: {}, Client: {}, Tool: {}",
378            limit_type, client_id, tool_name
379        );
380
381        // Update rejection statistics
382        {
383            let mut stats = self.stats.write().await;
384            stats.rejected_requests += 1;
385            stats.rejection_rate = stats.rejected_requests as f64 / stats.total_requests as f64;
386
387            *stats
388                .per_client_rejections
389                .entry(client_id.to_string())
390                .or_insert(0) += 1;
391            *stats
392                .per_tool_rejections
393                .entry(tool_name.to_string())
394                .or_insert(0) += 1;
395        }
396
397        // Log the violation for security auditing
398        self.audit_logger
399            .log_rate_limit_violation(client_id, tool_name, limit_type)
400            .await;
401    }
402
403    /// Reset rate limits for a specific client (admin function)
404    pub async fn reset_client_limits(&self, client_id: &str) -> Result<()> {
405        let mut limiters = self.client_limiters.write().await;
406        limiters.remove(client_id);
407
408        // Create a fresh limiter with default rates to ensure the Governor state is reset
409        let fresh_limiter = ScopedRateLimiter::new(
410            self.config.per_client_requests_per_minute,
411            self.config.per_client_burst_size,
412            format!("client:{client_id}"),
413        );
414
415        limiters.insert(client_id.to_string(), fresh_limiter);
416        debug!("Reset rate limits for client: {}", client_id);
417        Ok(())
418    }
419
420    /// Get current rate limiting statistics
421    pub async fn get_stats(&self) -> RateLimitStats {
422        self.stats.read().await.clone()
423    }
424
425    /// Update configuration dynamically
426    pub async fn update_config(&mut self, new_config: MCPRateLimitConfig) -> Result<()> {
427        debug!("Updating rate limiter configuration");
428
429        // Update global limiter
430        self.global_limiter = if new_config.enabled {
431            Some(ScopedRateLimiter::new(
432                new_config.global_requests_per_minute,
433                new_config.global_burst_size,
434                "global".to_string(),
435            ))
436        } else {
437            None
438        };
439
440        // Update tool limiters
441        self.tool_limiters.clear();
442        if new_config.enabled {
443            for (tool_name, &rate) in &new_config.per_tool_requests_per_minute {
444                let burst = new_config
445                    .per_tool_burst_size
446                    .get(tool_name)
447                    .copied()
448                    .unwrap_or(rate / 10);
449
450                self.tool_limiters.insert(
451                    tool_name.clone(),
452                    ScopedRateLimiter::new(rate, burst, format!("tool:{tool_name}")),
453                );
454            }
455        }
456
457        // Clear existing client limiters to force recreation with new rates
458        {
459            let mut limiters = self.client_limiters.write().await;
460            limiters.clear();
461        }
462
463        self.config = new_config;
464        Ok(())
465    }
466
467    /// Get rate limiter configuration and status
468    pub async fn get_status(&self) -> serde_json::Value {
469        let stats = self.get_stats().await;
470        let client_count = self.client_limiters.read().await.len();
471
472        serde_json::json!({
473            "enabled": self.config.enabled,
474            "global_limits": {
475                "requests_per_minute": self.config.global_requests_per_minute,
476                "burst_size": self.config.global_burst_size,
477            },
478            "per_client_limits": {
479                "requests_per_minute": self.config.per_client_requests_per_minute,
480                "burst_size": self.config.per_client_burst_size,
481                "active_clients": client_count,
482            },
483            "tool_limits": self.config.per_tool_requests_per_minute,
484            "statistics": stats,
485            "performance": {
486                "target_ms": self.config.performance_target_ms,
487                "avg_check_duration_ms": stats.avg_check_duration_ms,
488            },
489            "silent_mode_multiplier": self.config.silent_mode_multiplier,
490            "whitelist_clients": self.config.whitelist_clients.len(),
491        })
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::mcp_server::auth::AuthMethod;
499    use crate::security::AuditConfig;
500    use tempfile::tempdir;
501
502    fn create_test_config() -> MCPRateLimitConfig {
503        MCPRateLimitConfig {
504            enabled: true,
505            global_requests_per_minute: 60, // 1 per second for testing
506            global_burst_size: 5,
507            per_client_requests_per_minute: 30, // 0.5 per second for testing
508            per_client_burst_size: 3,
509            per_tool_requests_per_minute: {
510                let mut map = HashMap::new();
511                map.insert("store_memory".to_string(), 12); // 0.2 per second
512                map.insert("search_memory".to_string(), 60); // 1 per second
513                map
514            },
515            per_tool_burst_size: {
516                let mut map = HashMap::new();
517                map.insert("store_memory".to_string(), 2);
518                map.insert("search_memory".to_string(), 5);
519                map
520            },
521            silent_mode_multiplier: 0.5,
522            whitelist_clients: vec!["whitelisted-client".to_string()],
523            performance_target_ms: 5,
524        }
525    }
526
527    async fn create_test_rate_limiter() -> MCPRateLimiter {
528        let config = create_test_config();
529        let temp_dir = tempdir().unwrap();
530        let audit_config = AuditConfig {
531            enabled: true,
532            log_all_requests: true,
533            log_data_access: true,
534            log_modifications: true,
535            log_auth_events: true,
536            retention_days: 30,
537        };
538        let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
539        MCPRateLimiter::new(config, audit_logger)
540    }
541
542    fn create_test_auth_context(client_id: &str) -> AuthContext {
543        AuthContext {
544            client_id: client_id.to_string(),
545            user_id: "test-user".to_string(),
546            method: AuthMethod::ApiKey,
547            scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
548            expires_at: None,
549            request_id: "test-request".to_string(),
550        }
551    }
552
553    #[tokio::test]
554    async fn test_rate_limit_allows_normal_requests() {
555        let limiter = create_test_rate_limiter().await;
556        let auth_context = create_test_auth_context("test-client");
557
558        // Should allow normal requests
559        let result = limiter
560            .check_rate_limit(Some(&auth_context), "search_memory", false)
561            .await;
562        assert!(result.is_ok());
563    }
564
565    #[tokio::test]
566    async fn test_rate_limit_blocks_excessive_requests() {
567        let limiter = create_test_rate_limiter().await;
568        let auth_context = create_test_auth_context("test-client");
569
570        // Exhaust the burst limit for store_memory (2 requests)
571        assert!(limiter
572            .check_rate_limit(Some(&auth_context), "store_memory", false)
573            .await
574            .is_ok());
575        assert!(limiter
576            .check_rate_limit(Some(&auth_context), "store_memory", false)
577            .await
578            .is_ok());
579
580        // Third request should be rate limited
581        let result = limiter
582            .check_rate_limit(Some(&auth_context), "store_memory", false)
583            .await;
584        assert!(result.is_err());
585
586        // Check that it's specifically a rate limit error
587        let error = result.unwrap_err();
588        assert!(error.to_string().contains("Rate limit exceeded"));
589    }
590
591    #[tokio::test]
592    async fn test_different_clients_have_separate_limits() {
593        let limiter = create_test_rate_limiter().await;
594        let auth_context1 = create_test_auth_context("client-1");
595        let auth_context2 = create_test_auth_context("client-2");
596
597        // Exhaust client-1's limits
598        for _ in 0..3 {
599            let result = limiter
600                .check_rate_limit(Some(&auth_context1), "search_memory", false)
601                .await;
602            if result.is_err() {
603                break;
604            }
605        }
606
607        // client-2 should still be able to make requests
608        let result = limiter
609            .check_rate_limit(Some(&auth_context2), "search_memory", false)
610            .await;
611        assert!(result.is_ok());
612    }
613
614    #[tokio::test]
615    async fn test_whitelisted_clients_bypass_limits() {
616        let limiter = create_test_rate_limiter().await;
617        let auth_context = create_test_auth_context("whitelisted-client");
618
619        // Should be able to make many requests without being rate limited
620        for _ in 0..10 {
621            let result = limiter
622                .check_rate_limit(Some(&auth_context), "store_memory", false)
623                .await;
624            assert!(result.is_ok());
625        }
626    }
627
628    #[tokio::test]
629    async fn test_silent_mode_reduces_limits() {
630        let limiter = create_test_rate_limiter().await;
631        let auth_context = create_test_auth_context("test-client");
632
633        // In silent mode, limits should be reduced by multiplier (0.5)
634        // So burst size should be effectively 1 instead of 2 for store_memory
635        assert!(limiter
636            .check_rate_limit(Some(&auth_context), "store_memory", true)
637            .await
638            .is_ok());
639
640        // Second request should be rate limited in silent mode
641        let result = limiter
642            .check_rate_limit(Some(&auth_context), "store_memory", true)
643            .await;
644        // Note: This might pass depending on the exact timing and implementation
645        // The key is that silent mode should be more restrictive
646    }
647
648    #[tokio::test]
649    async fn test_disabled_rate_limiting() {
650        let mut config = create_test_config();
651        config.enabled = false;
652
653        let temp_dir = tempdir().unwrap();
654        let audit_config = AuditConfig {
655            enabled: true,
656            log_all_requests: true,
657            log_data_access: true,
658            log_modifications: true,
659            log_auth_events: true,
660            retention_days: 30,
661        };
662        let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
663        let limiter = MCPRateLimiter::new(config, audit_logger);
664
665        let auth_context = create_test_auth_context("test-client");
666
667        // Should allow unlimited requests when disabled
668        for _ in 0..20 {
669            let result = limiter
670                .check_rate_limit(Some(&auth_context), "store_memory", false)
671                .await;
672            assert!(result.is_ok());
673        }
674    }
675
676    #[tokio::test]
677    async fn test_statistics_tracking() {
678        let limiter = create_test_rate_limiter().await;
679        let auth_context = create_test_auth_context("test-client");
680
681        // Make some requests
682        let _ = limiter
683            .check_rate_limit(Some(&auth_context), "search_memory", false)
684            .await;
685        let _ = limiter
686            .check_rate_limit(Some(&auth_context), "search_memory", false)
687            .await;
688
689        let stats = limiter.get_stats().await;
690        assert_eq!(stats.total_requests, 2);
691        assert!(stats.avg_check_duration_ms >= 0.0);
692    }
693
694    #[tokio::test]
695    async fn test_client_limit_reset() {
696        let limiter = create_test_rate_limiter().await;
697        let auth_context = create_test_auth_context("test-client");
698
699        // Exhaust only the client limits (3 requests = client burst size)
700        // Use a tool that has higher limits than client limits to avoid tool limit conflicts
701        for _ in 0..3 {
702            let result = limiter
703                .check_rate_limit(Some(&auth_context), "search_memory", false)
704                .await;
705            // First 3 should succeed due to burst
706            assert!(result.is_ok());
707        }
708
709        // 4th request should fail due to client rate limit
710        let result = limiter
711            .check_rate_limit(Some(&auth_context), "search_memory", false)
712            .await;
713        assert!(
714            result.is_err(),
715            "4th request should be rate limited by client limits"
716        );
717
718        // Reset limits for this client
719        limiter.reset_client_limits("test-client").await.unwrap();
720
721        // Should be able to make requests again after reset
722        let result = limiter
723            .check_rate_limit(Some(&auth_context), "search_memory", false)
724            .await;
725        assert!(
726            result.is_ok(),
727            "Request should succeed after client limit reset"
728        );
729    }
730}