1use std::convert::Infallible;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::task::{Context, Poll};
26use std::time::Duration;
27
28use moka::future::Cache;
29use serde::Serialize;
30use tower::Service;
31use tower_mcp::router::{RouterRequest, RouterResponse};
32use tower_mcp_types::protocol::McpRequest;
33
34use crate::config::BackendCacheConfig;
35
36#[derive(Clone)]
38struct BackendCache {
39 namespace: String,
40 resource_cache: Option<Cache<String, RouterResponse>>,
41 tool_cache: Option<Cache<String, RouterResponse>>,
42 stats: Arc<CacheStats>,
43}
44
45struct CacheStats {
47 hits: AtomicU64,
48 misses: AtomicU64,
49}
50
51impl CacheStats {
52 fn new() -> Self {
53 Self {
54 hits: AtomicU64::new(0),
55 misses: AtomicU64::new(0),
56 }
57 }
58}
59
60#[derive(Serialize, Clone)]
65pub struct CacheStatsSnapshot {
66 pub namespace: String,
68 pub hits: u64,
70 pub misses: u64,
72 pub hit_rate: f64,
74 pub entry_count: u64,
76}
77
78#[derive(Clone)]
80pub struct CacheHandle {
81 caches: Arc<Vec<BackendCache>>,
82}
83
84impl CacheHandle {
85 pub fn stats(&self) -> Vec<CacheStatsSnapshot> {
87 self.caches
88 .iter()
89 .map(|bc| {
90 let hits = bc.stats.hits.load(Ordering::Relaxed);
91 let misses = bc.stats.misses.load(Ordering::Relaxed);
92 let total = hits + misses;
93 let entry_count = bc.resource_cache.as_ref().map_or(0, |c| c.entry_count())
94 + bc.tool_cache.as_ref().map_or(0, |c| c.entry_count());
95 CacheStatsSnapshot {
96 namespace: bc.namespace.clone(),
97 hits,
98 misses,
99 hit_rate: if total > 0 {
100 hits as f64 / total as f64
101 } else {
102 0.0
103 },
104 entry_count,
105 }
106 })
107 .collect()
108 }
109
110 pub fn clear(&self) {
112 for bc in self.caches.iter() {
113 if let Some(c) = &bc.resource_cache {
114 c.invalidate_all();
115 }
116 if let Some(c) = &bc.tool_cache {
117 c.invalidate_all();
118 }
119 bc.stats.hits.store(0, Ordering::Relaxed);
120 bc.stats.misses.store(0, Ordering::Relaxed);
121 }
122 }
123}
124
125#[derive(Clone)]
127pub struct CacheService<S> {
128 inner: S,
129 caches: Arc<Vec<BackendCache>>,
130}
131
132impl<S> CacheService<S> {
133 pub fn new(inner: S, configs: Vec<(String, &BackendCacheConfig)>) -> (Self, CacheHandle) {
135 let caches: Vec<BackendCache> = configs
136 .into_iter()
137 .map(|(namespace, cfg)| {
138 let resource_cache = if cfg.resource_ttl_seconds > 0 {
139 Some(
140 Cache::builder()
141 .max_capacity(cfg.max_entries)
142 .time_to_live(Duration::from_secs(cfg.resource_ttl_seconds))
143 .build(),
144 )
145 } else {
146 None
147 };
148 let tool_cache = if cfg.tool_ttl_seconds > 0 {
149 Some(
150 Cache::builder()
151 .max_capacity(cfg.max_entries)
152 .time_to_live(Duration::from_secs(cfg.tool_ttl_seconds))
153 .build(),
154 )
155 } else {
156 None
157 };
158 BackendCache {
159 namespace,
160 resource_cache,
161 tool_cache,
162 stats: Arc::new(CacheStats::new()),
163 }
164 })
165 .collect();
166 let caches = Arc::new(caches);
167 let handle = CacheHandle {
168 caches: Arc::clone(&caches),
169 };
170 (Self { inner, caches }, handle)
171 }
172}
173
174fn resolve_cache<'a>(
176 caches: &'a [BackendCache],
177 req: &McpRequest,
178) -> Option<(
179 &'a Cache<String, RouterResponse>,
180 String,
181 &'a Arc<CacheStats>,
182)> {
183 match req {
184 McpRequest::ReadResource(params) => {
185 let key = format!("res:{}", params.uri);
186 for bc in caches {
187 if params.uri.starts_with(&bc.namespace) {
188 return bc.resource_cache.as_ref().map(|c| (c, key, &bc.stats));
189 }
190 }
191 None
192 }
193 McpRequest::CallTool(params) => {
194 let args = serde_json::to_string(¶ms.arguments).unwrap_or_default();
195 let key = format!("tool:{}:{}", params.name, args);
196 for bc in caches {
197 if params.name.starts_with(&bc.namespace) {
198 return bc.tool_cache.as_ref().map(|c| (c, key, &bc.stats));
199 }
200 }
201 None
202 }
203 _ => None,
204 }
205}
206
207impl<S> Service<RouterRequest> for CacheService<S>
208where
209 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
210 + Clone
211 + Send
212 + 'static,
213 S::Future: Send,
214{
215 type Response = RouterResponse;
216 type Error = Infallible;
217 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
218
219 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220 self.inner.poll_ready(cx)
221 }
222
223 fn call(&mut self, req: RouterRequest) -> Self::Future {
224 let caches = Arc::clone(&self.caches);
225
226 if let Some((cache, key, stats)) = resolve_cache(&caches, &req.inner) {
227 let cache = cache.clone();
228 let stats = Arc::clone(stats);
229 let mut inner = self.inner.clone();
230
231 return Box::pin(async move {
232 if let Some(cached) = cache.get(&key).await {
234 stats.hits.fetch_add(1, Ordering::Relaxed);
235 return Ok(RouterResponse {
236 id: req.id,
237 inner: cached.inner,
238 });
239 }
240
241 stats.misses.fetch_add(1, Ordering::Relaxed);
242 let result = inner.call(req).await;
243
244 let Ok(ref resp) = result;
246 if resp.inner.is_ok() {
247 cache.insert(key, resp.clone()).await;
248 }
249
250 result
251 });
252 }
253
254 let fut = self.inner.call(req);
256 Box::pin(fut)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use tower_mcp::protocol::{McpRequest, McpResponse};
263
264 use super::CacheService;
265 use crate::config::BackendCacheConfig;
266 use crate::test_util::{MockService, call_service};
267
268 fn tool_call(name: &str) -> McpRequest {
269 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
270 name: name.to_string(),
271 arguments: serde_json::json!({"key": "value"}),
272 meta: None,
273 task: None,
274 })
275 }
276
277 #[tokio::test]
278 async fn test_cache_hit_returns_same_result() {
279 let mock = MockService::with_tools(&["fs/read"]);
280 let cfg = BackendCacheConfig {
281 resource_ttl_seconds: 60,
282 tool_ttl_seconds: 60,
283 max_entries: 100,
284 };
285 let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
286
287 let resp1 = call_service(&mut svc, tool_call("fs/read")).await;
288 let resp2 = call_service(&mut svc, tool_call("fs/read")).await;
289
290 match (resp1.inner.unwrap(), resp2.inner.unwrap()) {
292 (McpResponse::CallTool(r1), McpResponse::CallTool(r2)) => {
293 assert_eq!(r1.all_text(), r2.all_text());
294 }
295 _ => panic!("expected CallTool responses"),
296 }
297 }
298
299 #[tokio::test]
300 async fn test_cache_disabled_passes_through() {
301 let mock = MockService::with_tools(&["fs/read"]);
302 let cfg = BackendCacheConfig {
303 resource_ttl_seconds: 0,
304 tool_ttl_seconds: 0,
305 max_entries: 100,
306 };
307 let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
308
309 let resp = call_service(&mut svc, tool_call("fs/read")).await;
310 assert!(resp.inner.is_ok());
311 }
312
313 #[tokio::test]
314 async fn test_cache_non_matching_namespace_passes_through() {
315 let mock = MockService::with_tools(&["db/query"]);
316 let cfg = BackendCacheConfig {
317 resource_ttl_seconds: 60,
318 tool_ttl_seconds: 60,
319 max_entries: 100,
320 };
321 let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
322
323 let resp = call_service(&mut svc, tool_call("db/query")).await;
324 assert!(resp.inner.is_ok());
325 }
326
327 #[tokio::test]
328 async fn test_cache_list_tools_not_cached() {
329 let mock = MockService::with_tools(&["fs/read"]);
330 let cfg = BackendCacheConfig {
331 resource_ttl_seconds: 60,
332 tool_ttl_seconds: 60,
333 max_entries: 100,
334 };
335 let (mut svc, _handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
336
337 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
338 assert!(resp.inner.is_ok(), "list_tools should pass through");
339 }
340
341 #[tokio::test]
342 async fn test_cache_stats_tracks_hits_and_misses() {
343 let mock = MockService::with_tools(&["fs/read"]);
344 let cfg = BackendCacheConfig {
345 resource_ttl_seconds: 60,
346 tool_ttl_seconds: 60,
347 max_entries: 100,
348 };
349 let (mut svc, handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
350
351 let _ = call_service(&mut svc, tool_call("fs/read")).await;
353 let stats = handle.stats();
354 assert_eq!(stats.len(), 1);
355 assert_eq!(stats[0].hits, 0);
356 assert_eq!(stats[0].misses, 1);
357
358 let _ = call_service(&mut svc, tool_call("fs/read")).await;
360 let stats = handle.stats();
361 assert_eq!(stats[0].hits, 1);
362 assert_eq!(stats[0].misses, 1);
363 assert!((stats[0].hit_rate - 0.5).abs() < f64::EPSILON);
364 }
365
366 #[tokio::test]
367 async fn test_cache_clear_resets_stats() {
368 let mock = MockService::with_tools(&["fs/read"]);
369 let cfg = BackendCacheConfig {
370 resource_ttl_seconds: 60,
371 tool_ttl_seconds: 60,
372 max_entries: 100,
373 };
374 let (mut svc, handle) = CacheService::new(mock, vec![("fs/".to_string(), &cfg)]);
375
376 let _ = call_service(&mut svc, tool_call("fs/read")).await;
377 let _ = call_service(&mut svc, tool_call("fs/read")).await;
378
379 handle.clear();
380 let stats = handle.stats();
381 assert_eq!(stats[0].hits, 0);
382 assert_eq!(stats[0].misses, 0);
383 }
384}