1use crate::mcp_server::auth::AuthContext;
7use crate::security::{audit::AuditLogger, SecurityError};
8use anyhow::Result;
9use governor::{
10 clock::DefaultClock,
11 middleware::NoOpMiddleware,
12 state::{InMemoryState, NotKeyed},
13 Quota, RateLimiter as GovernorRateLimiter,
14};
15use nonzero_ext::nonzero;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::env;
19use std::num::NonZeroU32;
20use std::sync::Arc;
21use tokio::sync::RwLock;
22use tracing::{debug, warn};
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MCPRateLimitConfig {
27 pub enabled: bool,
28 pub global_requests_per_minute: u32,
29 pub global_burst_size: u32,
30 pub per_client_requests_per_minute: u32,
31 pub per_client_burst_size: u32,
32 pub per_tool_requests_per_minute: HashMap<String, u32>,
33 pub per_tool_burst_size: HashMap<String, u32>,
34 pub silent_mode_multiplier: f64,
35 pub whitelist_clients: Vec<String>,
36 pub performance_target_ms: u64,
37}
38
39impl Default for MCPRateLimitConfig {
40 fn default() -> Self {
41 Self {
42 enabled: env::var("MCP_RATE_LIMIT_ENABLED")
43 .map(|s| s.parse().unwrap_or(true))
44 .unwrap_or(true),
45 global_requests_per_minute: env::var("MCP_GLOBAL_RATE_LIMIT")
46 .ok()
47 .and_then(|s| s.parse().ok())
48 .unwrap_or(1000),
49 global_burst_size: env::var("MCP_GLOBAL_BURST_SIZE")
50 .ok()
51 .and_then(|s| s.parse().ok())
52 .unwrap_or(50),
53 per_client_requests_per_minute: env::var("MCP_CLIENT_RATE_LIMIT")
54 .ok()
55 .and_then(|s| s.parse().ok())
56 .unwrap_or(100),
57 per_client_burst_size: env::var("MCP_CLIENT_BURST_SIZE")
58 .ok()
59 .and_then(|s| s.parse().ok())
60 .unwrap_or(10),
61 per_tool_requests_per_minute: Self::load_tool_rates_from_env(),
62 per_tool_burst_size: Self::load_tool_bursts_from_env(),
63 silent_mode_multiplier: env::var("MCP_SILENT_MODE_MULTIPLIER")
64 .ok()
65 .and_then(|s| s.parse().ok())
66 .unwrap_or(0.5), whitelist_clients: env::var("MCP_RATE_LIMIT_WHITELIST")
68 .map(|s| s.split(',').map(|c| c.trim().to_string()).collect())
69 .unwrap_or_default(),
70 performance_target_ms: 5, }
72 }
73}
74
75impl MCPRateLimitConfig {
76 fn load_tool_rates_from_env() -> HashMap<String, u32> {
78 let mut rates = HashMap::new();
79
80 rates.insert("store_memory".to_string(), 50);
82 rates.insert("search_memory".to_string(), 200);
83 rates.insert("get_statistics".to_string(), 20);
84 rates.insert("what_did_you_remember".to_string(), 30);
85 rates.insert("harvest_conversation".to_string(), 100);
86 rates.insert("get_harvester_metrics".to_string(), 10);
87 rates.insert("migrate_memory".to_string(), 20);
88 rates.insert("delete_memory".to_string(), 10);
89
90 if let Ok(custom_rates) = env::var("MCP_TOOL_RATE_LIMITS") {
92 if let Ok(parsed) = serde_json::from_str::<HashMap<String, u32>>(&custom_rates) {
93 rates.extend(parsed);
94 }
95 }
96
97 rates
98 }
99
100 fn load_tool_bursts_from_env() -> HashMap<String, u32> {
102 let mut bursts = HashMap::new();
103
104 bursts.insert("store_memory".to_string(), 5);
106 bursts.insert("search_memory".to_string(), 20);
107 bursts.insert("get_statistics".to_string(), 2);
108 bursts.insert("what_did_you_remember".to_string(), 3);
109 bursts.insert("harvest_conversation".to_string(), 10);
110 bursts.insert("get_harvester_metrics".to_string(), 1);
111 bursts.insert("migrate_memory".to_string(), 2);
112 bursts.insert("delete_memory".to_string(), 1);
113
114 if let Ok(custom_bursts) = env::var("MCP_TOOL_BURST_SIZES") {
116 if let Ok(parsed) = serde_json::from_str::<HashMap<String, u32>>(&custom_bursts) {
117 bursts.extend(parsed);
118 }
119 }
120
121 bursts
122 }
123
124 pub fn from_env() -> Self {
126 Self::default()
127 }
128}
129
130#[derive(Debug, Clone, Serialize)]
132pub struct RateLimitStats {
133 pub total_requests: u64,
134 pub rejected_requests: u64,
135 pub rejection_rate: f64,
136 pub per_client_rejections: HashMap<String, u64>,
137 pub per_tool_rejections: HashMap<String, u64>,
138 pub avg_check_duration_ms: f64,
139 pub peak_requests_per_minute: u64,
140}
141
142pub struct ScopedRateLimiter {
144 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
145 requests_per_minute: u32,
146 burst_size: u32,
147 name: String,
148}
149
150impl ScopedRateLimiter {
151 fn new(requests_per_minute: u32, burst_size: u32, name: String) -> Self {
152 let rate = NonZeroU32::new(requests_per_minute).unwrap_or(nonzero!(1u32));
153 let burst = NonZeroU32::new(burst_size).unwrap_or(nonzero!(1u32));
154 let quota = Quota::per_minute(rate).allow_burst(burst);
155 let limiter = Arc::new(GovernorRateLimiter::direct(quota));
156
157 Self {
158 limiter,
159 requests_per_minute,
160 burst_size,
161 name,
162 }
163 }
164
165 fn check_rate_limit(&self) -> Result<(), governor::NotUntil<governor::clock::QuantaInstant>> {
166 self.limiter.check()
167 }
168}
169
170pub struct MCPRateLimiter {
172 config: MCPRateLimitConfig,
173 global_limiter: Option<ScopedRateLimiter>,
174 client_limiters: Arc<RwLock<HashMap<String, ScopedRateLimiter>>>,
175 tool_limiters: HashMap<String, ScopedRateLimiter>,
176 stats: Arc<RwLock<RateLimitStats>>,
177 audit_logger: Arc<AuditLogger>,
178}
179
180impl MCPRateLimiter {
181 pub fn new(config: MCPRateLimitConfig, audit_logger: Arc<AuditLogger>) -> Self {
183 let global_limiter = if config.enabled {
184 Some(ScopedRateLimiter::new(
185 config.global_requests_per_minute,
186 config.global_burst_size,
187 "global".to_string(),
188 ))
189 } else {
190 None
191 };
192
193 let mut tool_limiters = HashMap::new();
194 if config.enabled {
195 for (tool_name, &rate) in &config.per_tool_requests_per_minute {
196 let burst = config
197 .per_tool_burst_size
198 .get(tool_name)
199 .copied()
200 .unwrap_or(rate / 10); tool_limiters.insert(
203 tool_name.clone(),
204 ScopedRateLimiter::new(rate, burst, format!("tool:{tool_name}")),
205 );
206 }
207 }
208
209 let stats = Arc::new(RwLock::new(RateLimitStats {
210 total_requests: 0,
211 rejected_requests: 0,
212 rejection_rate: 0.0,
213 per_client_rejections: HashMap::new(),
214 per_tool_rejections: HashMap::new(),
215 avg_check_duration_ms: 0.0,
216 peak_requests_per_minute: 0,
217 }));
218
219 Self {
220 config,
221 global_limiter,
222 client_limiters: Arc::new(RwLock::new(HashMap::new())),
223 tool_limiters,
224 stats,
225 audit_logger,
226 }
227 }
228
229 pub async fn check_rate_limit(
231 &self,
232 auth_context: Option<&AuthContext>,
233 tool_name: &str,
234 silent_mode: bool,
235 ) -> Result<()> {
236 let start_time = std::time::Instant::now();
237
238 if !self.config.enabled {
239 return Ok(());
240 }
241
242 let client_id = auth_context
243 .map(|ctx| ctx.client_id.as_str())
244 .unwrap_or("anonymous");
245
246 if self
248 .config
249 .whitelist_clients
250 .contains(&client_id.to_string())
251 {
252 debug!("Client {} is whitelisted, skipping rate limits", client_id);
253 return Ok(());
254 }
255
256 {
258 let mut stats = self.stats.write().await;
259 stats.total_requests += 1;
260 }
261
262 let rate_multiplier = if silent_mode {
264 self.config.silent_mode_multiplier
265 } else {
266 1.0
267 };
268
269 if let Some(ref global_limiter) = self.global_limiter {
271 if global_limiter.check_rate_limit().is_err() {
272 self.handle_rate_limit_violation("global", client_id, tool_name)
273 .await;
274 return Err(SecurityError::RateLimitExceeded.into());
275 }
276 }
277
278 let client_limiter = self
280 .get_or_create_client_limiter(client_id, rate_multiplier)
281 .await;
282 if client_limiter.check_rate_limit().is_err() {
283 self.handle_rate_limit_violation("client", client_id, tool_name)
284 .await;
285 return Err(SecurityError::RateLimitExceeded.into());
286 }
287
288 if let Some(tool_limiter) = self.tool_limiters.get(tool_name) {
290 if tool_limiter.check_rate_limit().is_err() {
291 self.handle_rate_limit_violation("tool", client_id, tool_name)
292 .await;
293 return Err(SecurityError::RateLimitExceeded.into());
294 }
295 }
296
297 let elapsed = start_time.elapsed();
298
299 if elapsed.as_millis() > self.config.performance_target_ms as u128 {
301 warn!(
302 "Rate limit check took {}ms, exceeding target of {}ms",
303 elapsed.as_millis(),
304 self.config.performance_target_ms
305 );
306 }
307
308 {
310 let mut stats = self.stats.write().await;
311 let total_ms = stats.avg_check_duration_ms * (stats.total_requests - 1) as f64;
312 stats.avg_check_duration_ms =
313 (total_ms + elapsed.as_millis() as f64) / stats.total_requests as f64;
314 }
315
316 debug!(
317 "Rate limit check passed for client: {}, tool: {}",
318 client_id, tool_name
319 );
320 Ok(())
321 }
322
323 async fn get_or_create_client_limiter(
325 &self,
326 client_id: &str,
327 rate_multiplier: f64,
328 ) -> ScopedRateLimiter {
329 {
330 let limiters = self.client_limiters.read().await;
331 if let Some(limiter) = limiters.get(client_id) {
332 return ScopedRateLimiter {
333 limiter: limiter.limiter.clone(),
334 requests_per_minute: limiter.requests_per_minute,
335 burst_size: limiter.burst_size,
336 name: limiter.name.clone(),
337 };
338 }
339 }
340
341 let adjusted_rate =
343 (self.config.per_client_requests_per_minute as f64 * rate_multiplier) as u32;
344 let adjusted_burst = (self.config.per_client_burst_size as f64 * rate_multiplier) as u32;
345
346 let limiter = ScopedRateLimiter::new(
347 adjusted_rate.max(1),
348 adjusted_burst.max(1),
349 format!("client:{client_id}"),
350 );
351
352 {
354 let mut limiters = self.client_limiters.write().await;
355 limiters.insert(
356 client_id.to_string(),
357 ScopedRateLimiter {
358 limiter: limiter.limiter.clone(),
359 requests_per_minute: limiter.requests_per_minute,
360 burst_size: limiter.burst_size,
361 name: limiter.name.clone(),
362 },
363 );
364 }
365
366 limiter
367 }
368
369 async fn handle_rate_limit_violation(
371 &self,
372 limit_type: &str,
373 client_id: &str,
374 tool_name: &str,
375 ) {
376 warn!(
377 "Rate limit violation - Type: {}, Client: {}, Tool: {}",
378 limit_type, client_id, tool_name
379 );
380
381 {
383 let mut stats = self.stats.write().await;
384 stats.rejected_requests += 1;
385 stats.rejection_rate = stats.rejected_requests as f64 / stats.total_requests as f64;
386
387 *stats
388 .per_client_rejections
389 .entry(client_id.to_string())
390 .or_insert(0) += 1;
391 *stats
392 .per_tool_rejections
393 .entry(tool_name.to_string())
394 .or_insert(0) += 1;
395 }
396
397 self.audit_logger
399 .log_rate_limit_violation(client_id, tool_name, limit_type)
400 .await;
401 }
402
403 pub async fn reset_client_limits(&self, client_id: &str) -> Result<()> {
405 let mut limiters = self.client_limiters.write().await;
406 limiters.remove(client_id);
407
408 let fresh_limiter = ScopedRateLimiter::new(
410 self.config.per_client_requests_per_minute,
411 self.config.per_client_burst_size,
412 format!("client:{client_id}"),
413 );
414
415 limiters.insert(client_id.to_string(), fresh_limiter);
416 debug!("Reset rate limits for client: {}", client_id);
417 Ok(())
418 }
419
420 pub async fn get_stats(&self) -> RateLimitStats {
422 self.stats.read().await.clone()
423 }
424
425 pub async fn update_config(&mut self, new_config: MCPRateLimitConfig) -> Result<()> {
427 debug!("Updating rate limiter configuration");
428
429 self.global_limiter = if new_config.enabled {
431 Some(ScopedRateLimiter::new(
432 new_config.global_requests_per_minute,
433 new_config.global_burst_size,
434 "global".to_string(),
435 ))
436 } else {
437 None
438 };
439
440 self.tool_limiters.clear();
442 if new_config.enabled {
443 for (tool_name, &rate) in &new_config.per_tool_requests_per_minute {
444 let burst = new_config
445 .per_tool_burst_size
446 .get(tool_name)
447 .copied()
448 .unwrap_or(rate / 10);
449
450 self.tool_limiters.insert(
451 tool_name.clone(),
452 ScopedRateLimiter::new(rate, burst, format!("tool:{tool_name}")),
453 );
454 }
455 }
456
457 {
459 let mut limiters = self.client_limiters.write().await;
460 limiters.clear();
461 }
462
463 self.config = new_config;
464 Ok(())
465 }
466
467 pub async fn get_status(&self) -> serde_json::Value {
469 let stats = self.get_stats().await;
470 let client_count = self.client_limiters.read().await.len();
471
472 serde_json::json!({
473 "enabled": self.config.enabled,
474 "global_limits": {
475 "requests_per_minute": self.config.global_requests_per_minute,
476 "burst_size": self.config.global_burst_size,
477 },
478 "per_client_limits": {
479 "requests_per_minute": self.config.per_client_requests_per_minute,
480 "burst_size": self.config.per_client_burst_size,
481 "active_clients": client_count,
482 },
483 "tool_limits": self.config.per_tool_requests_per_minute,
484 "statistics": stats,
485 "performance": {
486 "target_ms": self.config.performance_target_ms,
487 "avg_check_duration_ms": stats.avg_check_duration_ms,
488 },
489 "silent_mode_multiplier": self.config.silent_mode_multiplier,
490 "whitelist_clients": self.config.whitelist_clients.len(),
491 })
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::mcp_server::auth::AuthMethod;
499 use crate::security::AuditConfig;
500 use tempfile::tempdir;
501
502 fn create_test_config() -> MCPRateLimitConfig {
503 MCPRateLimitConfig {
504 enabled: true,
505 global_requests_per_minute: 60, global_burst_size: 5,
507 per_client_requests_per_minute: 30, per_client_burst_size: 3,
509 per_tool_requests_per_minute: {
510 let mut map = HashMap::new();
511 map.insert("store_memory".to_string(), 12); map.insert("search_memory".to_string(), 60); map
514 },
515 per_tool_burst_size: {
516 let mut map = HashMap::new();
517 map.insert("store_memory".to_string(), 2);
518 map.insert("search_memory".to_string(), 5);
519 map
520 },
521 silent_mode_multiplier: 0.5,
522 whitelist_clients: vec!["whitelisted-client".to_string()],
523 performance_target_ms: 5,
524 }
525 }
526
527 async fn create_test_rate_limiter() -> MCPRateLimiter {
528 let config = create_test_config();
529 let temp_dir = tempdir().unwrap();
530 let audit_config = AuditConfig {
531 enabled: true,
532 log_all_requests: true,
533 log_data_access: true,
534 log_modifications: true,
535 log_auth_events: true,
536 retention_days: 30,
537 };
538 let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
539 MCPRateLimiter::new(config, audit_logger)
540 }
541
542 fn create_test_auth_context(client_id: &str) -> AuthContext {
543 AuthContext {
544 client_id: client_id.to_string(),
545 user_id: "test-user".to_string(),
546 method: AuthMethod::ApiKey,
547 scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
548 expires_at: None,
549 request_id: "test-request".to_string(),
550 }
551 }
552
553 #[tokio::test]
554 async fn test_rate_limit_allows_normal_requests() {
555 let limiter = create_test_rate_limiter().await;
556 let auth_context = create_test_auth_context("test-client");
557
558 let result = limiter
560 .check_rate_limit(Some(&auth_context), "search_memory", false)
561 .await;
562 assert!(result.is_ok());
563 }
564
565 #[tokio::test]
566 async fn test_rate_limit_blocks_excessive_requests() {
567 let limiter = create_test_rate_limiter().await;
568 let auth_context = create_test_auth_context("test-client");
569
570 assert!(limiter
572 .check_rate_limit(Some(&auth_context), "store_memory", false)
573 .await
574 .is_ok());
575 assert!(limiter
576 .check_rate_limit(Some(&auth_context), "store_memory", false)
577 .await
578 .is_ok());
579
580 let result = limiter
582 .check_rate_limit(Some(&auth_context), "store_memory", false)
583 .await;
584 assert!(result.is_err());
585
586 let error = result.unwrap_err();
588 assert!(error.to_string().contains("Rate limit exceeded"));
589 }
590
591 #[tokio::test]
592 async fn test_different_clients_have_separate_limits() {
593 let limiter = create_test_rate_limiter().await;
594 let auth_context1 = create_test_auth_context("client-1");
595 let auth_context2 = create_test_auth_context("client-2");
596
597 for _ in 0..3 {
599 let result = limiter
600 .check_rate_limit(Some(&auth_context1), "search_memory", false)
601 .await;
602 if result.is_err() {
603 break;
604 }
605 }
606
607 let result = limiter
609 .check_rate_limit(Some(&auth_context2), "search_memory", false)
610 .await;
611 assert!(result.is_ok());
612 }
613
614 #[tokio::test]
615 async fn test_whitelisted_clients_bypass_limits() {
616 let limiter = create_test_rate_limiter().await;
617 let auth_context = create_test_auth_context("whitelisted-client");
618
619 for _ in 0..10 {
621 let result = limiter
622 .check_rate_limit(Some(&auth_context), "store_memory", false)
623 .await;
624 assert!(result.is_ok());
625 }
626 }
627
628 #[tokio::test]
629 async fn test_silent_mode_reduces_limits() {
630 let limiter = create_test_rate_limiter().await;
631 let auth_context = create_test_auth_context("test-client");
632
633 assert!(limiter
636 .check_rate_limit(Some(&auth_context), "store_memory", true)
637 .await
638 .is_ok());
639
640 let result = limiter
642 .check_rate_limit(Some(&auth_context), "store_memory", true)
643 .await;
644 }
647
648 #[tokio::test]
649 async fn test_disabled_rate_limiting() {
650 let mut config = create_test_config();
651 config.enabled = false;
652
653 let temp_dir = tempdir().unwrap();
654 let audit_config = AuditConfig {
655 enabled: true,
656 log_all_requests: true,
657 log_data_access: true,
658 log_modifications: true,
659 log_auth_events: true,
660 retention_days: 30,
661 };
662 let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
663 let limiter = MCPRateLimiter::new(config, audit_logger);
664
665 let auth_context = create_test_auth_context("test-client");
666
667 for _ in 0..20 {
669 let result = limiter
670 .check_rate_limit(Some(&auth_context), "store_memory", false)
671 .await;
672 assert!(result.is_ok());
673 }
674 }
675
676 #[tokio::test]
677 async fn test_statistics_tracking() {
678 let limiter = create_test_rate_limiter().await;
679 let auth_context = create_test_auth_context("test-client");
680
681 let _ = limiter
683 .check_rate_limit(Some(&auth_context), "search_memory", false)
684 .await;
685 let _ = limiter
686 .check_rate_limit(Some(&auth_context), "search_memory", false)
687 .await;
688
689 let stats = limiter.get_stats().await;
690 assert_eq!(stats.total_requests, 2);
691 assert!(stats.avg_check_duration_ms >= 0.0);
692 }
693
694 #[tokio::test]
695 async fn test_client_limit_reset() {
696 let limiter = create_test_rate_limiter().await;
697 let auth_context = create_test_auth_context("test-client");
698
699 for _ in 0..3 {
702 let result = limiter
703 .check_rate_limit(Some(&auth_context), "search_memory", false)
704 .await;
705 assert!(result.is_ok());
707 }
708
709 let result = limiter
711 .check_rate_limit(Some(&auth_context), "search_memory", false)
712 .await;
713 assert!(
714 result.is_err(),
715 "4th request should be rate limited by client limits"
716 );
717
718 limiter.reset_client_limits("test-client").await.unwrap();
720
721 let result = limiter
723 .check_rate_limit(Some(&auth_context), "search_memory", false)
724 .await;
725 assert!(
726 result.is_ok(),
727 "Request should succeed after client limit reset"
728 );
729 }
730}