Skip to main content

mcp_proxy/
cache.rs

1//! Response caching middleware for the proxy.
2//!
3//! Caches `ReadResource` and `CallTool` responses with per-backend TTL.
4//! Cache keys are derived from the request type, name/URI, and arguments.
5//!
6//! # Configuration
7//!
8//! ```toml
9//! [[backends]]
10//! name = "slow-api"
11//! transport = "http"
12//! url = "http://localhost:8080"
13//!
14//! [backends.cache]
15//! resource_ttl_seconds = 300
16//! tool_ttl_seconds = 60
17//! max_entries = 1000
18//! ```
19
20use std::convert::Infallible;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::task::{Context, Poll};
26use std::time::Duration;
27
28use moka::future::Cache;
29use serde::Serialize;
30use tower::Service;
31use tower_mcp::router::{RouterRequest, RouterResponse};
32use tower_mcp_types::protocol::McpRequest;
33
34use crate::config::BackendCacheConfig;
35
36/// Per-backend cache with separate resource and tool caches (different TTLs).
37#[derive(Clone)]
38struct BackendCache {
39    namespace: String,
40    resource_cache: Option<Cache<String, RouterResponse>>,
41    tool_cache: Option<Cache<String, RouterResponse>>,
42    stats: Arc<CacheStats>,
43}
44
45/// Atomic hit/miss counters for a backend cache.
46struct CacheStats {
47    hits: AtomicU64,
48    misses: AtomicU64,
49}
50
51impl CacheStats {
52    fn new() -> Self {
53        Self {
54            hits: AtomicU64::new(0),
55            misses: AtomicU64::new(0),
56        }
57    }
58}
59
60/// Snapshot of cache statistics for a single backend.
61///
62/// Returned by [`CacheHandle::stats()`] to report hit/miss rates
63/// and entry counts per cached namespace.
64#[derive(Serialize, Clone)]
65pub struct CacheStatsSnapshot {
66    /// Backend namespace this cache covers.
67    pub namespace: String,
68    /// Total cache hits.
69    pub hits: u64,
70    /// Total cache misses.
71    pub misses: u64,
72    /// Hit rate as a fraction (0.0-1.0).
73    pub hit_rate: f64,
74    /// Current number of cached entries.
75    pub entry_count: u64,
76}
77
78/// Shared handle for querying cache stats and clearing caches.
79#[derive(Clone)]
80pub struct CacheHandle {
81    caches: Arc<Vec<BackendCache>>,
82}
83
84impl CacheHandle {
85    /// Get a snapshot of cache statistics for all backends.
86    pub fn stats(&self) -> Vec<CacheStatsSnapshot> {
87        self.caches
88            .iter()
89            .map(|bc| {
90                let hits = bc.stats.hits.load(Ordering::Relaxed);
91                let misses = bc.stats.misses.load(Ordering::Relaxed);
92                let total = hits + misses;
93                let entry_count = bc.resource_cache.as_ref().map_or(0, |c| c.entry_count())
94                    + bc.tool_cache.as_ref().map_or(0, |c| c.entry_count());
95                CacheStatsSnapshot {
96                    namespace: bc.namespace.clone(),
97                    hits,
98                    misses,
99                    hit_rate: if total > 0 {
100                        hits as f64 / total as f64
101                    } else {
102                        0.0
103                    },
104                    entry_count,
105                }
106            })
107            .collect()
108    }
109
110    /// Clear all cache entries and reset stats.
111    pub fn clear(&self) {
112        for bc in self.caches.iter() {
113            if let Some(c) = &bc.resource_cache {
114                c.invalidate_all();
115            }
116            if let Some(c) = &bc.tool_cache {
117                c.invalidate_all();
118            }
119            bc.stats.hits.store(0, Ordering::Relaxed);
120            bc.stats.misses.store(0, Ordering::Relaxed);
121        }
122    }
123}
124
125/// Tower service that caches resource reads and tool call results.
126#[derive(Clone)]
127pub struct CacheService<S> {
128    inner: S,
129    caches: Arc<Vec<BackendCache>>,
130}
131
132impl<S> CacheService<S> {
133    /// Create a new cache service and return it with a shareable handle.
134    pub fn new(inner: S, configs: Vec<(String, &BackendCacheConfig)>) -> (Self, CacheHandle) {
135        let caches: Vec<BackendCache> = configs
136            .into_iter()
137            .map(|(namespace, cfg)| {
138                let resource_cache = if cfg.resource_ttl_seconds > 0 {
139                    Some(
140                        Cache::builder()
141                            .max_capacity(cfg.max_entries)
142                            .time_to_live(Duration::from_secs(cfg.resource_ttl_seconds))
143                            .build(),
144                    )
145                } else {
146                    None
147                };
148                let tool_cache = if cfg.tool_ttl_seconds > 0 {
149                    Some(
150                        Cache::builder()
151                            .max_capacity(cfg.max_entries)
152                            .time_to_live(Duration::from_secs(cfg.tool_ttl_seconds))
153                            .build(),
154                    )
155                } else {
156                    None
157                };
158                BackendCache {
159                    namespace,
160                    resource_cache,
161                    tool_cache,
162                    stats: Arc::new(CacheStats::new()),
163                }
164            })
165            .collect();
166        let caches = Arc::new(caches);
167        let handle = CacheHandle {
168            caches: Arc::clone(&caches),
169        };
170        (Self { inner, caches }, handle)
171    }
172}
173
174/// Extract cache key and find the matching backend cache + stats.
175fn resolve_cache<'a>(
176    caches: &'a [BackendCache],
177    req: &McpRequest,
178) -> Option<(
179    &'a Cache<String, RouterResponse>,
180    String,
181    &'a Arc<CacheStats>,
182)> {
183    match req {
184        McpRequest::ReadResource(params) => {
185            let key = format!("res:{}", params.uri);
186            for bc in caches {
187                if params.uri.starts_with(&bc.namespace) {
188                    return bc.resource_cache.as_ref().map(|c| (c, key, &bc.stats));
189                }
190            }
191            None
192        }
193        McpRequest::CallTool(params) => {
194            let args = serde_json::to_string(&params.arguments).unwrap_or_default();
195            let key = format!("tool:{}:{}", params.name, args);
196            for bc in caches {
197                if params.name.starts_with(&bc.namespace) {
198                    return bc.tool_cache.as_ref().map(|c| (c, key, &bc.stats));
199                }
200            }
201            None
202        }
203        _ => None,
204    }
205}
206
207impl<S> Service<RouterRequest> for CacheService<S>
208where
209    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
210        + Clone
211        + Send
212        + 'static,
213    S::Future: Send,
214{
215    type Response = RouterResponse;
216    type Error = Infallible;
217    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
218
219    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220        self.inner.poll_ready(cx)
221    }
222
223    fn call(&mut self, req: RouterRequest) -> Self::Future {
224        let caches = Arc::clone(&self.caches);
225
226        if let Some((cache, key, stats)) = resolve_cache(&caches, &req.inner) {
227            let cache = cache.clone();
228            let stats = Arc::clone(stats);
229            let mut inner = self.inner.clone();
230
231            return Box::pin(async move {
232                // Cache hit -- return with current request ID
233                if let Some(cached) = cache.get(&key).await {
234                    stats.hits.fetch_add(1, Ordering::Relaxed);
235                    return Ok(RouterResponse {
236                        id: req.id,
237                        inner: cached.inner,
238                    });
239                }
240
241                stats.misses.fetch_add(1, Ordering::Relaxed);
242                let result = inner.call(req).await;
243
244                // Only cache successful MCP responses
245                let Ok(ref resp) = result;
246                if resp.inner.is_ok() {
247                    cache.insert(key, resp.clone()).await;
248                }
249
250                result
251            });
252        }
253
254        // No caching for this request type or backend
255        let fut = self.inner.call(req);
256        Box::pin(fut)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use tower_mcp::protocol::{McpRequest, McpResponse};
263
264    use super::CacheService;
265    use crate::config::BackendCacheConfig;
266    use crate::test_util::{MockService, call_service};
267
268    fn tool_call(name: &str) -> McpRequest {
269        McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
270            name: name.to_string(),
271            arguments: serde_json::json!({"key": "value"}),
272            meta: None,
273            task: None,
274        })
275    }
276
277    #[tokio::test]
278    async fn test_cache_hit_returns_same_result() {
279        let mock = MockService::with_tools(&["fs/read"]);
280        let cfg = BackendCacheConfig {
281            resource_ttl_seconds: 60,
282            tool_ttl_seconds: 60,
283            max_entries: 100,
284        };
285        let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
286
287        let resp1 = call_service(&mut svc, tool_call("fs/read")).await;
288        let resp2 = call_service(&mut svc, tool_call("fs/read")).await;
289
290        // Both should succeed with same content
291        match (resp1.inner.unwrap(), resp2.inner.unwrap()) {
292            (McpResponse::CallTool(r1), McpResponse::CallTool(r2)) => {
293                assert_eq!(r1.all_text(), r2.all_text());
294            }
295            _ => panic!("expected CallTool responses"),
296        }
297    }
298
299    #[tokio::test]
300    async fn test_cache_disabled_passes_through() {
301        let mock = MockService::with_tools(&["fs/read"]);
302        let cfg = BackendCacheConfig {
303            resource_ttl_seconds: 0,
304            tool_ttl_seconds: 0,
305            max_entries: 100,
306        };
307        let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
308
309        let resp = call_service(&mut svc, tool_call("fs/read")).await;
310        assert!(resp.inner.is_ok());
311    }
312
313    #[tokio::test]
314    async fn test_cache_non_matching_namespace_passes_through() {
315        let mock = MockService::with_tools(&["db/query"]);
316        let cfg = BackendCacheConfig {
317            resource_ttl_seconds: 60,
318            tool_ttl_seconds: 60,
319            max_entries: 100,
320        };
321        let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
322
323        let resp = call_service(&mut svc, tool_call("db/query")).await;
324        assert!(resp.inner.is_ok());
325    }
326
327    #[tokio::test]
328    async fn test_cache_list_tools_not_cached() {
329        let mock = MockService::with_tools(&["fs/read"]);
330        let cfg = BackendCacheConfig {
331            resource_ttl_seconds: 60,
332            tool_ttl_seconds: 60,
333            max_entries: 100,
334        };
335        let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
336
337        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
338        assert!(resp.inner.is_ok(), "list_tools should pass through");
339    }
340
341    #[tokio::test]
342    async fn test_cache_stats_tracks_hits_and_misses() {
343        let mock = MockService::with_tools(&["fs/read"]);
344        let cfg = BackendCacheConfig {
345            resource_ttl_seconds: 60,
346            tool_ttl_seconds: 60,
347            max_entries: 100,
348        };
349        let (mut svc, handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
350
351        // First call = miss
352        let _ = call_service(&mut svc, tool_call("fs/read")).await;
353        let stats = handle.stats();
354        assert_eq!(stats.len(), 1);
355        assert_eq!(stats[0].hits, 0);
356        assert_eq!(stats[0].misses, 1);
357
358        // Second call = hit
359        let _ = call_service(&mut svc, tool_call("fs/read")).await;
360        let stats = handle.stats();
361        assert_eq!(stats[0].hits, 1);
362        assert_eq!(stats[0].misses, 1);
363        assert!((stats[0].hit_rate - 0.5).abs() < f64::EPSILON);
364    }
365
366    #[tokio::test]
367    async fn test_cache_clear_resets_stats() {
368        let mock = MockService::with_tools(&["fs/read"]);
369        let cfg = BackendCacheConfig {
370            resource_ttl_seconds: 60,
371            tool_ttl_seconds: 60,
372            max_entries: 100,
373        };
374        let (mut svc, handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
375
376        let _ = call_service(&mut svc, tool_call("fs/read")).await;
377        let _ = call_service(&mut svc, tool_call("fs/read")).await;
378
379        handle.clear();
380        let stats = handle.stats();
381        assert_eq!(stats[0].hits, 0);
382        assert_eq!(stats[0].misses, 0);
383    }
384}