Skip to main content

fastmcp_server/
caching.rs

1//! Response caching middleware for MCP servers.
2//!
3//! This module provides a caching middleware that can cache responses for
4//! various MCP methods to improve performance and reduce redundant processing.
5//!
6//! # Cached Methods
7//!
8//! By default, the middleware caches responses for:
9//! - `tools/list` - 5 minute TTL
10//! - `resources/list` - 5 minute TTL
11//! - `prompts/list` - 5 minute TTL
12//! - `resources/read` - 1 hour TTL
13//! - `prompts/get` - 1 hour TTL
14//! - `tools/call` - 1 hour TTL (configurable per tool)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use fastmcp::prelude::*;
20//! use fastmcp_server::caching::ResponseCachingMiddleware;
21//!
22//! let caching = ResponseCachingMiddleware::new()
23//!     .list_ttl_secs(600)  // 10 minute TTL for list operations
24//!     .call_ttl_secs(3600); // 1 hour TTL for call operations
25//!
26//! Server::new("my-server", "1.0.0")
27//!     .middleware(caching)
28//!     .run_stdio();
29//! ```
30
31use std::collections::HashMap;
32use std::hash::{Hash, Hasher};
33use std::sync::Mutex;
34use std::time::{Duration, Instant};
35
36use fastmcp_core::{McpContext, McpError, McpResult};
37use fastmcp_protocol::JsonRpcRequest;
38
39use crate::{Middleware, MiddlewareDecision};
40
41/// Default TTL for list operations (5 minutes).
42pub const DEFAULT_LIST_TTL_SECS: u64 = 300;
43
44/// Default TTL for call/get/read operations (1 hour).
45pub const DEFAULT_CALL_TTL_SECS: u64 = 3600;
46
47/// Maximum cache item size in bytes (1 MB).
48pub const DEFAULT_MAX_ITEM_SIZE: usize = 1024 * 1024;
49
50/// A cached response with expiration time.
51#[derive(Debug, Clone)]
52struct CacheEntry {
53    value: serde_json::Value,
54    expires_at: Instant,
55    size_bytes: usize,
56}
57
58impl CacheEntry {
59    fn new(value: serde_json::Value, ttl: Duration) -> Self {
60        let size_bytes = value.to_string().len();
61        Self {
62            value,
63            expires_at: Instant::now() + ttl,
64            size_bytes,
65        }
66    }
67
68    fn is_expired(&self) -> bool {
69        Instant::now() > self.expires_at
70    }
71}
72
73/// Cache key derived from method and parameters.
74#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75struct CacheKey {
76    method: String,
77    params_hash: u64,
78}
79
80impl CacheKey {
81    fn new(method: &str, params: Option<&serde_json::Value>) -> Self {
82        let params_hash = match params {
83            Some(v) => hash_json_value(v),
84            None => 0,
85        };
86        Self {
87            method: method.to_string(),
88            params_hash,
89        }
90    }
91}
92
93/// Computes a stable hash of a JSON value.
94fn hash_json_value(value: &serde_json::Value) -> u64 {
95    use std::collections::hash_map::DefaultHasher;
96
97    let mut hasher = DefaultHasher::new();
98
99    // Convert to canonical JSON string for consistent hashing
100    // This handles key ordering in objects
101    let json_str = serde_json::to_string(value).unwrap_or_default();
102    json_str.hash(&mut hasher);
103
104    hasher.finish()
105}
106
107/// Configuration for caching specific methods.
108#[derive(Debug, Clone)]
109pub struct MethodCacheConfig {
110    /// Whether caching is enabled for this method.
111    pub enabled: bool,
112    /// Time to live in seconds.
113    pub ttl_secs: u64,
114}
115
116impl Default for MethodCacheConfig {
117    fn default() -> Self {
118        Self {
119            enabled: true,
120            ttl_secs: DEFAULT_CALL_TTL_SECS,
121        }
122    }
123}
124
125/// Configuration for `tools/call` caching.
126#[derive(Debug, Clone, Default)]
127pub struct ToolCallCacheConfig {
128    /// Base configuration.
129    pub base: MethodCacheConfig,
130    /// Tools to include (if empty, include all).
131    pub included_tools: Vec<String>,
132    /// Tools to exclude (takes precedence over included).
133    pub excluded_tools: Vec<String>,
134}
135
136impl ToolCallCacheConfig {
137    /// Checks if a specific tool should be cached.
138    fn should_cache_tool(&self, tool_name: &str) -> bool {
139        if !self.base.enabled {
140            return false;
141        }
142
143        // Check exclusions first (takes precedence)
144        if self.excluded_tools.contains(&tool_name.to_string()) {
145            return false;
146        }
147
148        // If include list is specified, tool must be in it
149        if !self.included_tools.is_empty() {
150            return self.included_tools.contains(&tool_name.to_string());
151        }
152
153        // Default: include all
154        true
155    }
156}
157
158/// Simple LRU cache with TTL support.
159#[derive(Debug)]
160struct LruCache {
161    /// Map of keys to entries.
162    entries: HashMap<CacheKey, CacheEntry>,
163    /// Order of keys for LRU eviction (most recent at the end).
164    order: Vec<CacheKey>,
165    /// Maximum number of entries.
166    max_entries: usize,
167    /// Maximum total size in bytes.
168    max_size_bytes: usize,
169    /// Maximum size per item in bytes.
170    max_item_size: usize,
171    /// Current total size in bytes.
172    current_size_bytes: usize,
173}
174
175impl LruCache {
176    fn new(max_entries: usize, max_size_bytes: usize, max_item_size: usize) -> Self {
177        Self {
178            entries: HashMap::new(),
179            order: Vec::new(),
180            max_entries,
181            max_size_bytes,
182            max_item_size,
183            current_size_bytes: 0,
184        }
185    }
186
187    fn get(&mut self, key: &CacheKey) -> Option<serde_json::Value> {
188        // Check if entry exists and is not expired
189        if let Some(entry) = self.entries.get(key) {
190            if entry.is_expired() {
191                // Remove expired entry
192                self.remove(key);
193                return None;
194            }
195
196            // Move to end of order (most recently used)
197            if let Some(pos) = self.order.iter().position(|k| k == key) {
198                let k = self.order.remove(pos);
199                self.order.push(k);
200            }
201
202            return Some(entry.value.clone());
203        }
204        None
205    }
206
207    fn insert(&mut self, key: CacheKey, value: serde_json::Value, ttl: Duration) {
208        let entry = CacheEntry::new(value, ttl);
209
210        // Check item size limit
211        if entry.size_bytes > self.max_item_size {
212            // Silently skip oversized items (matching Python behavior)
213            return;
214        }
215
216        // Remove old entry if it exists
217        if self.entries.contains_key(&key) {
218            self.remove(&key);
219        }
220
221        // Evict entries if needed to make room
222        while self.entries.len() >= self.max_entries
223            || self.current_size_bytes + entry.size_bytes > self.max_size_bytes
224        {
225            if self.order.is_empty() {
226                break;
227            }
228            // Evict least recently used (first in order)
229            let oldest_key = self.order.remove(0);
230            if let Some(old_entry) = self.entries.remove(&oldest_key) {
231                self.current_size_bytes -= old_entry.size_bytes;
232            }
233        }
234
235        // Also remove expired entries opportunistically
236        self.evict_expired();
237
238        // Insert new entry
239        self.current_size_bytes += entry.size_bytes;
240        self.entries.insert(key.clone(), entry);
241        self.order.push(key);
242    }
243
244    fn remove(&mut self, key: &CacheKey) {
245        if let Some(entry) = self.entries.remove(key) {
246            self.current_size_bytes -= entry.size_bytes;
247            if let Some(pos) = self.order.iter().position(|k| k == key) {
248                self.order.remove(pos);
249            }
250        }
251    }
252
253    fn evict_expired(&mut self) {
254        let expired_keys: Vec<CacheKey> = self
255            .entries
256            .iter()
257            .filter(|(_, entry)| entry.is_expired())
258            .map(|(key, _)| key.clone())
259            .collect();
260
261        for key in expired_keys {
262            self.remove(&key);
263        }
264    }
265
266    fn clear(&mut self) {
267        self.entries.clear();
268        self.order.clear();
269        self.current_size_bytes = 0;
270    }
271
272    fn len(&self) -> usize {
273        self.entries.len()
274    }
275
276    #[allow(dead_code)]
277    fn is_empty(&self) -> bool {
278        self.entries.is_empty()
279    }
280}
281
282/// Cache statistics.
283#[derive(Debug, Clone, Default)]
284pub struct CacheStats {
285    /// Number of cache hits.
286    pub hits: u64,
287    /// Number of cache misses.
288    pub misses: u64,
289    /// Number of entries currently in cache.
290    pub entries: usize,
291    /// Current cache size in bytes.
292    pub size_bytes: usize,
293}
294
295impl CacheStats {
296    /// Returns the hit rate as a percentage.
297    #[must_use]
298    pub fn hit_rate(&self) -> f64 {
299        let total = self.hits + self.misses;
300        if total == 0 {
301            0.0
302        } else {
303            (self.hits as f64 / total as f64) * 100.0
304        }
305    }
306}
307
308/// Response caching middleware for MCP servers.
309///
310/// Caches responses for list and call operations with configurable TTL.
311/// Uses an LRU eviction strategy when the cache is full.
312pub struct ResponseCachingMiddleware {
313    /// Cache storage.
314    cache: Mutex<LruCache>,
315    /// TTL for list operations.
316    list_ttl: Duration,
317    /// TTL for call/get/read operations.
318    call_ttl: Duration,
319    /// Configuration for tools/list caching.
320    tools_list_config: MethodCacheConfig,
321    /// Configuration for resources/list caching.
322    resources_list_config: MethodCacheConfig,
323    /// Configuration for prompts/list caching.
324    prompts_list_config: MethodCacheConfig,
325    /// Configuration for tools/call caching.
326    tools_call_config: ToolCallCacheConfig,
327    /// Configuration for resources/read caching.
328    resources_read_config: MethodCacheConfig,
329    /// Configuration for prompts/get caching.
330    prompts_get_config: MethodCacheConfig,
331    /// Statistics tracking.
332    stats: Mutex<CacheStats>,
333}
334
335impl std::fmt::Debug for ResponseCachingMiddleware {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        f.debug_struct("ResponseCachingMiddleware")
338            .field("list_ttl", &self.list_ttl)
339            .field("call_ttl", &self.call_ttl)
340            .finish_non_exhaustive()
341    }
342}
343
344impl Default for ResponseCachingMiddleware {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350impl ResponseCachingMiddleware {
351    /// Creates a new response caching middleware with default settings.
352    #[must_use]
353    pub fn new() -> Self {
354        Self {
355            cache: Mutex::new(LruCache::new(
356                1000,
357                100 * 1024 * 1024,
358                DEFAULT_MAX_ITEM_SIZE,
359            )),
360            list_ttl: Duration::from_secs(DEFAULT_LIST_TTL_SECS),
361            call_ttl: Duration::from_secs(DEFAULT_CALL_TTL_SECS),
362            tools_list_config: MethodCacheConfig {
363                enabled: true,
364                ttl_secs: DEFAULT_LIST_TTL_SECS,
365            },
366            resources_list_config: MethodCacheConfig {
367                enabled: true,
368                ttl_secs: DEFAULT_LIST_TTL_SECS,
369            },
370            prompts_list_config: MethodCacheConfig {
371                enabled: true,
372                ttl_secs: DEFAULT_LIST_TTL_SECS,
373            },
374            tools_call_config: ToolCallCacheConfig {
375                base: MethodCacheConfig {
376                    enabled: true,
377                    ttl_secs: DEFAULT_CALL_TTL_SECS,
378                },
379                included_tools: Vec::new(),
380                excluded_tools: Vec::new(),
381            },
382            resources_read_config: MethodCacheConfig {
383                enabled: true,
384                ttl_secs: DEFAULT_CALL_TTL_SECS,
385            },
386            prompts_get_config: MethodCacheConfig {
387                enabled: true,
388                ttl_secs: DEFAULT_CALL_TTL_SECS,
389            },
390            stats: Mutex::new(CacheStats::default()),
391        }
392    }
393
394    /// Sets the maximum number of cache entries.
395    #[must_use]
396    pub fn max_entries(self, max: usize) -> Self {
397        let max_size = {
398            let cache = self
399                .cache
400                .lock()
401                .unwrap_or_else(std::sync::PoisonError::into_inner);
402            cache.max_size_bytes
403        };
404        let max_item_size = {
405            let cache = self
406                .cache
407                .lock()
408                .unwrap_or_else(std::sync::PoisonError::into_inner);
409            cache.max_item_size
410        };
411        Self {
412            cache: Mutex::new(LruCache::new(max, max_size, max_item_size)),
413            ..self
414        }
415    }
416
417    /// Sets the maximum cache size in bytes.
418    #[must_use]
419    pub fn max_size_bytes(self, max: usize) -> Self {
420        let max_entries = {
421            let cache = self
422                .cache
423                .lock()
424                .unwrap_or_else(std::sync::PoisonError::into_inner);
425            cache.max_entries
426        };
427        let max_item_size = {
428            let cache = self
429                .cache
430                .lock()
431                .unwrap_or_else(std::sync::PoisonError::into_inner);
432            cache.max_item_size
433        };
434        Self {
435            cache: Mutex::new(LruCache::new(max_entries, max, max_item_size)),
436            ..self
437        }
438    }
439
440    /// Sets the maximum size per cache item in bytes.
441    #[must_use]
442    pub fn max_item_size(self, max: usize) -> Self {
443        let max_entries = {
444            let cache = self
445                .cache
446                .lock()
447                .unwrap_or_else(std::sync::PoisonError::into_inner);
448            cache.max_entries
449        };
450        let max_size = {
451            let cache = self
452                .cache
453                .lock()
454                .unwrap_or_else(std::sync::PoisonError::into_inner);
455            cache.max_size_bytes
456        };
457        Self {
458            cache: Mutex::new(LruCache::new(max_entries, max_size, max)),
459            ..self
460        }
461    }
462
463    /// Sets the TTL for list operations (tools/list, resources/list, prompts/list).
464    #[must_use]
465    pub fn list_ttl_secs(mut self, secs: u64) -> Self {
466        self.list_ttl = Duration::from_secs(secs);
467        self.tools_list_config.ttl_secs = secs;
468        self.resources_list_config.ttl_secs = secs;
469        self.prompts_list_config.ttl_secs = secs;
470        self
471    }
472
473    /// Sets the TTL for call/get/read operations.
474    #[must_use]
475    pub fn call_ttl_secs(mut self, secs: u64) -> Self {
476        self.call_ttl = Duration::from_secs(secs);
477        self.tools_call_config.base.ttl_secs = secs;
478        self.resources_read_config.ttl_secs = secs;
479        self.prompts_get_config.ttl_secs = secs;
480        self
481    }
482
483    /// Disables caching for tools/list.
484    #[must_use]
485    pub fn disable_tools_list(mut self) -> Self {
486        self.tools_list_config.enabled = false;
487        self
488    }
489
490    /// Disables caching for resources/list.
491    #[must_use]
492    pub fn disable_resources_list(mut self) -> Self {
493        self.resources_list_config.enabled = false;
494        self
495    }
496
497    /// Disables caching for prompts/list.
498    #[must_use]
499    pub fn disable_prompts_list(mut self) -> Self {
500        self.prompts_list_config.enabled = false;
501        self
502    }
503
504    /// Disables caching for tools/call.
505    #[must_use]
506    pub fn disable_tools_call(mut self) -> Self {
507        self.tools_call_config.base.enabled = false;
508        self
509    }
510
511    /// Disables caching for resources/read.
512    #[must_use]
513    pub fn disable_resources_read(mut self) -> Self {
514        self.resources_read_config.enabled = false;
515        self
516    }
517
518    /// Disables caching for prompts/get.
519    #[must_use]
520    pub fn disable_prompts_get(mut self) -> Self {
521        self.prompts_get_config.enabled = false;
522        self
523    }
524
525    /// Sets the list of tools to include in caching (empty = all).
526    #[must_use]
527    pub fn include_tools(mut self, tools: Vec<String>) -> Self {
528        self.tools_call_config.included_tools = tools;
529        self
530    }
531
532    /// Sets the list of tools to exclude from caching.
533    #[must_use]
534    pub fn exclude_tools(mut self, tools: Vec<String>) -> Self {
535        self.tools_call_config.excluded_tools = tools;
536        self
537    }
538
539    /// Returns current cache statistics.
540    #[must_use]
541    pub fn stats(&self) -> CacheStats {
542        let cache = self
543            .cache
544            .lock()
545            .unwrap_or_else(std::sync::PoisonError::into_inner);
546        let mut stats = self
547            .stats
548            .lock()
549            .unwrap_or_else(std::sync::PoisonError::into_inner)
550            .clone();
551        stats.entries = cache.len();
552        stats.size_bytes = cache.current_size_bytes;
553        stats
554    }
555
556    /// Clears the entire cache.
557    pub fn clear(&self) {
558        let mut cache = self
559            .cache
560            .lock()
561            .unwrap_or_else(std::sync::PoisonError::into_inner);
562        cache.clear();
563    }
564
565    /// Invalidates a specific cache entry by method and params.
566    pub fn invalidate(&self, method: &str, params: Option<&serde_json::Value>) {
567        let key = CacheKey::new(method, params);
568        let mut cache = self
569            .cache
570            .lock()
571            .unwrap_or_else(std::sync::PoisonError::into_inner);
572        cache.remove(&key);
573    }
574
575    /// Checks if a method should be cached.
576    fn should_cache_method(&self, method: &str, params: Option<&serde_json::Value>) -> bool {
577        match method {
578            "tools/list" => self.tools_list_config.enabled,
579            "resources/list" => self.resources_list_config.enabled,
580            "prompts/list" => self.prompts_list_config.enabled,
581            "resources/read" => self.resources_read_config.enabled,
582            "prompts/get" => self.prompts_get_config.enabled,
583            "tools/call" => {
584                if !self.tools_call_config.base.enabled {
585                    return false;
586                }
587                // Extract tool name from params
588                if let Some(params) = params {
589                    if let Some(tool_name) = params.get("name").and_then(|v| v.as_str()) {
590                        return self.tools_call_config.should_cache_tool(tool_name);
591                    }
592                }
593                false
594            }
595            _ => false,
596        }
597    }
598
599    /// Gets the TTL for a specific method.
600    fn get_ttl(&self, method: &str) -> Duration {
601        match method {
602            "tools/list" => Duration::from_secs(self.tools_list_config.ttl_secs),
603            "resources/list" => Duration::from_secs(self.resources_list_config.ttl_secs),
604            "prompts/list" => Duration::from_secs(self.prompts_list_config.ttl_secs),
605            "tools/call" => Duration::from_secs(self.tools_call_config.base.ttl_secs),
606            "resources/read" => Duration::from_secs(self.resources_read_config.ttl_secs),
607            "prompts/get" => Duration::from_secs(self.prompts_get_config.ttl_secs),
608            _ => self.call_ttl,
609        }
610    }
611
612    fn record_hit(&self) {
613        let mut stats = self
614            .stats
615            .lock()
616            .unwrap_or_else(std::sync::PoisonError::into_inner);
617        stats.hits += 1;
618    }
619
620    fn record_miss(&self) {
621        let mut stats = self
622            .stats
623            .lock()
624            .unwrap_or_else(std::sync::PoisonError::into_inner);
625        stats.misses += 1;
626    }
627}
628
629impl Middleware for ResponseCachingMiddleware {
630    fn on_request(
631        &self,
632        _ctx: &McpContext,
633        request: &JsonRpcRequest,
634    ) -> McpResult<MiddlewareDecision> {
635        // Check if this method should be cached
636        if !self.should_cache_method(&request.method, request.params.as_ref()) {
637            return Ok(MiddlewareDecision::Continue);
638        }
639
640        // Try to get cached response
641        let key = CacheKey::new(&request.method, request.params.as_ref());
642        let mut cache = self
643            .cache
644            .lock()
645            .unwrap_or_else(std::sync::PoisonError::into_inner);
646
647        if let Some(value) = cache.get(&key) {
648            self.record_hit();
649            return Ok(MiddlewareDecision::Respond(value));
650        }
651
652        self.record_miss();
653        Ok(MiddlewareDecision::Continue)
654    }
655
656    fn on_response(
657        &self,
658        _ctx: &McpContext,
659        request: &JsonRpcRequest,
660        response: serde_json::Value,
661    ) -> McpResult<serde_json::Value> {
662        // Only cache if this method is cacheable
663        if !self.should_cache_method(&request.method, request.params.as_ref()) {
664            return Ok(response);
665        }
666
667        // Store in cache
668        let key = CacheKey::new(&request.method, request.params.as_ref());
669        let ttl = self.get_ttl(&request.method);
670
671        let mut cache = self
672            .cache
673            .lock()
674            .unwrap_or_else(std::sync::PoisonError::into_inner);
675
676        cache.insert(key, response.clone(), ttl);
677
678        Ok(response)
679    }
680
681    fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
682        // Don't cache errors, just pass them through
683        error
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use asupersync::Cx;
691
692    fn test_context() -> McpContext {
693        let cx = Cx::for_testing();
694        McpContext::new(cx, 1)
695    }
696
697    fn test_request(method: &str, params: Option<serde_json::Value>) -> JsonRpcRequest {
698        JsonRpcRequest {
699            jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
700            method: method.to_string(),
701            params,
702            id: Some(fastmcp_protocol::RequestId::Number(1)),
703        }
704    }
705
706    // ========================================
707    // LruCache tests
708    // ========================================
709
710    #[test]
711    fn test_lru_cache_basic_operations() {
712        let mut cache = LruCache::new(10, 1024 * 1024, 1024);
713
714        let key = CacheKey::new("test", None);
715        let value = serde_json::json!({"result": "cached"});
716
717        // Insert and retrieve
718        cache.insert(key.clone(), value.clone(), Duration::from_secs(60));
719        let retrieved = cache.get(&key);
720        assert_eq!(retrieved, Some(value));
721    }
722
723    #[test]
724    fn test_lru_cache_expiration() {
725        let mut cache = LruCache::new(10, 1024 * 1024, 1024);
726
727        let key = CacheKey::new("test", None);
728        let value = serde_json::json!({"result": "cached"});
729
730        // Insert with very short TTL
731        cache.insert(key.clone(), value, Duration::from_millis(1));
732
733        // Wait for expiration
734        std::thread::sleep(std::time::Duration::from_millis(10));
735
736        // Should be expired
737        assert!(cache.get(&key).is_none());
738    }
739
740    #[test]
741    fn test_lru_cache_eviction() {
742        let mut cache = LruCache::new(2, 1024 * 1024, 1024);
743
744        let key1 = CacheKey::new("test1", None);
745        let key2 = CacheKey::new("test2", None);
746        let key3 = CacheKey::new("test3", None);
747
748        cache.insert(
749            key1.clone(),
750            serde_json::json!("v1"),
751            Duration::from_secs(60),
752        );
753        cache.insert(
754            key2.clone(),
755            serde_json::json!("v2"),
756            Duration::from_secs(60),
757        );
758
759        // Should evict key1 (LRU)
760        cache.insert(
761            key3.clone(),
762            serde_json::json!("v3"),
763            Duration::from_secs(60),
764        );
765
766        assert!(cache.get(&key1).is_none());
767        assert!(cache.get(&key2).is_some());
768        assert!(cache.get(&key3).is_some());
769    }
770
771    #[test]
772    fn test_lru_cache_size_limit() {
773        // Very small size limit
774        let mut cache = LruCache::new(100, 50, 1024);
775
776        let key1 = CacheKey::new("test1", None);
777        let key2 = CacheKey::new("test2", None);
778
779        // First entry should fit
780        cache.insert(
781            key1.clone(),
782            serde_json::json!("short"),
783            Duration::from_secs(60),
784        );
785        assert_eq!(cache.len(), 1);
786
787        // Second entry should cause eviction
788        cache.insert(
789            key2.clone(),
790            serde_json::json!("another"),
791            Duration::from_secs(60),
792        );
793        // Cache should have evicted the first entry to make room
794        assert!(cache.len() <= 2);
795    }
796
797    #[test]
798    fn test_lru_cache_oversized_item_rejected() {
799        let mut cache = LruCache::new(10, 1024 * 1024, 10); // max 10 bytes per item
800
801        let key = CacheKey::new("test", None);
802        let large_value = serde_json::json!({"data": "this is much longer than 10 bytes"});
803
804        cache.insert(key.clone(), large_value, Duration::from_secs(60));
805
806        // Should not be stored
807        assert!(cache.get(&key).is_none());
808    }
809
810    // ========================================
811    // ResponseCachingMiddleware tests
812    // ========================================
813
814    #[test]
815    fn test_caching_middleware_caches_tools_list() {
816        let middleware = ResponseCachingMiddleware::new();
817        let ctx = test_context();
818        let request = test_request("tools/list", None);
819
820        // First request: miss, continue
821        let decision = middleware.on_request(&ctx, &request).unwrap();
822        assert!(matches!(decision, MiddlewareDecision::Continue));
823
824        // Simulate response
825        let response = serde_json::json!({"tools": []});
826        middleware
827            .on_response(&ctx, &request, response.clone())
828            .unwrap();
829
830        // Second request: hit, respond from cache
831        let decision = middleware.on_request(&ctx, &request).unwrap();
832        match decision {
833            MiddlewareDecision::Respond(cached) => assert_eq!(cached, response),
834            MiddlewareDecision::Continue => panic!("Expected cache hit"),
835        }
836
837        // Check stats
838        let stats = middleware.stats();
839        assert_eq!(stats.hits, 1);
840        assert_eq!(stats.misses, 1);
841    }
842
843    #[test]
844    fn test_caching_middleware_skips_non_cacheable_methods() {
845        let middleware = ResponseCachingMiddleware::new();
846        let ctx = test_context();
847        let request = test_request("initialize", None);
848
849        // Should continue (not cached)
850        let decision = middleware.on_request(&ctx, &request).unwrap();
851        assert!(matches!(decision, MiddlewareDecision::Continue));
852
853        // Even after response, next request should not hit cache
854        middleware
855            .on_response(&ctx, &request, serde_json::json!({}))
856            .unwrap();
857
858        let decision = middleware.on_request(&ctx, &request).unwrap();
859        assert!(matches!(decision, MiddlewareDecision::Continue));
860    }
861
862    #[test]
863    fn test_caching_middleware_different_params_different_keys() {
864        let middleware = ResponseCachingMiddleware::new();
865        let ctx = test_context();
866
867        let request1 = test_request(
868            "tools/call",
869            Some(serde_json::json!({"name": "tool_a", "arguments": {}})),
870        );
871        let request2 = test_request(
872            "tools/call",
873            Some(serde_json::json!({"name": "tool_b", "arguments": {}})),
874        );
875
876        // Cache response for request1
877        middleware.on_request(&ctx, &request1).unwrap();
878        let response1 = serde_json::json!({"result": "a"});
879        middleware
880            .on_response(&ctx, &request1, response1.clone())
881            .unwrap();
882
883        // Request2 should not hit cache
884        let decision = middleware.on_request(&ctx, &request2).unwrap();
885        assert!(matches!(decision, MiddlewareDecision::Continue));
886
887        // Request1 should hit cache
888        let decision = middleware.on_request(&ctx, &request1).unwrap();
889        match decision {
890            MiddlewareDecision::Respond(cached) => assert_eq!(cached, response1),
891            MiddlewareDecision::Continue => panic!("Expected cache hit"),
892        }
893    }
894
895    #[test]
896    fn test_caching_middleware_tool_exclusion() {
897        let middleware =
898            ResponseCachingMiddleware::new().exclude_tools(vec!["excluded_tool".to_string()]);
899        let ctx = test_context();
900
901        let excluded_request = test_request(
902            "tools/call",
903            Some(serde_json::json!({"name": "excluded_tool", "arguments": {}})),
904        );
905        let included_request = test_request(
906            "tools/call",
907            Some(serde_json::json!({"name": "included_tool", "arguments": {}})),
908        );
909
910        // Excluded tool should not be cached
911        middleware.on_request(&ctx, &excluded_request).unwrap();
912        middleware
913            .on_response(&ctx, &excluded_request, serde_json::json!({}))
914            .unwrap();
915
916        let decision = middleware.on_request(&ctx, &excluded_request).unwrap();
917        assert!(matches!(decision, MiddlewareDecision::Continue));
918
919        // Included tool should be cached
920        middleware.on_request(&ctx, &included_request).unwrap();
921        let response = serde_json::json!({"result": "included"});
922        middleware
923            .on_response(&ctx, &included_request, response.clone())
924            .unwrap();
925
926        let decision = middleware.on_request(&ctx, &included_request).unwrap();
927        match decision {
928            MiddlewareDecision::Respond(cached) => assert_eq!(cached, response),
929            MiddlewareDecision::Continue => panic!("Expected cache hit for included tool"),
930        }
931    }
932
933    #[test]
934    fn test_caching_middleware_disable_method() {
935        let middleware = ResponseCachingMiddleware::new().disable_tools_list();
936        let ctx = test_context();
937        let request = test_request("tools/list", None);
938
939        // Should not cache
940        middleware.on_request(&ctx, &request).unwrap();
941        middleware
942            .on_response(&ctx, &request, serde_json::json!({}))
943            .unwrap();
944
945        let decision = middleware.on_request(&ctx, &request).unwrap();
946        assert!(matches!(decision, MiddlewareDecision::Continue));
947    }
948
949    #[test]
950    fn test_caching_middleware_clear() {
951        let middleware = ResponseCachingMiddleware::new();
952        let ctx = test_context();
953        let request = test_request("tools/list", None);
954
955        // Cache a response
956        middleware.on_request(&ctx, &request).unwrap();
957        middleware
958            .on_response(&ctx, &request, serde_json::json!({}))
959            .unwrap();
960
961        // Verify cached
962        let decision = middleware.on_request(&ctx, &request).unwrap();
963        assert!(matches!(decision, MiddlewareDecision::Respond(_)));
964
965        // Clear cache
966        middleware.clear();
967
968        // Should miss now
969        let decision = middleware.on_request(&ctx, &request).unwrap();
970        assert!(matches!(decision, MiddlewareDecision::Continue));
971    }
972
973    #[test]
974    fn test_caching_middleware_invalidate() {
975        let middleware = ResponseCachingMiddleware::new();
976        let ctx = test_context();
977        let request = test_request("tools/list", None);
978
979        // Cache a response
980        middleware.on_request(&ctx, &request).unwrap();
981        middleware
982            .on_response(&ctx, &request, serde_json::json!({}))
983            .unwrap();
984
985        // Invalidate specific entry
986        middleware.invalidate("tools/list", None);
987
988        // Should miss now
989        let decision = middleware.on_request(&ctx, &request).unwrap();
990        assert!(matches!(decision, MiddlewareDecision::Continue));
991    }
992
993    #[test]
994    fn test_cache_stats_hit_rate() {
995        let stats = CacheStats {
996            hits: 75,
997            misses: 25,
998            entries: 10,
999            size_bytes: 1000,
1000        };
1001
1002        assert!((stats.hit_rate() - 75.0).abs() < 0.001);
1003    }
1004}