1use std::collections::HashMap;
32use std::hash::{Hash, Hasher};
33use std::sync::Mutex;
34use std::time::{Duration, Instant};
35
36use fastmcp_core::{McpContext, McpError, McpResult};
37use fastmcp_protocol::JsonRpcRequest;
38
39use crate::{Middleware, MiddlewareDecision};
40
41pub const DEFAULT_LIST_TTL_SECS: u64 = 300;
43
44pub const DEFAULT_CALL_TTL_SECS: u64 = 3600;
46
47pub const DEFAULT_MAX_ITEM_SIZE: usize = 1024 * 1024;
49
50#[derive(Debug, Clone)]
52struct CacheEntry {
53 value: serde_json::Value,
54 expires_at: Instant,
55 size_bytes: usize,
56}
57
58impl CacheEntry {
59 fn new(value: serde_json::Value, ttl: Duration) -> Self {
60 let size_bytes = value.to_string().len();
61 Self {
62 value,
63 expires_at: Instant::now() + ttl,
64 size_bytes,
65 }
66 }
67
68 fn is_expired(&self) -> bool {
69 Instant::now() > self.expires_at
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75struct CacheKey {
76 method: String,
77 params_hash: u64,
78}
79
80impl CacheKey {
81 fn new(method: &str, params: Option<&serde_json::Value>) -> Self {
82 let params_hash = match params {
83 Some(v) => hash_json_value(v),
84 None => 0,
85 };
86 Self {
87 method: method.to_string(),
88 params_hash,
89 }
90 }
91}
92
93fn hash_json_value(value: &serde_json::Value) -> u64 {
95 use std::collections::hash_map::DefaultHasher;
96
97 let mut hasher = DefaultHasher::new();
98
99 let json_str = serde_json::to_string(value).unwrap_or_default();
102 json_str.hash(&mut hasher);
103
104 hasher.finish()
105}
106
107#[derive(Debug, Clone)]
109pub struct MethodCacheConfig {
110 pub enabled: bool,
112 pub ttl_secs: u64,
114}
115
116impl Default for MethodCacheConfig {
117 fn default() -> Self {
118 Self {
119 enabled: true,
120 ttl_secs: DEFAULT_CALL_TTL_SECS,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct ToolCallCacheConfig {
128 pub base: MethodCacheConfig,
130 pub included_tools: Vec<String>,
132 pub excluded_tools: Vec<String>,
134}
135
136impl ToolCallCacheConfig {
137 fn should_cache_tool(&self, tool_name: &str) -> bool {
139 if !self.base.enabled {
140 return false;
141 }
142
143 if self.excluded_tools.contains(&tool_name.to_string()) {
145 return false;
146 }
147
148 if !self.included_tools.is_empty() {
150 return self.included_tools.contains(&tool_name.to_string());
151 }
152
153 true
155 }
156}
157
158#[derive(Debug)]
160struct LruCache {
161 entries: HashMap<CacheKey, CacheEntry>,
163 order: Vec<CacheKey>,
165 max_entries: usize,
167 max_size_bytes: usize,
169 max_item_size: usize,
171 current_size_bytes: usize,
173}
174
175impl LruCache {
176 fn new(max_entries: usize, max_size_bytes: usize, max_item_size: usize) -> Self {
177 Self {
178 entries: HashMap::new(),
179 order: Vec::new(),
180 max_entries,
181 max_size_bytes,
182 max_item_size,
183 current_size_bytes: 0,
184 }
185 }
186
187 fn get(&mut self, key: &CacheKey) -> Option<serde_json::Value> {
188 if let Some(entry) = self.entries.get(key) {
190 if entry.is_expired() {
191 self.remove(key);
193 return None;
194 }
195
196 if let Some(pos) = self.order.iter().position(|k| k == key) {
198 let k = self.order.remove(pos);
199 self.order.push(k);
200 }
201
202 return Some(entry.value.clone());
203 }
204 None
205 }
206
207 fn insert(&mut self, key: CacheKey, value: serde_json::Value, ttl: Duration) {
208 let entry = CacheEntry::new(value, ttl);
209
210 if entry.size_bytes > self.max_item_size {
212 return;
214 }
215
216 if self.entries.contains_key(&key) {
218 self.remove(&key);
219 }
220
221 while self.entries.len() >= self.max_entries
223 || self.current_size_bytes + entry.size_bytes > self.max_size_bytes
224 {
225 if self.order.is_empty() {
226 break;
227 }
228 let oldest_key = self.order.remove(0);
230 if let Some(old_entry) = self.entries.remove(&oldest_key) {
231 self.current_size_bytes -= old_entry.size_bytes;
232 }
233 }
234
235 self.evict_expired();
237
238 self.current_size_bytes += entry.size_bytes;
240 self.entries.insert(key.clone(), entry);
241 self.order.push(key);
242 }
243
244 fn remove(&mut self, key: &CacheKey) {
245 if let Some(entry) = self.entries.remove(key) {
246 self.current_size_bytes -= entry.size_bytes;
247 if let Some(pos) = self.order.iter().position(|k| k == key) {
248 self.order.remove(pos);
249 }
250 }
251 }
252
253 fn evict_expired(&mut self) {
254 let expired_keys: Vec<CacheKey> = self
255 .entries
256 .iter()
257 .filter(|(_, entry)| entry.is_expired())
258 .map(|(key, _)| key.clone())
259 .collect();
260
261 for key in expired_keys {
262 self.remove(&key);
263 }
264 }
265
266 fn clear(&mut self) {
267 self.entries.clear();
268 self.order.clear();
269 self.current_size_bytes = 0;
270 }
271
272 fn len(&self) -> usize {
273 self.entries.len()
274 }
275
276 #[allow(dead_code)]
277 fn is_empty(&self) -> bool {
278 self.entries.is_empty()
279 }
280}
281
282#[derive(Debug, Clone, Default)]
284pub struct CacheStats {
285 pub hits: u64,
287 pub misses: u64,
289 pub entries: usize,
291 pub size_bytes: usize,
293}
294
295impl CacheStats {
296 #[must_use]
298 pub fn hit_rate(&self) -> f64 {
299 let total = self.hits + self.misses;
300 if total == 0 {
301 0.0
302 } else {
303 (self.hits as f64 / total as f64) * 100.0
304 }
305 }
306}
307
308pub struct ResponseCachingMiddleware {
313 cache: Mutex<LruCache>,
315 list_ttl: Duration,
317 call_ttl: Duration,
319 tools_list_config: MethodCacheConfig,
321 resources_list_config: MethodCacheConfig,
323 prompts_list_config: MethodCacheConfig,
325 tools_call_config: ToolCallCacheConfig,
327 resources_read_config: MethodCacheConfig,
329 prompts_get_config: MethodCacheConfig,
331 stats: Mutex<CacheStats>,
333}
334
335impl std::fmt::Debug for ResponseCachingMiddleware {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.debug_struct("ResponseCachingMiddleware")
338 .field("list_ttl", &self.list_ttl)
339 .field("call_ttl", &self.call_ttl)
340 .finish_non_exhaustive()
341 }
342}
343
344impl Default for ResponseCachingMiddleware {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350impl ResponseCachingMiddleware {
351 #[must_use]
353 pub fn new() -> Self {
354 Self {
355 cache: Mutex::new(LruCache::new(
356 1000,
357 100 * 1024 * 1024,
358 DEFAULT_MAX_ITEM_SIZE,
359 )),
360 list_ttl: Duration::from_secs(DEFAULT_LIST_TTL_SECS),
361 call_ttl: Duration::from_secs(DEFAULT_CALL_TTL_SECS),
362 tools_list_config: MethodCacheConfig {
363 enabled: true,
364 ttl_secs: DEFAULT_LIST_TTL_SECS,
365 },
366 resources_list_config: MethodCacheConfig {
367 enabled: true,
368 ttl_secs: DEFAULT_LIST_TTL_SECS,
369 },
370 prompts_list_config: MethodCacheConfig {
371 enabled: true,
372 ttl_secs: DEFAULT_LIST_TTL_SECS,
373 },
374 tools_call_config: ToolCallCacheConfig {
375 base: MethodCacheConfig {
376 enabled: true,
377 ttl_secs: DEFAULT_CALL_TTL_SECS,
378 },
379 included_tools: Vec::new(),
380 excluded_tools: Vec::new(),
381 },
382 resources_read_config: MethodCacheConfig {
383 enabled: true,
384 ttl_secs: DEFAULT_CALL_TTL_SECS,
385 },
386 prompts_get_config: MethodCacheConfig {
387 enabled: true,
388 ttl_secs: DEFAULT_CALL_TTL_SECS,
389 },
390 stats: Mutex::new(CacheStats::default()),
391 }
392 }
393
394 #[must_use]
396 pub fn max_entries(self, max: usize) -> Self {
397 let max_size = {
398 let cache = self
399 .cache
400 .lock()
401 .unwrap_or_else(std::sync::PoisonError::into_inner);
402 cache.max_size_bytes
403 };
404 let max_item_size = {
405 let cache = self
406 .cache
407 .lock()
408 .unwrap_or_else(std::sync::PoisonError::into_inner);
409 cache.max_item_size
410 };
411 Self {
412 cache: Mutex::new(LruCache::new(max, max_size, max_item_size)),
413 ..self
414 }
415 }
416
417 #[must_use]
419 pub fn max_size_bytes(self, max: usize) -> Self {
420 let max_entries = {
421 let cache = self
422 .cache
423 .lock()
424 .unwrap_or_else(std::sync::PoisonError::into_inner);
425 cache.max_entries
426 };
427 let max_item_size = {
428 let cache = self
429 .cache
430 .lock()
431 .unwrap_or_else(std::sync::PoisonError::into_inner);
432 cache.max_item_size
433 };
434 Self {
435 cache: Mutex::new(LruCache::new(max_entries, max, max_item_size)),
436 ..self
437 }
438 }
439
440 #[must_use]
442 pub fn max_item_size(self, max: usize) -> Self {
443 let max_entries = {
444 let cache = self
445 .cache
446 .lock()
447 .unwrap_or_else(std::sync::PoisonError::into_inner);
448 cache.max_entries
449 };
450 let max_size = {
451 let cache = self
452 .cache
453 .lock()
454 .unwrap_or_else(std::sync::PoisonError::into_inner);
455 cache.max_size_bytes
456 };
457 Self {
458 cache: Mutex::new(LruCache::new(max_entries, max_size, max)),
459 ..self
460 }
461 }
462
463 #[must_use]
465 pub fn list_ttl_secs(mut self, secs: u64) -> Self {
466 self.list_ttl = Duration::from_secs(secs);
467 self.tools_list_config.ttl_secs = secs;
468 self.resources_list_config.ttl_secs = secs;
469 self.prompts_list_config.ttl_secs = secs;
470 self
471 }
472
473 #[must_use]
475 pub fn call_ttl_secs(mut self, secs: u64) -> Self {
476 self.call_ttl = Duration::from_secs(secs);
477 self.tools_call_config.base.ttl_secs = secs;
478 self.resources_read_config.ttl_secs = secs;
479 self.prompts_get_config.ttl_secs = secs;
480 self
481 }
482
483 #[must_use]
485 pub fn disable_tools_list(mut self) -> Self {
486 self.tools_list_config.enabled = false;
487 self
488 }
489
490 #[must_use]
492 pub fn disable_resources_list(mut self) -> Self {
493 self.resources_list_config.enabled = false;
494 self
495 }
496
497 #[must_use]
499 pub fn disable_prompts_list(mut self) -> Self {
500 self.prompts_list_config.enabled = false;
501 self
502 }
503
504 #[must_use]
506 pub fn disable_tools_call(mut self) -> Self {
507 self.tools_call_config.base.enabled = false;
508 self
509 }
510
511 #[must_use]
513 pub fn disable_resources_read(mut self) -> Self {
514 self.resources_read_config.enabled = false;
515 self
516 }
517
518 #[must_use]
520 pub fn disable_prompts_get(mut self) -> Self {
521 self.prompts_get_config.enabled = false;
522 self
523 }
524
525 #[must_use]
527 pub fn include_tools(mut self, tools: Vec<String>) -> Self {
528 self.tools_call_config.included_tools = tools;
529 self
530 }
531
532 #[must_use]
534 pub fn exclude_tools(mut self, tools: Vec<String>) -> Self {
535 self.tools_call_config.excluded_tools = tools;
536 self
537 }
538
539 #[must_use]
541 pub fn stats(&self) -> CacheStats {
542 let cache = self
543 .cache
544 .lock()
545 .unwrap_or_else(std::sync::PoisonError::into_inner);
546 let mut stats = self
547 .stats
548 .lock()
549 .unwrap_or_else(std::sync::PoisonError::into_inner)
550 .clone();
551 stats.entries = cache.len();
552 stats.size_bytes = cache.current_size_bytes;
553 stats
554 }
555
556 pub fn clear(&self) {
558 let mut cache = self
559 .cache
560 .lock()
561 .unwrap_or_else(std::sync::PoisonError::into_inner);
562 cache.clear();
563 }
564
565 pub fn invalidate(&self, method: &str, params: Option<&serde_json::Value>) {
567 let key = CacheKey::new(method, params);
568 let mut cache = self
569 .cache
570 .lock()
571 .unwrap_or_else(std::sync::PoisonError::into_inner);
572 cache.remove(&key);
573 }
574
575 fn should_cache_method(&self, method: &str, params: Option<&serde_json::Value>) -> bool {
577 match method {
578 "tools/list" => self.tools_list_config.enabled,
579 "resources/list" => self.resources_list_config.enabled,
580 "prompts/list" => self.prompts_list_config.enabled,
581 "resources/read" => self.resources_read_config.enabled,
582 "prompts/get" => self.prompts_get_config.enabled,
583 "tools/call" => {
584 if !self.tools_call_config.base.enabled {
585 return false;
586 }
587 if let Some(params) = params {
589 if let Some(tool_name) = params.get("name").and_then(|v| v.as_str()) {
590 return self.tools_call_config.should_cache_tool(tool_name);
591 }
592 }
593 false
594 }
595 _ => false,
596 }
597 }
598
599 fn get_ttl(&self, method: &str) -> Duration {
601 match method {
602 "tools/list" => Duration::from_secs(self.tools_list_config.ttl_secs),
603 "resources/list" => Duration::from_secs(self.resources_list_config.ttl_secs),
604 "prompts/list" => Duration::from_secs(self.prompts_list_config.ttl_secs),
605 "tools/call" => Duration::from_secs(self.tools_call_config.base.ttl_secs),
606 "resources/read" => Duration::from_secs(self.resources_read_config.ttl_secs),
607 "prompts/get" => Duration::from_secs(self.prompts_get_config.ttl_secs),
608 _ => self.call_ttl,
609 }
610 }
611
612 fn record_hit(&self) {
613 let mut stats = self
614 .stats
615 .lock()
616 .unwrap_or_else(std::sync::PoisonError::into_inner);
617 stats.hits += 1;
618 }
619
620 fn record_miss(&self) {
621 let mut stats = self
622 .stats
623 .lock()
624 .unwrap_or_else(std::sync::PoisonError::into_inner);
625 stats.misses += 1;
626 }
627}
628
629impl Middleware for ResponseCachingMiddleware {
630 fn on_request(
631 &self,
632 _ctx: &McpContext,
633 request: &JsonRpcRequest,
634 ) -> McpResult<MiddlewareDecision> {
635 if !self.should_cache_method(&request.method, request.params.as_ref()) {
637 return Ok(MiddlewareDecision::Continue);
638 }
639
640 let key = CacheKey::new(&request.method, request.params.as_ref());
642 let mut cache = self
643 .cache
644 .lock()
645 .unwrap_or_else(std::sync::PoisonError::into_inner);
646
647 if let Some(value) = cache.get(&key) {
648 self.record_hit();
649 return Ok(MiddlewareDecision::Respond(value));
650 }
651
652 self.record_miss();
653 Ok(MiddlewareDecision::Continue)
654 }
655
656 fn on_response(
657 &self,
658 _ctx: &McpContext,
659 request: &JsonRpcRequest,
660 response: serde_json::Value,
661 ) -> McpResult<serde_json::Value> {
662 if !self.should_cache_method(&request.method, request.params.as_ref()) {
664 return Ok(response);
665 }
666
667 let key = CacheKey::new(&request.method, request.params.as_ref());
669 let ttl = self.get_ttl(&request.method);
670
671 let mut cache = self
672 .cache
673 .lock()
674 .unwrap_or_else(std::sync::PoisonError::into_inner);
675
676 cache.insert(key, response.clone(), ttl);
677
678 Ok(response)
679 }
680
681 fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
682 error
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use asupersync::Cx;
691
692 fn test_context() -> McpContext {
693 let cx = Cx::for_testing();
694 McpContext::new(cx, 1)
695 }
696
697 fn test_request(method: &str, params: Option<serde_json::Value>) -> JsonRpcRequest {
698 JsonRpcRequest {
699 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
700 method: method.to_string(),
701 params,
702 id: Some(fastmcp_protocol::RequestId::Number(1)),
703 }
704 }
705
706 #[test]
711 fn test_lru_cache_basic_operations() {
712 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
713
714 let key = CacheKey::new("test", None);
715 let value = serde_json::json!({"result": "cached"});
716
717 cache.insert(key.clone(), value.clone(), Duration::from_secs(60));
719 let retrieved = cache.get(&key);
720 assert_eq!(retrieved, Some(value));
721 }
722
723 #[test]
724 fn test_lru_cache_expiration() {
725 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
726
727 let key = CacheKey::new("test", None);
728 let value = serde_json::json!({"result": "cached"});
729
730 cache.insert(key.clone(), value, Duration::from_millis(1));
732
733 std::thread::sleep(std::time::Duration::from_millis(10));
735
736 assert!(cache.get(&key).is_none());
738 }
739
740 #[test]
741 fn test_lru_cache_eviction() {
742 let mut cache = LruCache::new(2, 1024 * 1024, 1024);
743
744 let key1 = CacheKey::new("test1", None);
745 let key2 = CacheKey::new("test2", None);
746 let key3 = CacheKey::new("test3", None);
747
748 cache.insert(
749 key1.clone(),
750 serde_json::json!("v1"),
751 Duration::from_secs(60),
752 );
753 cache.insert(
754 key2.clone(),
755 serde_json::json!("v2"),
756 Duration::from_secs(60),
757 );
758
759 cache.insert(
761 key3.clone(),
762 serde_json::json!("v3"),
763 Duration::from_secs(60),
764 );
765
766 assert!(cache.get(&key1).is_none());
767 assert!(cache.get(&key2).is_some());
768 assert!(cache.get(&key3).is_some());
769 }
770
771 #[test]
772 fn test_lru_cache_size_limit() {
773 let mut cache = LruCache::new(100, 50, 1024);
775
776 let key1 = CacheKey::new("test1", None);
777 let key2 = CacheKey::new("test2", None);
778
779 cache.insert(
781 key1.clone(),
782 serde_json::json!("short"),
783 Duration::from_secs(60),
784 );
785 assert_eq!(cache.len(), 1);
786
787 cache.insert(
789 key2.clone(),
790 serde_json::json!("another"),
791 Duration::from_secs(60),
792 );
793 assert!(cache.len() <= 2);
795 }
796
797 #[test]
798 fn test_lru_cache_oversized_item_rejected() {
799 let mut cache = LruCache::new(10, 1024 * 1024, 10); let key = CacheKey::new("test", None);
802 let large_value = serde_json::json!({"data": "this is much longer than 10 bytes"});
803
804 cache.insert(key.clone(), large_value, Duration::from_secs(60));
805
806 assert!(cache.get(&key).is_none());
808 }
809
810 #[test]
815 fn test_caching_middleware_caches_tools_list() {
816 let middleware = ResponseCachingMiddleware::new();
817 let ctx = test_context();
818 let request = test_request("tools/list", None);
819
820 let decision = middleware.on_request(&ctx, &request).unwrap();
822 assert!(matches!(decision, MiddlewareDecision::Continue));
823
824 let response = serde_json::json!({"tools": []});
826 middleware
827 .on_response(&ctx, &request, response.clone())
828 .unwrap();
829
830 let decision = middleware.on_request(&ctx, &request).unwrap();
832 assert!(
833 matches!(decision, MiddlewareDecision::Respond(_)),
834 "Expected cache hit"
835 );
836 let MiddlewareDecision::Respond(cached) = decision else {
837 return;
838 };
839 assert_eq!(cached, response);
840
841 let stats = middleware.stats();
843 assert_eq!(stats.hits, 1);
844 assert_eq!(stats.misses, 1);
845 }
846
847 #[test]
848 fn test_caching_middleware_skips_non_cacheable_methods() {
849 let middleware = ResponseCachingMiddleware::new();
850 let ctx = test_context();
851 let request = test_request("initialize", None);
852
853 let decision = middleware.on_request(&ctx, &request).unwrap();
855 assert!(matches!(decision, MiddlewareDecision::Continue));
856
857 middleware
859 .on_response(&ctx, &request, serde_json::json!({}))
860 .unwrap();
861
862 let decision = middleware.on_request(&ctx, &request).unwrap();
863 assert!(matches!(decision, MiddlewareDecision::Continue));
864 }
865
866 #[test]
867 fn test_caching_middleware_different_params_different_keys() {
868 let middleware = ResponseCachingMiddleware::new();
869 let ctx = test_context();
870
871 let request1 = test_request(
872 "tools/call",
873 Some(serde_json::json!({"name": "tool_a", "arguments": {}})),
874 );
875 let request2 = test_request(
876 "tools/call",
877 Some(serde_json::json!({"name": "tool_b", "arguments": {}})),
878 );
879
880 middleware.on_request(&ctx, &request1).unwrap();
882 let response1 = serde_json::json!({"result": "a"});
883 middleware
884 .on_response(&ctx, &request1, response1.clone())
885 .unwrap();
886
887 let decision = middleware.on_request(&ctx, &request2).unwrap();
889 assert!(matches!(decision, MiddlewareDecision::Continue));
890
891 let decision = middleware.on_request(&ctx, &request1).unwrap();
893 assert!(
894 matches!(decision, MiddlewareDecision::Respond(_)),
895 "Expected cache hit"
896 );
897 let MiddlewareDecision::Respond(cached) = decision else {
898 return;
899 };
900 assert_eq!(cached, response1);
901 }
902
903 #[test]
904 fn test_caching_middleware_tool_exclusion() {
905 let middleware =
906 ResponseCachingMiddleware::new().exclude_tools(vec!["excluded_tool".to_string()]);
907 let ctx = test_context();
908
909 let excluded_request = test_request(
910 "tools/call",
911 Some(serde_json::json!({"name": "excluded_tool", "arguments": {}})),
912 );
913 let included_request = test_request(
914 "tools/call",
915 Some(serde_json::json!({"name": "included_tool", "arguments": {}})),
916 );
917
918 middleware.on_request(&ctx, &excluded_request).unwrap();
920 middleware
921 .on_response(&ctx, &excluded_request, serde_json::json!({}))
922 .unwrap();
923
924 let decision = middleware.on_request(&ctx, &excluded_request).unwrap();
925 assert!(matches!(decision, MiddlewareDecision::Continue));
926
927 middleware.on_request(&ctx, &included_request).unwrap();
929 let response = serde_json::json!({"result": "included"});
930 middleware
931 .on_response(&ctx, &included_request, response.clone())
932 .unwrap();
933
934 let decision = middleware.on_request(&ctx, &included_request).unwrap();
935 assert!(
936 matches!(decision, MiddlewareDecision::Respond(_)),
937 "Expected cache hit for included tool"
938 );
939 let MiddlewareDecision::Respond(cached) = decision else {
940 return;
941 };
942 assert_eq!(cached, response);
943 }
944
945 #[test]
946 fn test_caching_middleware_disable_method() {
947 let middleware = ResponseCachingMiddleware::new().disable_tools_list();
948 let ctx = test_context();
949 let request = test_request("tools/list", None);
950
951 middleware.on_request(&ctx, &request).unwrap();
953 middleware
954 .on_response(&ctx, &request, serde_json::json!({}))
955 .unwrap();
956
957 let decision = middleware.on_request(&ctx, &request).unwrap();
958 assert!(matches!(decision, MiddlewareDecision::Continue));
959 }
960
961 #[test]
962 fn test_caching_middleware_clear() {
963 let middleware = ResponseCachingMiddleware::new();
964 let ctx = test_context();
965 let request = test_request("tools/list", None);
966
967 middleware.on_request(&ctx, &request).unwrap();
969 middleware
970 .on_response(&ctx, &request, serde_json::json!({}))
971 .unwrap();
972
973 let decision = middleware.on_request(&ctx, &request).unwrap();
975 assert!(matches!(decision, MiddlewareDecision::Respond(_)));
976
977 middleware.clear();
979
980 let decision = middleware.on_request(&ctx, &request).unwrap();
982 assert!(matches!(decision, MiddlewareDecision::Continue));
983 }
984
985 #[test]
986 fn test_caching_middleware_invalidate() {
987 let middleware = ResponseCachingMiddleware::new();
988 let ctx = test_context();
989 let request = test_request("tools/list", None);
990
991 middleware.on_request(&ctx, &request).unwrap();
993 middleware
994 .on_response(&ctx, &request, serde_json::json!({}))
995 .unwrap();
996
997 middleware.invalidate("tools/list", None);
999
1000 let decision = middleware.on_request(&ctx, &request).unwrap();
1002 assert!(matches!(decision, MiddlewareDecision::Continue));
1003 }
1004
1005 #[test]
1006 fn test_cache_stats_hit_rate() {
1007 let stats = CacheStats {
1008 hits: 75,
1009 misses: 25,
1010 entries: 10,
1011 size_bytes: 1000,
1012 };
1013
1014 assert!((stats.hit_rate() - 75.0).abs() < 0.001);
1015 }
1016
1017 #[test]
1020 fn cache_stats_hit_rate_zero_total() {
1021 let stats = CacheStats::default();
1022 assert!(stats.hit_rate().abs() < f64::EPSILON);
1023 }
1024
1025 #[test]
1026 fn cache_stats_debug() {
1027 let stats = CacheStats::default();
1028 let debug = format!("{:?}", stats);
1029 assert!(debug.contains("CacheStats"));
1030 }
1031
1032 #[test]
1035 fn cache_key_same_method_same_params_are_equal() {
1036 let k1 = CacheKey::new("tools/list", Some(&serde_json::json!({"a": 1})));
1037 let k2 = CacheKey::new("tools/list", Some(&serde_json::json!({"a": 1})));
1038 assert_eq!(k1, k2);
1039 }
1040
1041 #[test]
1042 fn cache_key_different_params_differ() {
1043 let k1 = CacheKey::new("tools/list", Some(&serde_json::json!({"a": 1})));
1044 let k2 = CacheKey::new("tools/list", Some(&serde_json::json!({"a": 2})));
1045 assert_ne!(k1, k2);
1046 }
1047
1048 #[test]
1049 fn cache_key_none_params_hash_is_zero() {
1050 let k = CacheKey::new("test", None);
1051 assert_eq!(k.params_hash, 0);
1052 }
1053
1054 #[test]
1055 fn cache_key_debug_and_clone() {
1056 let k = CacheKey::new("test", None);
1057 let debug = format!("{:?}", k);
1058 assert!(debug.contains("test"));
1059 let cloned = k.clone();
1060 assert_eq!(k, cloned);
1061 }
1062
1063 #[test]
1066 fn hash_json_value_deterministic() {
1067 let v = serde_json::json!({"key": "value", "num": 42});
1068 let h1 = hash_json_value(&v);
1069 let h2 = hash_json_value(&v);
1070 assert_eq!(h1, h2);
1071 }
1072
1073 #[test]
1074 fn hash_json_value_different_values_differ() {
1075 let h1 = hash_json_value(&serde_json::json!(1));
1076 let h2 = hash_json_value(&serde_json::json!(2));
1077 assert_ne!(h1, h2);
1078 }
1079
1080 #[test]
1083 fn lru_cache_clear() {
1084 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
1085 cache.insert(
1086 CacheKey::new("a", None),
1087 serde_json::json!(1),
1088 Duration::from_secs(60),
1089 );
1090 cache.insert(
1091 CacheKey::new("b", None),
1092 serde_json::json!(2),
1093 Duration::from_secs(60),
1094 );
1095 assert_eq!(cache.len(), 2);
1096 assert!(!cache.is_empty());
1097
1098 cache.clear();
1099 assert_eq!(cache.len(), 0);
1100 assert!(cache.is_empty());
1101 assert_eq!(cache.current_size_bytes, 0);
1102 }
1103
1104 #[test]
1105 fn lru_cache_remove_nonexistent() {
1106 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
1107 let key = CacheKey::new("nonexistent", None);
1108 cache.remove(&key); assert_eq!(cache.len(), 0);
1110 }
1111
1112 #[test]
1113 fn lru_cache_insert_duplicate_replaces() {
1114 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
1115 let key = CacheKey::new("test", None);
1116 cache.insert(
1117 key.clone(),
1118 serde_json::json!("v1"),
1119 Duration::from_secs(60),
1120 );
1121 cache.insert(
1122 key.clone(),
1123 serde_json::json!("v2"),
1124 Duration::from_secs(60),
1125 );
1126 assert_eq!(cache.len(), 1);
1127 assert_eq!(cache.get(&key), Some(serde_json::json!("v2")));
1128 }
1129
1130 #[test]
1131 fn lru_cache_get_miss_returns_none() {
1132 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
1133 assert!(cache.get(&CacheKey::new("missing", None)).is_none());
1134 }
1135
1136 #[test]
1137 fn lru_cache_lru_order_updated_on_access() {
1138 let mut cache = LruCache::new(2, 1024 * 1024, 1024);
1139 let k1 = CacheKey::new("a", None);
1140 let k2 = CacheKey::new("b", None);
1141 cache.insert(k1.clone(), serde_json::json!(1), Duration::from_secs(60));
1142 cache.insert(k2.clone(), serde_json::json!(2), Duration::from_secs(60));
1143
1144 let _ = cache.get(&k1);
1146
1147 let k3 = CacheKey::new("c", None);
1149 cache.insert(k3.clone(), serde_json::json!(3), Duration::from_secs(60));
1150 assert!(cache.get(&k1).is_some()); assert!(cache.get(&k2).is_none()); assert!(cache.get(&k3).is_some());
1153 }
1154
1155 #[test]
1158 fn should_cache_tool_disabled_returns_false() {
1159 let config = ToolCallCacheConfig {
1160 base: MethodCacheConfig {
1161 enabled: false,
1162 ttl_secs: 60,
1163 },
1164 ..ToolCallCacheConfig::default()
1165 };
1166 assert!(!config.should_cache_tool("any_tool"));
1167 }
1168
1169 #[test]
1170 fn should_cache_tool_excluded_returns_false() {
1171 let config = ToolCallCacheConfig {
1172 base: MethodCacheConfig {
1173 enabled: true,
1174 ttl_secs: 60,
1175 },
1176 excluded_tools: vec!["excluded".to_string()],
1177 included_tools: vec![],
1178 };
1179 assert!(!config.should_cache_tool("excluded"));
1180 assert!(config.should_cache_tool("other"));
1181 }
1182
1183 #[test]
1184 fn should_cache_tool_include_list_filters() {
1185 let config = ToolCallCacheConfig {
1186 base: MethodCacheConfig {
1187 enabled: true,
1188 ttl_secs: 60,
1189 },
1190 included_tools: vec!["allowed".to_string()],
1191 excluded_tools: vec![],
1192 };
1193 assert!(config.should_cache_tool("allowed"));
1194 assert!(!config.should_cache_tool("not_allowed"));
1195 }
1196
1197 #[test]
1198 fn should_cache_tool_exclude_takes_precedence_over_include() {
1199 let config = ToolCallCacheConfig {
1200 base: MethodCacheConfig {
1201 enabled: true,
1202 ttl_secs: 60,
1203 },
1204 included_tools: vec!["tool".to_string()],
1205 excluded_tools: vec!["tool".to_string()],
1206 };
1207 assert!(!config.should_cache_tool("tool"));
1208 }
1209
1210 #[test]
1213 fn method_cache_config_default() {
1214 let config = MethodCacheConfig::default();
1215 assert!(config.enabled);
1216 assert_eq!(config.ttl_secs, DEFAULT_CALL_TTL_SECS);
1217 }
1218
1219 #[test]
1220 fn method_cache_config_debug() {
1221 let config = MethodCacheConfig::default();
1222 let debug = format!("{:?}", config);
1223 assert!(debug.contains("MethodCacheConfig"));
1224 }
1225
1226 #[test]
1229 fn default_equals_new() {
1230 let d = ResponseCachingMiddleware::default();
1231 let n = ResponseCachingMiddleware::new();
1232 assert_eq!(d.list_ttl, n.list_ttl);
1233 assert_eq!(d.call_ttl, n.call_ttl);
1234 }
1235
1236 #[test]
1237 fn debug_output() {
1238 let m = ResponseCachingMiddleware::new();
1239 let debug = format!("{:?}", m);
1240 assert!(debug.contains("ResponseCachingMiddleware"));
1241 assert!(debug.contains("list_ttl"));
1242 assert!(debug.contains("call_ttl"));
1243 }
1244
1245 #[test]
1248 fn list_ttl_secs_updates_all_list_configs() {
1249 let m = ResponseCachingMiddleware::new().list_ttl_secs(600);
1250 assert_eq!(m.list_ttl, Duration::from_secs(600));
1251 assert_eq!(m.tools_list_config.ttl_secs, 600);
1252 assert_eq!(m.resources_list_config.ttl_secs, 600);
1253 assert_eq!(m.prompts_list_config.ttl_secs, 600);
1254 }
1255
1256 #[test]
1257 fn call_ttl_secs_updates_all_call_configs() {
1258 let m = ResponseCachingMiddleware::new().call_ttl_secs(7200);
1259 assert_eq!(m.call_ttl, Duration::from_secs(7200));
1260 assert_eq!(m.tools_call_config.base.ttl_secs, 7200);
1261 assert_eq!(m.resources_read_config.ttl_secs, 7200);
1262 assert_eq!(m.prompts_get_config.ttl_secs, 7200);
1263 }
1264
1265 #[test]
1266 fn max_entries_setter() {
1267 let m = ResponseCachingMiddleware::new().max_entries(50);
1268 let cache = m
1269 .cache
1270 .lock()
1271 .unwrap_or_else(std::sync::PoisonError::into_inner);
1272 assert_eq!(cache.max_entries, 50);
1273 }
1274
1275 #[test]
1276 fn max_size_bytes_setter() {
1277 let m = ResponseCachingMiddleware::new().max_size_bytes(2048);
1278 let cache = m
1279 .cache
1280 .lock()
1281 .unwrap_or_else(std::sync::PoisonError::into_inner);
1282 assert_eq!(cache.max_size_bytes, 2048);
1283 }
1284
1285 #[test]
1286 fn max_item_size_setter() {
1287 let m = ResponseCachingMiddleware::new().max_item_size(512);
1288 let cache = m
1289 .cache
1290 .lock()
1291 .unwrap_or_else(std::sync::PoisonError::into_inner);
1292 assert_eq!(cache.max_item_size, 512);
1293 }
1294
1295 #[test]
1298 fn disable_resources_list() {
1299 let m = ResponseCachingMiddleware::new().disable_resources_list();
1300 assert!(!m.resources_list_config.enabled);
1301 assert!(m.tools_list_config.enabled); }
1303
1304 #[test]
1305 fn disable_prompts_list() {
1306 let m = ResponseCachingMiddleware::new().disable_prompts_list();
1307 assert!(!m.prompts_list_config.enabled);
1308 }
1309
1310 #[test]
1311 fn disable_tools_call() {
1312 let m = ResponseCachingMiddleware::new().disable_tools_call();
1313 assert!(!m.tools_call_config.base.enabled);
1314 }
1315
1316 #[test]
1317 fn disable_resources_read() {
1318 let m = ResponseCachingMiddleware::new().disable_resources_read();
1319 assert!(!m.resources_read_config.enabled);
1320 }
1321
1322 #[test]
1323 fn disable_prompts_get() {
1324 let m = ResponseCachingMiddleware::new().disable_prompts_get();
1325 assert!(!m.prompts_get_config.enabled);
1326 }
1327
1328 #[test]
1331 fn include_tools_restricts_caching() {
1332 let m = ResponseCachingMiddleware::new().include_tools(vec!["allowed_tool".to_string()]);
1333 let _ctx = test_context();
1334
1335 let req = test_request(
1337 "tools/call",
1338 Some(serde_json::json!({"name": "allowed_tool"})),
1339 );
1340 assert!(m.should_cache_method(&req.method, req.params.as_ref()));
1341
1342 let req2 = test_request(
1344 "tools/call",
1345 Some(serde_json::json!({"name": "other_tool"})),
1346 );
1347 assert!(!m.should_cache_method(&req2.method, req2.params.as_ref()));
1348
1349 let req3 = test_request("tools/list", None);
1351 assert!(m.should_cache_method(&req3.method, req3.params.as_ref()));
1352 }
1353
1354 #[test]
1357 fn should_cache_tools_call_without_name_returns_false() {
1358 let m = ResponseCachingMiddleware::new();
1359 assert!(!m.should_cache_method("tools/call", Some(&serde_json::json!({"arguments": {}}))));
1361 }
1362
1363 #[test]
1364 fn should_cache_tools_call_with_no_params_returns_false() {
1365 let m = ResponseCachingMiddleware::new();
1366 assert!(!m.should_cache_method("tools/call", None));
1367 }
1368
1369 #[test]
1370 fn should_cache_unknown_method_returns_false() {
1371 let m = ResponseCachingMiddleware::new();
1372 assert!(!m.should_cache_method("unknown/method", None));
1373 }
1374
1375 #[test]
1376 fn should_cache_all_known_cacheable_methods() {
1377 let m = ResponseCachingMiddleware::new();
1378 assert!(m.should_cache_method("tools/list", None));
1379 assert!(m.should_cache_method("resources/list", None));
1380 assert!(m.should_cache_method("prompts/list", None));
1381 assert!(m.should_cache_method("resources/read", None));
1382 assert!(m.should_cache_method("prompts/get", None));
1383 }
1384
1385 #[test]
1388 fn get_ttl_list_methods() {
1389 let m = ResponseCachingMiddleware::new().list_ttl_secs(120);
1390 assert_eq!(m.get_ttl("tools/list"), Duration::from_secs(120));
1391 assert_eq!(m.get_ttl("resources/list"), Duration::from_secs(120));
1392 assert_eq!(m.get_ttl("prompts/list"), Duration::from_secs(120));
1393 }
1394
1395 #[test]
1396 fn get_ttl_call_methods() {
1397 let m = ResponseCachingMiddleware::new().call_ttl_secs(900);
1398 assert_eq!(m.get_ttl("tools/call"), Duration::from_secs(900));
1399 assert_eq!(m.get_ttl("resources/read"), Duration::from_secs(900));
1400 assert_eq!(m.get_ttl("prompts/get"), Duration::from_secs(900));
1401 }
1402
1403 #[test]
1404 fn get_ttl_unknown_method_uses_call_ttl() {
1405 let m = ResponseCachingMiddleware::new().call_ttl_secs(999);
1406 assert_eq!(m.get_ttl("unknown/method"), Duration::from_secs(999));
1407 }
1408
1409 #[test]
1412 fn on_error_passes_through() {
1413 let m = ResponseCachingMiddleware::new();
1414 let ctx = test_context();
1415 let req = test_request("tools/list", None);
1416 let err = McpError::internal_error("test error");
1417 let result = m.on_error(&ctx, &req, err);
1418 assert!(result.message.contains("test error"));
1419 }
1420
1421 #[test]
1424 fn stats_tracks_entries_and_size() {
1425 let m = ResponseCachingMiddleware::new();
1426 let ctx = test_context();
1427
1428 let stats = m.stats();
1429 assert_eq!(stats.entries, 0);
1430 assert_eq!(stats.size_bytes, 0);
1431
1432 let req = test_request("tools/list", None);
1433 m.on_request(&ctx, &req).unwrap();
1434 m.on_response(&ctx, &req, serde_json::json!({"tools": []}))
1435 .unwrap();
1436
1437 let stats = m.stats();
1438 assert_eq!(stats.entries, 1);
1439 assert!(stats.size_bytes > 0);
1440 assert_eq!(stats.misses, 1);
1441 }
1442
1443 #[test]
1446 fn caches_resources_list() {
1447 let m = ResponseCachingMiddleware::new();
1448 let ctx = test_context();
1449 let req = test_request("resources/list", None);
1450
1451 m.on_request(&ctx, &req).unwrap();
1452 m.on_response(&ctx, &req, serde_json::json!({"resources": []}))
1453 .unwrap();
1454
1455 let decision = m.on_request(&ctx, &req).unwrap();
1456 assert!(matches!(decision, MiddlewareDecision::Respond(_)));
1457 }
1458
1459 #[test]
1460 fn caches_prompts_list() {
1461 let m = ResponseCachingMiddleware::new();
1462 let ctx = test_context();
1463 let req = test_request("prompts/list", None);
1464
1465 m.on_request(&ctx, &req).unwrap();
1466 m.on_response(&ctx, &req, serde_json::json!({"prompts": []}))
1467 .unwrap();
1468
1469 let decision = m.on_request(&ctx, &req).unwrap();
1470 assert!(matches!(decision, MiddlewareDecision::Respond(_)));
1471 }
1472
1473 #[test]
1476 fn cache_entry_debug_and_clone() {
1477 let entry = CacheEntry::new(serde_json::json!(42), Duration::from_secs(60));
1478 let debug = format!("{:?}", entry);
1479 assert!(debug.contains("CacheEntry"));
1480 let cloned = entry.clone();
1481 assert_eq!(cloned.value, serde_json::json!(42));
1482 }
1483
1484 #[test]
1485 fn cache_entry_not_expired_initially() {
1486 let entry = CacheEntry::new(serde_json::json!(1), Duration::from_secs(60));
1487 assert!(!entry.is_expired());
1488 }
1489
1490 #[test]
1491 fn caches_resources_read() {
1492 let m = ResponseCachingMiddleware::new();
1493 let ctx = test_context();
1494 let req = test_request(
1495 "resources/read",
1496 Some(serde_json::json!({"uri": "file:///a.txt"})),
1497 );
1498
1499 m.on_request(&ctx, &req).unwrap();
1500 m.on_response(&ctx, &req, serde_json::json!({"contents": []}))
1501 .unwrap();
1502
1503 let decision = m.on_request(&ctx, &req).unwrap();
1504 assert!(matches!(decision, MiddlewareDecision::Respond(_)));
1505 }
1506
1507 #[test]
1508 fn caches_prompts_get() {
1509 let m = ResponseCachingMiddleware::new();
1510 let ctx = test_context();
1511 let req = test_request("prompts/get", Some(serde_json::json!({"name": "greeting"})));
1512
1513 m.on_request(&ctx, &req).unwrap();
1514 m.on_response(&ctx, &req, serde_json::json!({"messages": []}))
1515 .unwrap();
1516
1517 let decision = m.on_request(&ctx, &req).unwrap();
1518 assert!(matches!(decision, MiddlewareDecision::Respond(_)));
1519 }
1520
1521 #[test]
1522 fn lru_cache_evict_expired_frees_entries() {
1523 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
1524 cache.insert(
1526 CacheKey::new("a", None),
1527 serde_json::json!(1),
1528 Duration::from_millis(1),
1529 );
1530 cache.insert(
1531 CacheKey::new("b", None),
1532 serde_json::json!(2),
1533 Duration::from_millis(1),
1534 );
1535 assert_eq!(cache.len(), 2);
1536
1537 std::thread::sleep(std::time::Duration::from_millis(10));
1538 cache.evict_expired();
1539
1540 assert_eq!(cache.len(), 0);
1541 assert_eq!(cache.current_size_bytes, 0);
1542 }
1543
1544 #[test]
1545 fn lru_cache_insert_replaces_updates_size() {
1546 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
1547 let key = CacheKey::new("k", None);
1548 cache.insert(
1549 key.clone(),
1550 serde_json::json!("short"),
1551 Duration::from_secs(60),
1552 );
1553 let size_after_first = cache.current_size_bytes;
1554
1555 cache.insert(
1556 key.clone(),
1557 serde_json::json!("much longer value here"),
1558 Duration::from_secs(60),
1559 );
1560 let size_after_second = cache.current_size_bytes;
1561
1562 assert_ne!(size_after_first, size_after_second);
1564 assert_eq!(cache.len(), 1);
1565 }
1566
1567 #[test]
1568 fn tool_call_cache_config_debug_and_clone() {
1569 let config = ToolCallCacheConfig {
1570 base: MethodCacheConfig {
1571 enabled: true,
1572 ttl_secs: 120,
1573 },
1574 included_tools: vec!["t1".to_string()],
1575 excluded_tools: vec!["t2".to_string()],
1576 };
1577 let debug = format!("{:?}", config);
1578 assert!(debug.contains("ToolCallCacheConfig"));
1579 let cloned = config.clone();
1580 assert_eq!(cloned.included_tools, vec!["t1".to_string()]);
1581 assert_eq!(cloned.excluded_tools, vec!["t2".to_string()]);
1582 }
1583
1584 #[test]
1585 fn cache_stats_clone() {
1586 let stats = CacheStats {
1587 hits: 10,
1588 misses: 5,
1589 entries: 3,
1590 size_bytes: 100,
1591 };
1592 let cloned = stats.clone();
1593 assert_eq!(cloned.hits, 10);
1594 assert_eq!(cloned.misses, 5);
1595 assert_eq!(cloned.entries, 3);
1596 assert_eq!(cloned.size_bytes, 100);
1597 }
1598
1599 #[test]
1600 fn should_cache_tool_empty_lists_allows_all() {
1601 let config = ToolCallCacheConfig {
1602 base: MethodCacheConfig {
1603 enabled: true,
1604 ttl_secs: 60,
1605 },
1606 included_tools: vec![],
1607 excluded_tools: vec![],
1608 };
1609 assert!(config.should_cache_tool("any_tool"));
1610 assert!(config.should_cache_tool("another_tool"));
1611 }
1612}