1use std::collections::HashMap;
32use std::hash::{Hash, Hasher};
33use std::sync::Mutex;
34use std::time::{Duration, Instant};
35
36use fastmcp_core::{McpContext, McpError, McpResult};
37use fastmcp_protocol::JsonRpcRequest;
38
39use crate::{Middleware, MiddlewareDecision};
40
41pub const DEFAULT_LIST_TTL_SECS: u64 = 300;
43
44pub const DEFAULT_CALL_TTL_SECS: u64 = 3600;
46
47pub const DEFAULT_MAX_ITEM_SIZE: usize = 1024 * 1024;
49
50#[derive(Debug, Clone)]
52struct CacheEntry {
53 value: serde_json::Value,
54 expires_at: Instant,
55 size_bytes: usize,
56}
57
58impl CacheEntry {
59 fn new(value: serde_json::Value, ttl: Duration) -> Self {
60 let size_bytes = value.to_string().len();
61 Self {
62 value,
63 expires_at: Instant::now() + ttl,
64 size_bytes,
65 }
66 }
67
68 fn is_expired(&self) -> bool {
69 Instant::now() > self.expires_at
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75struct CacheKey {
76 method: String,
77 params_hash: u64,
78}
79
80impl CacheKey {
81 fn new(method: &str, params: Option<&serde_json::Value>) -> Self {
82 let params_hash = match params {
83 Some(v) => hash_json_value(v),
84 None => 0,
85 };
86 Self {
87 method: method.to_string(),
88 params_hash,
89 }
90 }
91}
92
93fn hash_json_value(value: &serde_json::Value) -> u64 {
95 use std::collections::hash_map::DefaultHasher;
96
97 let mut hasher = DefaultHasher::new();
98
99 let json_str = serde_json::to_string(value).unwrap_or_default();
102 json_str.hash(&mut hasher);
103
104 hasher.finish()
105}
106
107#[derive(Debug, Clone)]
109pub struct MethodCacheConfig {
110 pub enabled: bool,
112 pub ttl_secs: u64,
114}
115
116impl Default for MethodCacheConfig {
117 fn default() -> Self {
118 Self {
119 enabled: true,
120 ttl_secs: DEFAULT_CALL_TTL_SECS,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct ToolCallCacheConfig {
128 pub base: MethodCacheConfig,
130 pub included_tools: Vec<String>,
132 pub excluded_tools: Vec<String>,
134}
135
136impl ToolCallCacheConfig {
137 fn should_cache_tool(&self, tool_name: &str) -> bool {
139 if !self.base.enabled {
140 return false;
141 }
142
143 if self.excluded_tools.contains(&tool_name.to_string()) {
145 return false;
146 }
147
148 if !self.included_tools.is_empty() {
150 return self.included_tools.contains(&tool_name.to_string());
151 }
152
153 true
155 }
156}
157
158#[derive(Debug)]
160struct LruCache {
161 entries: HashMap<CacheKey, CacheEntry>,
163 order: Vec<CacheKey>,
165 max_entries: usize,
167 max_size_bytes: usize,
169 max_item_size: usize,
171 current_size_bytes: usize,
173}
174
175impl LruCache {
176 fn new(max_entries: usize, max_size_bytes: usize, max_item_size: usize) -> Self {
177 Self {
178 entries: HashMap::new(),
179 order: Vec::new(),
180 max_entries,
181 max_size_bytes,
182 max_item_size,
183 current_size_bytes: 0,
184 }
185 }
186
187 fn get(&mut self, key: &CacheKey) -> Option<serde_json::Value> {
188 if let Some(entry) = self.entries.get(key) {
190 if entry.is_expired() {
191 self.remove(key);
193 return None;
194 }
195
196 if let Some(pos) = self.order.iter().position(|k| k == key) {
198 let k = self.order.remove(pos);
199 self.order.push(k);
200 }
201
202 return Some(entry.value.clone());
203 }
204 None
205 }
206
207 fn insert(&mut self, key: CacheKey, value: serde_json::Value, ttl: Duration) {
208 let entry = CacheEntry::new(value, ttl);
209
210 if entry.size_bytes > self.max_item_size {
212 return;
214 }
215
216 if self.entries.contains_key(&key) {
218 self.remove(&key);
219 }
220
221 while self.entries.len() >= self.max_entries
223 || self.current_size_bytes + entry.size_bytes > self.max_size_bytes
224 {
225 if self.order.is_empty() {
226 break;
227 }
228 let oldest_key = self.order.remove(0);
230 if let Some(old_entry) = self.entries.remove(&oldest_key) {
231 self.current_size_bytes -= old_entry.size_bytes;
232 }
233 }
234
235 self.evict_expired();
237
238 self.current_size_bytes += entry.size_bytes;
240 self.entries.insert(key.clone(), entry);
241 self.order.push(key);
242 }
243
244 fn remove(&mut self, key: &CacheKey) {
245 if let Some(entry) = self.entries.remove(key) {
246 self.current_size_bytes -= entry.size_bytes;
247 if let Some(pos) = self.order.iter().position(|k| k == key) {
248 self.order.remove(pos);
249 }
250 }
251 }
252
253 fn evict_expired(&mut self) {
254 let expired_keys: Vec<CacheKey> = self
255 .entries
256 .iter()
257 .filter(|(_, entry)| entry.is_expired())
258 .map(|(key, _)| key.clone())
259 .collect();
260
261 for key in expired_keys {
262 self.remove(&key);
263 }
264 }
265
266 fn clear(&mut self) {
267 self.entries.clear();
268 self.order.clear();
269 self.current_size_bytes = 0;
270 }
271
272 fn len(&self) -> usize {
273 self.entries.len()
274 }
275
276 #[allow(dead_code)]
277 fn is_empty(&self) -> bool {
278 self.entries.is_empty()
279 }
280}
281
282#[derive(Debug, Clone, Default)]
284pub struct CacheStats {
285 pub hits: u64,
287 pub misses: u64,
289 pub entries: usize,
291 pub size_bytes: usize,
293}
294
295impl CacheStats {
296 #[must_use]
298 pub fn hit_rate(&self) -> f64 {
299 let total = self.hits + self.misses;
300 if total == 0 {
301 0.0
302 } else {
303 (self.hits as f64 / total as f64) * 100.0
304 }
305 }
306}
307
308pub struct ResponseCachingMiddleware {
313 cache: Mutex<LruCache>,
315 list_ttl: Duration,
317 call_ttl: Duration,
319 tools_list_config: MethodCacheConfig,
321 resources_list_config: MethodCacheConfig,
323 prompts_list_config: MethodCacheConfig,
325 tools_call_config: ToolCallCacheConfig,
327 resources_read_config: MethodCacheConfig,
329 prompts_get_config: MethodCacheConfig,
331 stats: Mutex<CacheStats>,
333}
334
335impl std::fmt::Debug for ResponseCachingMiddleware {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.debug_struct("ResponseCachingMiddleware")
338 .field("list_ttl", &self.list_ttl)
339 .field("call_ttl", &self.call_ttl)
340 .finish_non_exhaustive()
341 }
342}
343
344impl Default for ResponseCachingMiddleware {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350impl ResponseCachingMiddleware {
351 #[must_use]
353 pub fn new() -> Self {
354 Self {
355 cache: Mutex::new(LruCache::new(
356 1000,
357 100 * 1024 * 1024,
358 DEFAULT_MAX_ITEM_SIZE,
359 )),
360 list_ttl: Duration::from_secs(DEFAULT_LIST_TTL_SECS),
361 call_ttl: Duration::from_secs(DEFAULT_CALL_TTL_SECS),
362 tools_list_config: MethodCacheConfig {
363 enabled: true,
364 ttl_secs: DEFAULT_LIST_TTL_SECS,
365 },
366 resources_list_config: MethodCacheConfig {
367 enabled: true,
368 ttl_secs: DEFAULT_LIST_TTL_SECS,
369 },
370 prompts_list_config: MethodCacheConfig {
371 enabled: true,
372 ttl_secs: DEFAULT_LIST_TTL_SECS,
373 },
374 tools_call_config: ToolCallCacheConfig {
375 base: MethodCacheConfig {
376 enabled: true,
377 ttl_secs: DEFAULT_CALL_TTL_SECS,
378 },
379 included_tools: Vec::new(),
380 excluded_tools: Vec::new(),
381 },
382 resources_read_config: MethodCacheConfig {
383 enabled: true,
384 ttl_secs: DEFAULT_CALL_TTL_SECS,
385 },
386 prompts_get_config: MethodCacheConfig {
387 enabled: true,
388 ttl_secs: DEFAULT_CALL_TTL_SECS,
389 },
390 stats: Mutex::new(CacheStats::default()),
391 }
392 }
393
394 #[must_use]
396 pub fn max_entries(self, max: usize) -> Self {
397 let max_size = {
398 let cache = self
399 .cache
400 .lock()
401 .unwrap_or_else(std::sync::PoisonError::into_inner);
402 cache.max_size_bytes
403 };
404 let max_item_size = {
405 let cache = self
406 .cache
407 .lock()
408 .unwrap_or_else(std::sync::PoisonError::into_inner);
409 cache.max_item_size
410 };
411 Self {
412 cache: Mutex::new(LruCache::new(max, max_size, max_item_size)),
413 ..self
414 }
415 }
416
417 #[must_use]
419 pub fn max_size_bytes(self, max: usize) -> Self {
420 let max_entries = {
421 let cache = self
422 .cache
423 .lock()
424 .unwrap_or_else(std::sync::PoisonError::into_inner);
425 cache.max_entries
426 };
427 let max_item_size = {
428 let cache = self
429 .cache
430 .lock()
431 .unwrap_or_else(std::sync::PoisonError::into_inner);
432 cache.max_item_size
433 };
434 Self {
435 cache: Mutex::new(LruCache::new(max_entries, max, max_item_size)),
436 ..self
437 }
438 }
439
440 #[must_use]
442 pub fn max_item_size(self, max: usize) -> Self {
443 let max_entries = {
444 let cache = self
445 .cache
446 .lock()
447 .unwrap_or_else(std::sync::PoisonError::into_inner);
448 cache.max_entries
449 };
450 let max_size = {
451 let cache = self
452 .cache
453 .lock()
454 .unwrap_or_else(std::sync::PoisonError::into_inner);
455 cache.max_size_bytes
456 };
457 Self {
458 cache: Mutex::new(LruCache::new(max_entries, max_size, max)),
459 ..self
460 }
461 }
462
463 #[must_use]
465 pub fn list_ttl_secs(mut self, secs: u64) -> Self {
466 self.list_ttl = Duration::from_secs(secs);
467 self.tools_list_config.ttl_secs = secs;
468 self.resources_list_config.ttl_secs = secs;
469 self.prompts_list_config.ttl_secs = secs;
470 self
471 }
472
473 #[must_use]
475 pub fn call_ttl_secs(mut self, secs: u64) -> Self {
476 self.call_ttl = Duration::from_secs(secs);
477 self.tools_call_config.base.ttl_secs = secs;
478 self.resources_read_config.ttl_secs = secs;
479 self.prompts_get_config.ttl_secs = secs;
480 self
481 }
482
483 #[must_use]
485 pub fn disable_tools_list(mut self) -> Self {
486 self.tools_list_config.enabled = false;
487 self
488 }
489
490 #[must_use]
492 pub fn disable_resources_list(mut self) -> Self {
493 self.resources_list_config.enabled = false;
494 self
495 }
496
497 #[must_use]
499 pub fn disable_prompts_list(mut self) -> Self {
500 self.prompts_list_config.enabled = false;
501 self
502 }
503
504 #[must_use]
506 pub fn disable_tools_call(mut self) -> Self {
507 self.tools_call_config.base.enabled = false;
508 self
509 }
510
511 #[must_use]
513 pub fn disable_resources_read(mut self) -> Self {
514 self.resources_read_config.enabled = false;
515 self
516 }
517
518 #[must_use]
520 pub fn disable_prompts_get(mut self) -> Self {
521 self.prompts_get_config.enabled = false;
522 self
523 }
524
525 #[must_use]
527 pub fn include_tools(mut self, tools: Vec<String>) -> Self {
528 self.tools_call_config.included_tools = tools;
529 self
530 }
531
532 #[must_use]
534 pub fn exclude_tools(mut self, tools: Vec<String>) -> Self {
535 self.tools_call_config.excluded_tools = tools;
536 self
537 }
538
539 #[must_use]
541 pub fn stats(&self) -> CacheStats {
542 let cache = self
543 .cache
544 .lock()
545 .unwrap_or_else(std::sync::PoisonError::into_inner);
546 let mut stats = self
547 .stats
548 .lock()
549 .unwrap_or_else(std::sync::PoisonError::into_inner)
550 .clone();
551 stats.entries = cache.len();
552 stats.size_bytes = cache.current_size_bytes;
553 stats
554 }
555
556 pub fn clear(&self) {
558 let mut cache = self
559 .cache
560 .lock()
561 .unwrap_or_else(std::sync::PoisonError::into_inner);
562 cache.clear();
563 }
564
565 pub fn invalidate(&self, method: &str, params: Option<&serde_json::Value>) {
567 let key = CacheKey::new(method, params);
568 let mut cache = self
569 .cache
570 .lock()
571 .unwrap_or_else(std::sync::PoisonError::into_inner);
572 cache.remove(&key);
573 }
574
575 fn should_cache_method(&self, method: &str, params: Option<&serde_json::Value>) -> bool {
577 match method {
578 "tools/list" => self.tools_list_config.enabled,
579 "resources/list" => self.resources_list_config.enabled,
580 "prompts/list" => self.prompts_list_config.enabled,
581 "resources/read" => self.resources_read_config.enabled,
582 "prompts/get" => self.prompts_get_config.enabled,
583 "tools/call" => {
584 if !self.tools_call_config.base.enabled {
585 return false;
586 }
587 if let Some(params) = params {
589 if let Some(tool_name) = params.get("name").and_then(|v| v.as_str()) {
590 return self.tools_call_config.should_cache_tool(tool_name);
591 }
592 }
593 false
594 }
595 _ => false,
596 }
597 }
598
599 fn get_ttl(&self, method: &str) -> Duration {
601 match method {
602 "tools/list" => Duration::from_secs(self.tools_list_config.ttl_secs),
603 "resources/list" => Duration::from_secs(self.resources_list_config.ttl_secs),
604 "prompts/list" => Duration::from_secs(self.prompts_list_config.ttl_secs),
605 "tools/call" => Duration::from_secs(self.tools_call_config.base.ttl_secs),
606 "resources/read" => Duration::from_secs(self.resources_read_config.ttl_secs),
607 "prompts/get" => Duration::from_secs(self.prompts_get_config.ttl_secs),
608 _ => self.call_ttl,
609 }
610 }
611
612 fn record_hit(&self) {
613 let mut stats = self
614 .stats
615 .lock()
616 .unwrap_or_else(std::sync::PoisonError::into_inner);
617 stats.hits += 1;
618 }
619
620 fn record_miss(&self) {
621 let mut stats = self
622 .stats
623 .lock()
624 .unwrap_or_else(std::sync::PoisonError::into_inner);
625 stats.misses += 1;
626 }
627}
628
629impl Middleware for ResponseCachingMiddleware {
630 fn on_request(
631 &self,
632 _ctx: &McpContext,
633 request: &JsonRpcRequest,
634 ) -> McpResult<MiddlewareDecision> {
635 if !self.should_cache_method(&request.method, request.params.as_ref()) {
637 return Ok(MiddlewareDecision::Continue);
638 }
639
640 let key = CacheKey::new(&request.method, request.params.as_ref());
642 let mut cache = self
643 .cache
644 .lock()
645 .unwrap_or_else(std::sync::PoisonError::into_inner);
646
647 if let Some(value) = cache.get(&key) {
648 self.record_hit();
649 return Ok(MiddlewareDecision::Respond(value));
650 }
651
652 self.record_miss();
653 Ok(MiddlewareDecision::Continue)
654 }
655
656 fn on_response(
657 &self,
658 _ctx: &McpContext,
659 request: &JsonRpcRequest,
660 response: serde_json::Value,
661 ) -> McpResult<serde_json::Value> {
662 if !self.should_cache_method(&request.method, request.params.as_ref()) {
664 return Ok(response);
665 }
666
667 let key = CacheKey::new(&request.method, request.params.as_ref());
669 let ttl = self.get_ttl(&request.method);
670
671 let mut cache = self
672 .cache
673 .lock()
674 .unwrap_or_else(std::sync::PoisonError::into_inner);
675
676 cache.insert(key, response.clone(), ttl);
677
678 Ok(response)
679 }
680
681 fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
682 error
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use asupersync::Cx;
691
692 fn test_context() -> McpContext {
693 let cx = Cx::for_testing();
694 McpContext::new(cx, 1)
695 }
696
697 fn test_request(method: &str, params: Option<serde_json::Value>) -> JsonRpcRequest {
698 JsonRpcRequest {
699 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
700 method: method.to_string(),
701 params,
702 id: Some(fastmcp_protocol::RequestId::Number(1)),
703 }
704 }
705
706 #[test]
711 fn test_lru_cache_basic_operations() {
712 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
713
714 let key = CacheKey::new("test", None);
715 let value = serde_json::json!({"result": "cached"});
716
717 cache.insert(key.clone(), value.clone(), Duration::from_secs(60));
719 let retrieved = cache.get(&key);
720 assert_eq!(retrieved, Some(value));
721 }
722
723 #[test]
724 fn test_lru_cache_expiration() {
725 let mut cache = LruCache::new(10, 1024 * 1024, 1024);
726
727 let key = CacheKey::new("test", None);
728 let value = serde_json::json!({"result": "cached"});
729
730 cache.insert(key.clone(), value, Duration::from_millis(1));
732
733 std::thread::sleep(std::time::Duration::from_millis(10));
735
736 assert!(cache.get(&key).is_none());
738 }
739
740 #[test]
741 fn test_lru_cache_eviction() {
742 let mut cache = LruCache::new(2, 1024 * 1024, 1024);
743
744 let key1 = CacheKey::new("test1", None);
745 let key2 = CacheKey::new("test2", None);
746 let key3 = CacheKey::new("test3", None);
747
748 cache.insert(
749 key1.clone(),
750 serde_json::json!("v1"),
751 Duration::from_secs(60),
752 );
753 cache.insert(
754 key2.clone(),
755 serde_json::json!("v2"),
756 Duration::from_secs(60),
757 );
758
759 cache.insert(
761 key3.clone(),
762 serde_json::json!("v3"),
763 Duration::from_secs(60),
764 );
765
766 assert!(cache.get(&key1).is_none());
767 assert!(cache.get(&key2).is_some());
768 assert!(cache.get(&key3).is_some());
769 }
770
771 #[test]
772 fn test_lru_cache_size_limit() {
773 let mut cache = LruCache::new(100, 50, 1024);
775
776 let key1 = CacheKey::new("test1", None);
777 let key2 = CacheKey::new("test2", None);
778
779 cache.insert(
781 key1.clone(),
782 serde_json::json!("short"),
783 Duration::from_secs(60),
784 );
785 assert_eq!(cache.len(), 1);
786
787 cache.insert(
789 key2.clone(),
790 serde_json::json!("another"),
791 Duration::from_secs(60),
792 );
793 assert!(cache.len() <= 2);
795 }
796
797 #[test]
798 fn test_lru_cache_oversized_item_rejected() {
799 let mut cache = LruCache::new(10, 1024 * 1024, 10); let key = CacheKey::new("test", None);
802 let large_value = serde_json::json!({"data": "this is much longer than 10 bytes"});
803
804 cache.insert(key.clone(), large_value, Duration::from_secs(60));
805
806 assert!(cache.get(&key).is_none());
808 }
809
810 #[test]
815 fn test_caching_middleware_caches_tools_list() {
816 let middleware = ResponseCachingMiddleware::new();
817 let ctx = test_context();
818 let request = test_request("tools/list", None);
819
820 let decision = middleware.on_request(&ctx, &request).unwrap();
822 assert!(matches!(decision, MiddlewareDecision::Continue));
823
824 let response = serde_json::json!({"tools": []});
826 middleware
827 .on_response(&ctx, &request, response.clone())
828 .unwrap();
829
830 let decision = middleware.on_request(&ctx, &request).unwrap();
832 match decision {
833 MiddlewareDecision::Respond(cached) => assert_eq!(cached, response),
834 MiddlewareDecision::Continue => panic!("Expected cache hit"),
835 }
836
837 let stats = middleware.stats();
839 assert_eq!(stats.hits, 1);
840 assert_eq!(stats.misses, 1);
841 }
842
843 #[test]
844 fn test_caching_middleware_skips_non_cacheable_methods() {
845 let middleware = ResponseCachingMiddleware::new();
846 let ctx = test_context();
847 let request = test_request("initialize", None);
848
849 let decision = middleware.on_request(&ctx, &request).unwrap();
851 assert!(matches!(decision, MiddlewareDecision::Continue));
852
853 middleware
855 .on_response(&ctx, &request, serde_json::json!({}))
856 .unwrap();
857
858 let decision = middleware.on_request(&ctx, &request).unwrap();
859 assert!(matches!(decision, MiddlewareDecision::Continue));
860 }
861
862 #[test]
863 fn test_caching_middleware_different_params_different_keys() {
864 let middleware = ResponseCachingMiddleware::new();
865 let ctx = test_context();
866
867 let request1 = test_request(
868 "tools/call",
869 Some(serde_json::json!({"name": "tool_a", "arguments": {}})),
870 );
871 let request2 = test_request(
872 "tools/call",
873 Some(serde_json::json!({"name": "tool_b", "arguments": {}})),
874 );
875
876 middleware.on_request(&ctx, &request1).unwrap();
878 let response1 = serde_json::json!({"result": "a"});
879 middleware
880 .on_response(&ctx, &request1, response1.clone())
881 .unwrap();
882
883 let decision = middleware.on_request(&ctx, &request2).unwrap();
885 assert!(matches!(decision, MiddlewareDecision::Continue));
886
887 let decision = middleware.on_request(&ctx, &request1).unwrap();
889 match decision {
890 MiddlewareDecision::Respond(cached) => assert_eq!(cached, response1),
891 MiddlewareDecision::Continue => panic!("Expected cache hit"),
892 }
893 }
894
895 #[test]
896 fn test_caching_middleware_tool_exclusion() {
897 let middleware =
898 ResponseCachingMiddleware::new().exclude_tools(vec!["excluded_tool".to_string()]);
899 let ctx = test_context();
900
901 let excluded_request = test_request(
902 "tools/call",
903 Some(serde_json::json!({"name": "excluded_tool", "arguments": {}})),
904 );
905 let included_request = test_request(
906 "tools/call",
907 Some(serde_json::json!({"name": "included_tool", "arguments": {}})),
908 );
909
910 middleware.on_request(&ctx, &excluded_request).unwrap();
912 middleware
913 .on_response(&ctx, &excluded_request, serde_json::json!({}))
914 .unwrap();
915
916 let decision = middleware.on_request(&ctx, &excluded_request).unwrap();
917 assert!(matches!(decision, MiddlewareDecision::Continue));
918
919 middleware.on_request(&ctx, &included_request).unwrap();
921 let response = serde_json::json!({"result": "included"});
922 middleware
923 .on_response(&ctx, &included_request, response.clone())
924 .unwrap();
925
926 let decision = middleware.on_request(&ctx, &included_request).unwrap();
927 match decision {
928 MiddlewareDecision::Respond(cached) => assert_eq!(cached, response),
929 MiddlewareDecision::Continue => panic!("Expected cache hit for included tool"),
930 }
931 }
932
933 #[test]
934 fn test_caching_middleware_disable_method() {
935 let middleware = ResponseCachingMiddleware::new().disable_tools_list();
936 let ctx = test_context();
937 let request = test_request("tools/list", None);
938
939 middleware.on_request(&ctx, &request).unwrap();
941 middleware
942 .on_response(&ctx, &request, serde_json::json!({}))
943 .unwrap();
944
945 let decision = middleware.on_request(&ctx, &request).unwrap();
946 assert!(matches!(decision, MiddlewareDecision::Continue));
947 }
948
949 #[test]
950 fn test_caching_middleware_clear() {
951 let middleware = ResponseCachingMiddleware::new();
952 let ctx = test_context();
953 let request = test_request("tools/list", None);
954
955 middleware.on_request(&ctx, &request).unwrap();
957 middleware
958 .on_response(&ctx, &request, serde_json::json!({}))
959 .unwrap();
960
961 let decision = middleware.on_request(&ctx, &request).unwrap();
963 assert!(matches!(decision, MiddlewareDecision::Respond(_)));
964
965 middleware.clear();
967
968 let decision = middleware.on_request(&ctx, &request).unwrap();
970 assert!(matches!(decision, MiddlewareDecision::Continue));
971 }
972
973 #[test]
974 fn test_caching_middleware_invalidate() {
975 let middleware = ResponseCachingMiddleware::new();
976 let ctx = test_context();
977 let request = test_request("tools/list", None);
978
979 middleware.on_request(&ctx, &request).unwrap();
981 middleware
982 .on_response(&ctx, &request, serde_json::json!({}))
983 .unwrap();
984
985 middleware.invalidate("tools/list", None);
987
988 let decision = middleware.on_request(&ctx, &request).unwrap();
990 assert!(matches!(decision, MiddlewareDecision::Continue));
991 }
992
993 #[test]
994 fn test_cache_stats_hit_rate() {
995 let stats = CacheStats {
996 hits: 75,
997 misses: 25,
998 entries: 10,
999 size_bytes: 1000,
1000 };
1001
1002 assert!((stats.hit_rate() - 75.0).abs() < 0.001);
1003 }
1004}