1use std::convert::Infallible;
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicU64, Ordering};
35use std::task::{Context, Poll};
36use std::time::Duration;
37
38use moka::future::Cache;
39use serde::Serialize;
40use tower::{Layer, Service};
41use tower_mcp::router::{RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::McpRequest;
43
44use crate::config::{BackendCacheConfig, CacheBackendConfig};
45
46#[derive(Clone)]
55pub(crate) enum CacheStore {
56 Memory(Cache<String, RouterResponse>),
58 #[cfg(feature = "redis-cache")]
60 Redis {
61 client: redis::Client,
62 prefix: String,
63 ttl: Duration,
64 },
65 #[cfg(feature = "sqlite-cache")]
67 Sqlite {
68 conn: Arc<std::sync::Mutex<rusqlite::Connection>>,
69 ttl: Duration,
70 },
71}
72
73impl CacheStore {
74 async fn get(&self, key: &str) -> Option<RouterResponse> {
76 match self {
77 CacheStore::Memory(cache) => cache.get(key).await,
78 #[cfg(feature = "redis-cache")]
79 CacheStore::Redis {
80 client,
81 prefix,
82 ttl: _,
83 } => {
84 let full_key = format!("{prefix}{key}");
85 let mut conn = client.get_multiplexed_async_connection().await.ok()?;
86 let data: Option<String> =
87 redis::AsyncCommands::get(&mut conn, &full_key).await.ok()?;
88 data.and_then(|s| serde_json::from_str(&s).ok())
89 }
90 #[cfg(feature = "sqlite-cache")]
91 CacheStore::Sqlite { conn, ttl: _ } => {
92 let key = key.to_string();
93 let conn = conn.lock().ok()?;
94 let now = std::time::SystemTime::now()
95 .duration_since(std::time::UNIX_EPOCH)
96 .unwrap_or_default()
97 .as_secs() as i64;
98 let result: Option<String> = conn
99 .query_row(
100 "SELECT value FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
101 rusqlite::params![key, now],
102 |row| row.get(0),
103 )
104 .ok();
105 result.and_then(|s| serde_json::from_str(&s).ok())
106 }
107 }
108 }
109
110 async fn insert(&self, key: String, value: RouterResponse) {
112 match self {
113 CacheStore::Memory(cache) => {
114 cache.insert(key, value).await;
115 }
116 #[cfg(feature = "redis-cache")]
117 CacheStore::Redis {
118 client,
119 prefix,
120 ttl,
121 } => {
122 let full_key = format!("{prefix}{key}");
123 if let Ok(json) = serde_json::to_string(&value)
124 && let Ok(mut conn) = client.get_multiplexed_async_connection().await
125 {
126 let _: Result<(), _> =
127 redis::AsyncCommands::set_ex(&mut conn, &full_key, &json, ttl.as_secs())
128 .await;
129 }
130 }
131 #[cfg(feature = "sqlite-cache")]
132 CacheStore::Sqlite { conn, ttl } => {
133 if let Ok(json) = serde_json::to_string(&value) {
134 let expires_at = std::time::SystemTime::now()
135 .duration_since(std::time::UNIX_EPOCH)
136 .unwrap_or_default()
137 .as_secs() as i64
138 + ttl.as_secs() as i64;
139 if let Ok(conn) = conn.lock() {
140 let _ = conn.execute(
141 "INSERT OR REPLACE INTO cache_entries (key, value, expires_at) VALUES (?1, ?2, ?3)",
142 rusqlite::params![key, json, expires_at],
143 );
144 }
145 }
146 }
147 }
148 }
149
150 async fn invalidate_all(&self) {
152 match self {
153 CacheStore::Memory(cache) => {
154 cache.invalidate_all();
155 }
156 #[cfg(feature = "redis-cache")]
157 CacheStore::Redis {
158 client,
159 prefix,
160 ttl: _,
161 } => {
162 if let Ok(mut conn) = client.get_multiplexed_async_connection().await {
163 let pattern = format!("{prefix}*");
164 let keys: Vec<String> = redis::AsyncCommands::keys(&mut conn, &pattern)
165 .await
166 .unwrap_or_default();
167 if !keys.is_empty() {
168 let _: Result<(), _> = redis::AsyncCommands::del(&mut conn, &keys).await;
169 }
170 }
171 }
172 #[cfg(feature = "sqlite-cache")]
173 CacheStore::Sqlite { conn, ttl: _ } => {
174 if let Ok(conn) = conn.lock() {
175 let _ = conn.execute("DELETE FROM cache_entries", []);
176 }
177 }
178 }
179 }
180
181 async fn entry_count(&self) -> u64 {
183 match self {
184 CacheStore::Memory(cache) => cache.entry_count(),
185 #[cfg(feature = "redis-cache")]
186 CacheStore::Redis {
187 client,
188 prefix,
189 ttl: _,
190 } => {
191 if let Ok(mut conn) = client.get_multiplexed_async_connection().await {
192 let pattern = format!("{prefix}*");
193 let keys: Vec<String> = redis::AsyncCommands::keys(&mut conn, &pattern)
194 .await
195 .unwrap_or_default();
196 keys.len() as u64
197 } else {
198 0
199 }
200 }
201 #[cfg(feature = "sqlite-cache")]
202 CacheStore::Sqlite { conn, ttl: _ } => {
203 let now = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .unwrap_or_default()
206 .as_secs() as i64;
207 if let Ok(conn) = conn.lock() {
208 conn.query_row(
209 "SELECT COUNT(*) FROM cache_entries WHERE expires_at > ?1",
210 rusqlite::params![now],
211 |row| row.get::<_, i64>(0),
212 )
213 .unwrap_or(0) as u64
214 } else {
215 0
216 }
217 }
218 }
219 }
220}
221
222fn build_cache_store(
225 backend_config: &CacheBackendConfig,
226 ttl: Duration,
227 max_entries: u64,
228) -> CacheStore {
229 match backend_config.backend.as_str() {
230 #[cfg(feature = "redis-cache")]
231 "redis" => {
232 let url = backend_config.url.as_deref().unwrap_or("redis://127.0.0.1");
233 let client =
234 redis::Client::open(url).expect("invalid Redis URL in cache configuration");
235 CacheStore::Redis {
236 client,
237 prefix: backend_config.prefix.clone(),
238 ttl,
239 }
240 }
241 #[cfg(feature = "sqlite-cache")]
242 "sqlite" => {
243 let path = backend_config.url.as_deref().unwrap_or("cache.db");
244 let conn =
245 rusqlite::Connection::open(path).expect("failed to open SQLite cache database");
246 conn.execute_batch(
247 "CREATE TABLE IF NOT EXISTS cache_entries (
248 key TEXT PRIMARY KEY,
249 value TEXT NOT NULL,
250 expires_at INTEGER NOT NULL
251 )",
252 )
253 .expect("failed to create SQLite cache table");
254 CacheStore::Sqlite {
255 conn: Arc::new(std::sync::Mutex::new(conn)),
256 ttl,
257 }
258 }
259 _ => CacheStore::Memory(
261 Cache::builder()
262 .max_capacity(max_entries)
263 .time_to_live(ttl)
264 .build(),
265 ),
266 }
267}
268
269#[derive(Clone)]
271struct BackendCache {
272 namespace: String,
273 resource_cache: Option<CacheStore>,
274 tool_cache: Option<CacheStore>,
275 stats: Arc<CacheStats>,
276}
277
278struct CacheStats {
280 hits: AtomicU64,
281 misses: AtomicU64,
282}
283
284impl CacheStats {
285 fn new() -> Self {
286 Self {
287 hits: AtomicU64::new(0),
288 misses: AtomicU64::new(0),
289 }
290 }
291}
292
293#[derive(Serialize, Clone)]
298#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
299pub struct CacheStatsSnapshot {
300 pub namespace: String,
302 pub hits: u64,
304 pub misses: u64,
306 pub hit_rate: f64,
308 pub entry_count: u64,
310}
311
312#[derive(Clone)]
314pub struct CacheHandle {
315 caches: Arc<Vec<BackendCache>>,
316}
317
318impl CacheHandle {
319 pub async fn stats(&self) -> Vec<CacheStatsSnapshot> {
321 let mut snapshots = Vec::with_capacity(self.caches.len());
322 for bc in self.caches.iter() {
323 let hits = bc.stats.hits.load(Ordering::Relaxed);
324 let misses = bc.stats.misses.load(Ordering::Relaxed);
325 let total = hits + misses;
326 let resource_count = match &bc.resource_cache {
327 Some(store) => store.entry_count().await,
328 None => 0,
329 };
330 let tool_count = match &bc.tool_cache {
331 Some(store) => store.entry_count().await,
332 None => 0,
333 };
334 snapshots.push(CacheStatsSnapshot {
335 namespace: bc.namespace.clone(),
336 hits,
337 misses,
338 hit_rate: if total > 0 {
339 hits as f64 / total as f64
340 } else {
341 0.0
342 },
343 entry_count: resource_count + tool_count,
344 });
345 }
346 snapshots
347 }
348
349 pub async fn clear(&self) {
351 for bc in self.caches.iter() {
352 if let Some(store) = &bc.resource_cache {
353 store.invalidate_all().await;
354 }
355 if let Some(store) = &bc.tool_cache {
356 store.invalidate_all().await;
357 }
358 bc.stats.hits.store(0, Ordering::Relaxed);
359 bc.stats.misses.store(0, Ordering::Relaxed);
360 }
361 }
362}
363
364fn build_caches(
369 configs: Vec<(String, &BackendCacheConfig)>,
370 backend_config: &CacheBackendConfig,
371) -> Arc<Vec<BackendCache>> {
372 let caches: Vec<BackendCache> = configs
373 .into_iter()
374 .map(|(namespace, cfg)| {
375 let resource_cache = if cfg.resource_ttl_seconds > 0 {
376 Some(build_cache_store(
377 backend_config,
378 Duration::from_secs(cfg.resource_ttl_seconds),
379 cfg.max_entries,
380 ))
381 } else {
382 None
383 };
384 let tool_cache = if cfg.tool_ttl_seconds > 0 {
385 Some(build_cache_store(
386 backend_config,
387 Duration::from_secs(cfg.tool_ttl_seconds),
388 cfg.max_entries,
389 ))
390 } else {
391 None
392 };
393 BackendCache {
394 namespace,
395 resource_cache,
396 tool_cache,
397 stats: Arc::new(CacheStats::new()),
398 }
399 })
400 .collect();
401 Arc::new(caches)
402}
403
404#[derive(Clone)]
435pub struct CacheLayer {
436 caches: Arc<Vec<BackendCache>>,
437}
438
439impl CacheLayer {
440 pub fn new(
446 configs: Vec<(String, &BackendCacheConfig)>,
447 backend_config: &CacheBackendConfig,
448 ) -> (Self, CacheHandle) {
449 let caches = build_caches(configs, backend_config);
450 let handle = CacheHandle {
451 caches: Arc::clone(&caches),
452 };
453 (Self { caches }, handle)
454 }
455}
456
457impl<S> Layer<S> for CacheLayer {
458 type Service = CacheService<S>;
459
460 fn layer(&self, inner: S) -> Self::Service {
461 CacheService {
462 inner,
463 caches: Arc::clone(&self.caches),
464 }
465 }
466}
467
468#[derive(Clone)]
470pub struct CacheService<S> {
471 inner: S,
472 caches: Arc<Vec<BackendCache>>,
473}
474
475impl<S> CacheService<S> {
476 pub fn new(
478 inner: S,
479 configs: Vec<(String, &BackendCacheConfig)>,
480 backend_config: &CacheBackendConfig,
481 ) -> (Self, CacheHandle) {
482 let caches = build_caches(configs, backend_config);
483 let handle = CacheHandle {
484 caches: Arc::clone(&caches),
485 };
486 (Self { inner, caches }, handle)
487 }
488}
489
490fn resolve_cache<'a>(
492 caches: &'a [BackendCache],
493 req: &McpRequest,
494) -> Option<(&'a CacheStore, String, &'a Arc<CacheStats>)> {
495 match req {
496 McpRequest::ReadResource(params) => {
497 let key = format!("res:{}", params.uri);
498 for bc in caches {
499 if params.uri.starts_with(&bc.namespace) {
500 return bc.resource_cache.as_ref().map(|c| (c, key, &bc.stats));
501 }
502 }
503 None
504 }
505 McpRequest::CallTool(params) => {
506 let args = serde_json::to_string(¶ms.arguments).unwrap_or_default();
507 let key = format!("tool:{}:{}", params.name, args);
508 for bc in caches {
509 if params.name.starts_with(&bc.namespace) {
510 return bc.tool_cache.as_ref().map(|c| (c, key, &bc.stats));
511 }
512 }
513 None
514 }
515 _ => None,
516 }
517}
518
519impl<S> Service<RouterRequest> for CacheService<S>
520where
521 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
522 + Clone
523 + Send
524 + 'static,
525 S::Future: Send,
526{
527 type Response = RouterResponse;
528 type Error = Infallible;
529 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
530
531 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
532 self.inner.poll_ready(cx)
533 }
534
535 fn call(&mut self, req: RouterRequest) -> Self::Future {
536 let caches = Arc::clone(&self.caches);
537
538 if let Some((store, key, stats)) = resolve_cache(&caches, &req.inner) {
539 let store = store.clone();
540 let stats = Arc::clone(stats);
541 let mut inner = self.inner.clone();
542
543 return Box::pin(async move {
544 if let Some(cached) = store.get(&key).await {
546 stats.hits.fetch_add(1, Ordering::Relaxed);
547 return Ok(RouterResponse {
548 id: req.id,
549 inner: cached.inner,
550 });
551 }
552
553 stats.misses.fetch_add(1, Ordering::Relaxed);
554 let result = inner.call(req).await;
555
556 let Ok(ref resp) = result;
558 if resp.inner.is_ok() {
559 store.insert(key, resp.clone()).await;
560 }
561
562 result
563 });
564 }
565
566 let fut = self.inner.call(req);
568 Box::pin(fut)
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use tower_mcp::protocol::{McpRequest, McpResponse};
575
576 use super::CacheService;
577 use crate::config::{BackendCacheConfig, CacheBackendConfig};
578 use crate::test_util::{MockService, call_service};
579
580 fn tool_call(name: &str) -> McpRequest {
581 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
582 name: name.to_string(),
583 arguments: serde_json::json!({"key": "value"}),
584 meta: None,
585 task: None,
586 })
587 }
588
589 fn default_backend_config() -> CacheBackendConfig {
590 CacheBackendConfig::default()
591 }
592
593 #[tokio::test]
594 async fn test_cache_hit_returns_same_result() {
595 let mock = MockService::with_tools(&["fs/read"]);
596 let cfg = BackendCacheConfig {
597 resource_ttl_seconds: 60,
598 tool_ttl_seconds: 60,
599 max_entries: 100,
600 };
601 let (mut svc, _handle) = CacheService::new(
602 mock,
603 vec![("fs/".to_string(), &cfg)],
604 &default_backend_config(),
605 );
606
607 let resp1 = call_service(&mut svc, tool_call("fs/read")).await;
608 let resp2 = call_service(&mut svc, tool_call("fs/read")).await;
609
610 match (resp1.inner.unwrap(), resp2.inner.unwrap()) {
612 (McpResponse::CallTool(r1), McpResponse::CallTool(r2)) => {
613 assert_eq!(r1.all_text(), r2.all_text());
614 }
615 _ => panic!("expected CallTool responses"),
616 }
617 }
618
619 #[tokio::test]
620 async fn test_cache_disabled_passes_through() {
621 let mock = MockService::with_tools(&["fs/read"]);
622 let cfg = BackendCacheConfig {
623 resource_ttl_seconds: 0,
624 tool_ttl_seconds: 0,
625 max_entries: 100,
626 };
627 let (mut svc, _handle) = CacheService::new(
628 mock,
629 vec![("fs/".to_string(), &cfg)],
630 &default_backend_config(),
631 );
632
633 let resp = call_service(&mut svc, tool_call("fs/read")).await;
634 assert!(resp.inner.is_ok());
635 }
636
637 #[tokio::test]
638 async fn test_cache_non_matching_namespace_passes_through() {
639 let mock = MockService::with_tools(&["db/query"]);
640 let cfg = BackendCacheConfig {
641 resource_ttl_seconds: 60,
642 tool_ttl_seconds: 60,
643 max_entries: 100,
644 };
645 let (mut svc, _handle) = CacheService::new(
646 mock,
647 vec![("fs/".to_string(), &cfg)],
648 &default_backend_config(),
649 );
650
651 let resp = call_service(&mut svc, tool_call("db/query")).await;
652 assert!(resp.inner.is_ok());
653 }
654
655 #[tokio::test]
656 async fn test_cache_list_tools_not_cached() {
657 let mock = MockService::with_tools(&["fs/read"]);
658 let cfg = BackendCacheConfig {
659 resource_ttl_seconds: 60,
660 tool_ttl_seconds: 60,
661 max_entries: 100,
662 };
663 let (mut svc, _handle) = CacheService::new(
664 mock,
665 vec![("fs/".to_string(), &cfg)],
666 &default_backend_config(),
667 );
668
669 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
670 assert!(resp.inner.is_ok(), "list_tools should pass through");
671 }
672
673 #[tokio::test]
674 async fn test_cache_stats_tracks_hits_and_misses() {
675 let mock = MockService::with_tools(&["fs/read"]);
676 let cfg = BackendCacheConfig {
677 resource_ttl_seconds: 60,
678 tool_ttl_seconds: 60,
679 max_entries: 100,
680 };
681 let (mut svc, handle) = CacheService::new(
682 mock,
683 vec![("fs/".to_string(), &cfg)],
684 &default_backend_config(),
685 );
686
687 let _ = call_service(&mut svc, tool_call("fs/read")).await;
689 let stats = handle.stats().await;
690 assert_eq!(stats.len(), 1);
691 assert_eq!(stats[0].hits, 0);
692 assert_eq!(stats[0].misses, 1);
693
694 let _ = call_service(&mut svc, tool_call("fs/read")).await;
696 let stats = handle.stats().await;
697 assert_eq!(stats[0].hits, 1);
698 assert_eq!(stats[0].misses, 1);
699 assert!((stats[0].hit_rate - 0.5).abs() < f64::EPSILON);
700 }
701
702 #[tokio::test]
703 async fn test_cache_clear_resets_stats() {
704 let mock = MockService::with_tools(&["fs/read"]);
705 let cfg = BackendCacheConfig {
706 resource_ttl_seconds: 60,
707 tool_ttl_seconds: 60,
708 max_entries: 100,
709 };
710 let (mut svc, handle) = CacheService::new(
711 mock,
712 vec![("fs/".to_string(), &cfg)],
713 &default_backend_config(),
714 );
715
716 let _ = call_service(&mut svc, tool_call("fs/read")).await;
717 let _ = call_service(&mut svc, tool_call("fs/read")).await;
718
719 handle.clear().await;
720 let stats = handle.stats().await;
721 assert_eq!(stats[0].hits, 0);
722 assert_eq!(stats[0].misses, 0);
723 }
724
725 #[tokio::test]
726 async fn test_cache_layer_produces_working_service() {
727 use super::CacheLayer;
728 use tower::Layer;
729
730 let cfg = BackendCacheConfig {
731 resource_ttl_seconds: 60,
732 tool_ttl_seconds: 60,
733 max_entries: 100,
734 };
735 let (layer, handle) =
736 CacheLayer::new(vec![("fs/".to_string(), &cfg)], &default_backend_config());
737
738 let mock = MockService::with_tools(&["fs/read"]);
739 let mut svc = layer.layer(mock);
740
741 let _ = call_service(&mut svc, tool_call("fs/read")).await;
743 let stats = handle.stats().await;
744 assert_eq!(stats[0].misses, 1);
745 assert_eq!(stats[0].hits, 0);
746
747 let _ = call_service(&mut svc, tool_call("fs/read")).await;
749 let stats = handle.stats().await;
750 assert_eq!(stats[0].hits, 1);
751 assert_eq!(stats[0].misses, 1);
752 }
753
754 #[tokio::test]
755 async fn test_cache_layer_shares_state_across_services() {
756 use super::CacheLayer;
757 use tower::Layer;
758
759 let cfg = BackendCacheConfig {
760 resource_ttl_seconds: 60,
761 tool_ttl_seconds: 60,
762 max_entries: 100,
763 };
764 let (layer, handle) =
765 CacheLayer::new(vec![("fs/".to_string(), &cfg)], &default_backend_config());
766
767 let mock1 = MockService::with_tools(&["fs/read"]);
769 let mut svc1 = layer.layer(mock1);
770
771 let mock2 = MockService::with_tools(&["fs/read"]);
772 let mut svc2 = layer.layer(mock2);
773
774 let _ = call_service(&mut svc1, tool_call("fs/read")).await;
776 assert_eq!(handle.stats().await[0].misses, 1);
777
778 let _ = call_service(&mut svc2, tool_call("fs/read")).await;
780 assert_eq!(handle.stats().await[0].hits, 1);
781 assert_eq!(handle.stats().await[0].misses, 1);
782 }
783
784 #[tokio::test]
785 async fn test_cache_layer_handle_clear() {
786 use super::CacheLayer;
787 use tower::Layer;
788
789 let cfg = BackendCacheConfig {
790 resource_ttl_seconds: 60,
791 tool_ttl_seconds: 60,
792 max_entries: 100,
793 };
794 let (layer, handle) =
795 CacheLayer::new(vec![("fs/".to_string(), &cfg)], &default_backend_config());
796
797 let mock = MockService::with_tools(&["fs/read"]);
798 let mut svc = layer.layer(mock);
799
800 let _ = call_service(&mut svc, tool_call("fs/read")).await;
801 let _ = call_service(&mut svc, tool_call("fs/read")).await;
802 assert_eq!(handle.stats().await[0].hits, 1);
803
804 handle.clear().await;
805 let stats = handle.stats().await;
806 assert_eq!(stats[0].hits, 0);
807 assert_eq!(stats[0].misses, 0);
808 }
809
810 #[tokio::test]
811 async fn test_cache_store_memory_get_insert() {
812 use super::{CacheStore, build_cache_store};
813
814 let store = build_cache_store(&default_backend_config(), Duration::from_secs(60), 100);
815 assert!(matches!(store, CacheStore::Memory(_)));
816
817 assert!(store.get("key1").await.is_none());
819 assert_eq!(store.entry_count().await, 0);
820 }
821
822 #[cfg(feature = "redis-cache")]
823 #[test]
824 fn test_cache_store_redis_construction() {
825 use super::build_cache_store;
826
827 let cfg = CacheBackendConfig {
828 backend: "redis".to_string(),
829 url: Some("redis://127.0.0.1:6379".to_string()),
830 prefix: "test:".to_string(),
831 };
832 let store = build_cache_store(&cfg, Duration::from_secs(60), 100);
833 assert!(matches!(store, super::CacheStore::Redis { .. }));
834 }
835
836 #[cfg(feature = "sqlite-cache")]
837 #[tokio::test]
838 async fn test_cache_store_sqlite_construction() {
839 use super::build_cache_store;
840
841 let dir = std::env::temp_dir().join(format!("mcp-proxy-test-{}", std::process::id()));
842 std::fs::create_dir_all(&dir).unwrap();
843 let db_path = dir.join("test_cache.db");
844
845 let cfg = CacheBackendConfig {
846 backend: "sqlite".to_string(),
847 url: Some(db_path.to_string_lossy().to_string()),
848 prefix: "test:".to_string(),
849 };
850 let store = build_cache_store(&cfg, Duration::from_secs(60), 100);
851 assert!(matches!(store, super::CacheStore::Sqlite { .. }));
852 assert_eq!(store.entry_count().await, 0);
853
854 let _ = std::fs::remove_dir_all(&dir);
856 }
857
858 use std::time::Duration;
859}